From 14fccb84c9a4a9ddcf1ec6eb689e6bdf0394b07b Mon Sep 17 00:00:00 2001 From: Cloaked9000 Date: Thu, 15 Dec 2016 12:29:23 +0000 Subject: [PATCH] More work on SSL support. You can now accept SSL connections using SSLListeners, and then send/receive data through the associated SSLSocket. HttpSocket's now support both HTTP and HTTPS, using templates: fr::HttpSocket https_socket; fr::HttpSocket http_socket; --- CMakeLists.txt | 2 +- include/HttpSocket.h | 27 +++++++++++-- include/SSLListener.h | 3 +- include/SSLSocket.h | 16 +++----- main.cpp | 43 ++++++++++++-------- src/HttpRequest.cpp | 1 - src/HttpSocket.cpp | 23 ----------- {include => src}/SSLListener.cpp | 37 ++++++++++-------- {include => src}/SSLSocket.cpp | 67 ++++++++++++++++++-------------- 9 files changed, 116 insertions(+), 103 deletions(-) rename {include => src}/SSLListener.cpp (81%) rename {include => src}/SSLSocket.cpp (70%) diff --git a/CMakeLists.txt b/CMakeLists.txt index f4cc37c..27be970 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,7 +9,7 @@ INCLUDE_DIRECTORIES(${MBEDTLS_INCLUDE_DIR}) include_directories(include) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -m64 -fPIC -pthread -lmbedtls -lmbedx509 -lmbedcrypto") -set(SOURCE_FILES main.cpp src/TcpSocket.cpp include/TcpSocket.h src/TcpListener.cpp include/TcpListener.h src/Socket.cpp include/Socket.h src/Packet.cpp include/Packet.h include/NetworkEncoding.h src/SocketSelector.cpp include/SocketSelector.h src/HttpSocket.cpp include/HttpSocket.h src/HttpRequest.cpp include/HttpRequest.h src/HttpResponse.cpp include/HttpResponse.h src/Http.cpp include/Http.h include/SSLSocket.cpp include/SSLSocket.h include/SSLListener.cpp include/SSLListener.h) +set(SOURCE_FILES main.cpp src/TcpSocket.cpp include/TcpSocket.h src/TcpListener.cpp include/TcpListener.h src/Socket.cpp include/Socket.h src/Packet.cpp include/Packet.h include/NetworkEncoding.h src/SocketSelector.cpp include/SocketSelector.h src/HttpSocket.cpp include/HttpSocket.h src/HttpRequest.cpp include/HttpRequest.h src/HttpResponse.cpp include/HttpResponse.h src/Http.cpp include/Http.h src/SSLSocket.cpp include/SSLSocket.h src/SSLListener.cpp include/SSLListener.h) add_executable(frnetlib ${SOURCE_FILES}) TARGET_LINK_LIBRARIES(frnetlib ${MBEDTLS_LIBRARIES} -lmbedtls -lmbedx509 -lmbedcrypto) \ No newline at end of file diff --git a/include/HttpSocket.h b/include/HttpSocket.h index 342268a..be5ea4d 100644 --- a/include/HttpSocket.h +++ b/include/HttpSocket.h @@ -10,7 +10,8 @@ namespace fr { - class HttpSocket : public TcpSocket + template + class HttpSocket : public SocketType { public: /*! @@ -19,7 +20,23 @@ namespace fr * @param request Where to store the received request. * @return The status of the operation. */ - Socket::Status receive(Http &request); + Socket::Status receive(Http &request) + { + //Create buffer to receive_request the request + std::string buffer(2048, '\0'); + + //Receive the request + size_t received; + Socket::Status status = SocketType::receive_raw(&buffer[0], buffer.size(), received); + if(status != Socket::Success) + return status; + buffer.resize(received); + + //Parse it + request.parse(buffer); + + return Socket::Success; + } /*! * Sends a HTTP request to the connected socket. @@ -27,7 +44,11 @@ namespace fr * @param request The HTTP request to send. * @return The status of the operation. */ - Socket::Status send(const Http &request); + Socket::Status send(const Http &request) + { + std::string data = request.construct(SocketType::remote_address); + return SocketType::send_raw(&data[0], data.size()); + } }; } diff --git a/include/SSLListener.h b/include/SSLListener.h index c1666f7..08e4a92 100644 --- a/include/SSLListener.h +++ b/include/SSLListener.h @@ -5,7 +5,7 @@ #ifndef FRNETLIB_SSLLISTENER_H #define FRNETLIB_SSLLISTENER_H -//#define SSL_ENABLED +#define SSL_ENABLED #ifdef SSL_ENABLED @@ -50,7 +50,6 @@ namespace fr mbedtls_net_context listen_fd; mbedtls_entropy_context entropy; mbedtls_ctr_drbg_context ctr_drbg; - mbedtls_ssl_context ssl; mbedtls_ssl_config conf; mbedtls_x509_crt srvcert; mbedtls_pk_context pkey; diff --git a/include/SSLSocket.h b/include/SSLSocket.h index e93ad27..efea3a3 100644 --- a/include/SSLSocket.h +++ b/include/SSLSocket.h @@ -5,7 +5,7 @@ #ifndef FRNETLIB_SSL_SOCKET_H #define FRNETLIB_SSL_SOCKET_H -//#define SSL_ENABLED +#define SSL_ENABLED #ifdef SSL_ENABLED @@ -94,13 +94,6 @@ namespace fr */ Status receive_raw(void *data, size_t data_size, size_t &received) override; - /*! - * Sets the socket file descriptor. - * - * @param descriptor The socket descriptor. - */ - void set_descriptor(int descriptor) override; - /*! * Close the connection. */ @@ -115,11 +108,14 @@ namespace fr */ Socket::Status connect(const std::string &address, const std::string &port) override; + void set_ssl_context(std::unique_ptr context); + void set_net_context(std::unique_ptr context); + private: - mbedtls_net_context ssl_socket_descriptor; + std::unique_ptr ssl_socket_descriptor; mbedtls_entropy_context entropy; mbedtls_ctr_drbg_context ctr_drbg; - mbedtls_ssl_context ssl; + std::unique_ptr ssl; mbedtls_ssl_config conf; mbedtls_x509_crt cacert; uint32_t flags; diff --git a/main.cpp b/main.cpp index ef3da8d..26a28e6 100644 --- a/main.cpp +++ b/main.cpp @@ -18,24 +18,35 @@ int main() // return 1; // } // -// fr::SSLSocket socket; -// if(listener.accept(socket) != fr::Socket::Success) +// while(true) // { -// std::cout << "Failed to accept client" << std::endl; -// return 2; +// fr::HttpSocket http_socket; +// if(listener.accept(http_socket) != fr::Socket::Success) +// { +// std::cout << "Failed to accept client" << std::endl; +// continue; +// } +// +// fr::HttpRequest request; +// if(http_socket.receive(request) != fr::Socket::Success) +// { +// std::cout << "Failed to receive data" << std::endl; +// continue; +// } +// else +// { +// std::cout << "Read successfully" << std::endl; +// } +// +// std::cout << "Got: " << request.get_body() << std::endl; +// +// fr::HttpResponse response; +// response.set_body("

Hello, SSL World!

"); +// http_socket.send(response); +// http_socket.close(); // } -// -// std::string buf(1024, '\0'); -// size_t received = 0; -// if(socket.receive_raw(&buf[0], buf.size(), received) != fr::Socket::Success) -// { -// std::cout << "Failed to receive data" << std::endl; -// return 3; -// } -// -// std::cout << "Got: " << buf.substr(0, received) << std::endl; -// -// + + // fr::SSLSocket socket; // if(socket.connect("lloydsenpai.xyz", "443") != fr::Socket::Success) // return 1; diff --git a/src/HttpRequest.cpp b/src/HttpRequest.cpp index 21fd60b..228c3f5 100644 --- a/src/HttpRequest.cpp +++ b/src/HttpRequest.cpp @@ -129,7 +129,6 @@ namespace fr //Add in the body request += body + "\n"; - std::cout << "constructed: " << std::endl << request << std::endl; return request; } } \ No newline at end of file diff --git a/src/HttpSocket.cpp b/src/HttpSocket.cpp index 8a23993..62375dd 100644 --- a/src/HttpSocket.cpp +++ b/src/HttpSocket.cpp @@ -8,27 +8,4 @@ namespace fr { - Socket::Status HttpSocket::receive(Http &request) - { - //Create buffer to receive_request the request - std::string buffer(2048, '\0'); - - //Receive the request - size_t received; - Socket::Status status = receive_raw(&buffer[0], buffer.size(), received); - if(status != Socket::Success) - return status; - buffer.resize(received); - - //Parse it - request.parse(buffer); - - return Socket::Success; - } - - Socket::Status HttpSocket::send(const Http &request) - { - std::string data = request.construct(remote_address); - return send_raw(&data[0], data.size()); - } } \ No newline at end of file diff --git a/include/SSLListener.cpp b/src/SSLListener.cpp similarity index 81% rename from include/SSLListener.cpp rename to src/SSLListener.cpp index 8aae3de..dcf6afe 100644 --- a/include/SSLListener.cpp +++ b/src/SSLListener.cpp @@ -3,6 +3,7 @@ // #include "SSLListener.h" +#ifdef SSL_ENABLED namespace fr { @@ -10,7 +11,6 @@ namespace fr { //Initialise SSL objects required mbedtls_net_init(&listen_fd); - mbedtls_ssl_init(&ssl); mbedtls_ssl_config_init(&conf); mbedtls_x509_crt_init(&srvcert); mbedtls_pk_init(&pkey); @@ -65,12 +65,6 @@ namespace fr return; } - if((error = mbedtls_ssl_setup( &ssl, &conf ) ) != 0) - { - std::cout << "Failed to apply SSL setings: " << error << std::endl; - return; - } - } SSLListener::~SSLListener() @@ -78,7 +72,6 @@ namespace fr mbedtls_net_free(&listen_fd); mbedtls_x509_crt_free(&srvcert); mbedtls_pk_free(&pkey); - mbedtls_ssl_free(&ssl); mbedtls_ssl_config_free(&conf); mbedtls_ctr_drbg_free(&ctr_drbg); mbedtls_entropy_free( &entropy); @@ -87,7 +80,7 @@ namespace fr Socket::Status fr::SSLListener::listen(const std::string &port) { //Bind to port - if(mbedtls_net_bind( &listen_fd, NULL, port.c_str(), MBEDTLS_NET_PROTO_TCP) != 0) + if(mbedtls_net_bind(&listen_fd, NULL, port.c_str(), MBEDTLS_NET_PROTO_TCP) != 0) { return Socket::BindFailed; } @@ -97,32 +90,42 @@ namespace fr Socket::Status SSLListener::accept(SSLSocket &client) { int error = 0; + std::unique_ptr ssl(new mbedtls_ssl_context); + mbedtls_ssl_init(ssl.get()); + if((error = mbedtls_ssl_setup(ssl.get(), &conf ) ) != 0) + { + std::cout << "Failed to apply SSL setings: " << error << std::endl; + return Socket::Error; + } + //Accept a connection - mbedtls_net_context client_fd; - mbedtls_net_init(&client_fd); + std::unique_ptr client_fd(new mbedtls_net_context); + mbedtls_net_init(client_fd.get()); - if((error = mbedtls_net_accept(&listen_fd, &client_fd, NULL, 0, NULL)) != 0) + if((error = mbedtls_net_accept(&listen_fd, client_fd.get(), NULL, 0, NULL)) != 0) { std::cout << "Accept error: " << error << std::endl; return Socket::Error; } - mbedtls_ssl_set_bio( &ssl, &client_fd, mbedtls_net_send, mbedtls_net_recv, NULL); + mbedtls_ssl_set_bio(ssl.get(), client_fd.get(), mbedtls_net_send, mbedtls_net_recv, NULL); //SSL Handshake - while((error = mbedtls_ssl_handshake(&ssl)) != 0) + while((error = mbedtls_ssl_handshake(ssl.get())) != 0) { if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE) { - std::cout << "Handshake failed: " << error << std::end + std::cout << "Handshake failed: " << error << std::endl; return Socket::Status::HandshakeFailed; } } //Set socket details - client.set_descriptor(client_fd.fd); + client.set_net_context(std::move(client_fd)); + client.set_ssl_context(std::move(ssl)); return Socket::Success; } -} \ No newline at end of file +} +#endif //SSL_ENABLED \ No newline at end of file diff --git a/include/SSLSocket.cpp b/src/SSLSocket.cpp similarity index 70% rename from include/SSLSocket.cpp rename to src/SSLSocket.cpp index 52c8e69..115ccc0 100644 --- a/include/SSLSocket.cpp +++ b/src/SSLSocket.cpp @@ -3,6 +3,7 @@ // #include "SSLSocket.h" +#include #ifdef SSL_ENABLED namespace fr @@ -13,8 +14,6 @@ namespace fr const char *pers = "ssl_client1"; //Initialise mbedtls structures - mbedtls_net_init(&ssl_socket_descriptor); - mbedtls_ssl_init(&ssl); mbedtls_ssl_config_init(&conf); mbedtls_x509_crt_init(&cacert); mbedtls_ctr_drbg_init(&ctr_drbg); @@ -41,9 +40,7 @@ namespace fr close(); //Cleanup mbedsql stuff - mbedtls_net_free(&ssl_socket_descriptor); mbedtls_x509_crt_free(&cacert); - mbedtls_ssl_free(&ssl); mbedtls_ssl_config_free(&conf); mbedtls_ctr_drbg_free(&ctr_drbg); mbedtls_entropy_free(&entropy); @@ -53,7 +50,10 @@ namespace fr { if(is_connected) { - mbedtls_ssl_close_notify(&ssl); + if(ssl) + mbedtls_ssl_close_notify(ssl.get()); + if(ssl_socket_descriptor) + mbedtls_net_free(ssl_socket_descriptor.get()); is_connected = false; } } @@ -61,7 +61,7 @@ namespace fr Socket::Status SSLSocket::send_raw(const char *data, size_t size) { int error = 0; - while((error = mbedtls_ssl_write(&ssl, (const unsigned char *)data, size)) <= 0) + while((error = mbedtls_ssl_write(ssl.get(), (const unsigned char *)data, size)) <= 0) { if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE) { @@ -74,26 +74,23 @@ namespace fr Socket::Status SSLSocket::receive_raw(void *data, size_t data_size, size_t &received) { - int read = 0; + int read = MBEDTLS_ERR_SSL_WANT_READ; received = 0; if(unprocessed_buffer.size() < data_size) { - read = mbedtls_ssl_read(&ssl, (unsigned char *)recv_buffer.get(), RECV_CHUNK_SIZE); - - if(read == MBEDTLS_ERR_SSL_WANT_READ || read == MBEDTLS_ERR_SSL_WANT_WRITE) + while(read == MBEDTLS_ERR_SSL_WANT_READ || read == MBEDTLS_ERR_SSL_WANT_WRITE) { - received = 0; - return Socket::Status::Success; + read = mbedtls_ssl_read(ssl.get(), (unsigned char *)recv_buffer.get(), RECV_CHUNK_SIZE); } - else if(read == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) + + if(read == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { - std::cout << "disconnected" << std::endl; return Socket::Status::Disconnected; } else if(read <= 0) { - std::cout << "read <= 0" << std::endl; - return Socket::Status::Error; + //No data. But no error occurred. + return Socket::Status::Success; } received += read; @@ -114,18 +111,17 @@ namespace fr } - void SSLSocket::set_descriptor(int descriptor) - { - is_connected = true; - socket_descriptor = descriptor; - ssl_socket_descriptor.fd = descriptor; - } - Socket::Status SSLSocket::connect(const std::string &address, const std::string &port) { - //Initialise the connection using mbedtls + //Initialise mbedtls stuff + ssl = std::unique_ptr(new mbedtls_ssl_context); + ssl_socket_descriptor = std::unique_ptr(new mbedtls_net_context); + mbedtls_ssl_init(ssl.get()); + mbedtls_net_init(ssl_socket_descriptor.get()); + + //Initialise the connection using mbedtlsl int error = 0; - if((error = mbedtls_net_connect(&ssl_socket_descriptor, address.c_str(), port.c_str(), MBEDTLS_NET_PROTO_TCP)) != 0) + if((error = mbedtls_net_connect(ssl_socket_descriptor.get(), address.c_str(), port.c_str(), MBEDTLS_NET_PROTO_TCP)) != 0) { return Socket::Status::ConnectionFailed; } @@ -140,20 +136,20 @@ namespace fr mbedtls_ssl_conf_ca_chain(&conf, &cacert, NULL); mbedtls_ssl_conf_rng(&conf, mbedtls_ctr_drbg_random, &ctr_drbg); - if((error = mbedtls_ssl_setup(&ssl, &conf)) != 0) + if((error = mbedtls_ssl_setup(ssl.get(), &conf)) != 0) { return Socket::Status::Error; } - if((error = mbedtls_ssl_set_hostname(&ssl, address.c_str())) != 0) + if((error = mbedtls_ssl_set_hostname(ssl.get(), address.c_str())) != 0) { return Socket::Status::Error; } - mbedtls_ssl_set_bio(&ssl, &ssl_socket_descriptor, mbedtls_net_send, mbedtls_net_recv, NULL); + mbedtls_ssl_set_bio(ssl.get(), ssl_socket_descriptor.get(), mbedtls_net_send, mbedtls_net_recv, NULL); //Do SSL handshake - while((error = mbedtls_ssl_handshake(&ssl)) != 0) + while((error = mbedtls_ssl_handshake(ssl.get())) != 0) { if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE) { @@ -163,7 +159,7 @@ namespace fr } //Verify server certificate - if((flags = mbedtls_ssl_get_verify_result(&ssl)) != 0) + if((flags = mbedtls_ssl_get_verify_result(ssl.get())) != 0) { char verify_buffer[512]; mbedtls_x509_crt_verify_info( verify_buffer, sizeof( verify_buffer ), " ! ", flags ); @@ -177,6 +173,17 @@ namespace fr remote_address = address + ":" + port; return Socket::Status::Success; } + + void SSLSocket::set_ssl_context(std::unique_ptr context) + { + ssl = std::move(context); + } + + void SSLSocket::set_net_context(std::unique_ptr context) + { + is_connected = true; + ssl_socket_descriptor = std::move(context); + } } #endif \ No newline at end of file