From 69d183ed183a5d1f3e03af238e1c688a381c068e Mon Sep 17 00:00:00 2001 From: Cloaked9000 Date: Thu, 15 Dec 2016 14:57:01 +0000 Subject: [PATCH] Code refactoring --- include/HttpSocket.h | 2 +- include/NetworkEncoding.h | 17 +++++ include/SSLListener.h | 16 ++++- include/SSLSocket.h | 47 ++++++++++++-- include/Socket.h | 131 +++++++++++++++++++++----------------- include/TcpListener.h | 4 +- include/TcpSocket.h | 69 ++++++++------------ main.cpp | 68 ++++++++++---------- src/SSLListener.cpp | 10 +-- src/SSLSocket.cpp | 5 +- src/Socket.cpp | 63 ++++++++++++++++++ src/TcpSocket.cpp | 66 ++++--------------- 12 files changed, 294 insertions(+), 204 deletions(-) diff --git a/include/HttpSocket.h b/include/HttpSocket.h index be5ea4d..b2c03be 100644 --- a/include/HttpSocket.h +++ b/include/HttpSocket.h @@ -46,7 +46,7 @@ namespace fr */ Socket::Status send(const Http &request) { - std::string data = request.construct(SocketType::remote_address); + std::string data = request.construct(SocketType::get_remote_address()); return SocketType::send_raw(&data[0], data.size()); } }; diff --git a/include/NetworkEncoding.h b/include/NetworkEncoding.h index 6e4b4f1..60b4bd3 100644 --- a/include/NetworkEncoding.h +++ b/include/NetworkEncoding.h @@ -6,6 +6,7 @@ #define FRNETLIB_NETWORKENCODING_H #include +#include #include #define htonll(x) ((1==htonl(1)) ? (x) : ((uint64_t)htonl((x) & 0xFFFFFFFF) << 32) | htonl((x) >> 32)) @@ -53,6 +54,22 @@ inline void *get_sin_addr(struct sockaddr *sa) return &(((sockaddr_in6*)sa)->sin6_addr); } +inline void set_unix_socket_blocking(int32_t socket_descriptor, bool is_blocking_already, bool should_block) +{ + //Don't update it if we're already in that mode + if(should_block == is_blocking_already) + return; + + //Different API calls needed for both windows and unix + #ifdef WIN32 + u_long non_blocking = should_block ? 0 : 1; + ioctlsocket(socket_descriptor, FIONBIO, &non_blocking); + #else + int flags = fcntl(socket_descriptor, F_GETFL, 0); + fcntl(socket_descriptor, F_SETFL, is_blocking_already ? flags ^ O_NONBLOCK : flags ^= O_NONBLOCK); + #endif +} + //Windows and UNIX require some different headers. //We also need some compatibility defines for cross platform support. diff --git a/include/SSLListener.h b/include/SSLListener.h index 08e4a92..2beac98 100644 --- a/include/SSLListener.h +++ b/include/SSLListener.h @@ -46,6 +46,18 @@ namespace fr */ virtual Socket::Status accept(SSLSocket &client); + /*! + * Enables or disables blocking on the socket. + * + * @param should_block True to block, false otherwise. + */ + virtual void set_blocking(bool should_block) override {abort();}; //Not implemented + + virtual int32_t get_socket_descriptor() const override + { + return listen_fd.fd; + } + private: mbedtls_net_context listen_fd; mbedtls_entropy_context entropy; @@ -55,10 +67,10 @@ namespace fr mbedtls_pk_context pkey; //Stubs - virtual Status send(const Packet &packet){return Socket::Error;} - virtual Status receive(Packet &packet){return Socket::Error;} virtual void close(){} virtual Socket::Status connect(const std::string &address, const std::string &port){return Socket::Error;} + virtual Status send_raw(const char *data, size_t size) {return Socket::Error;} + virtual Status receive_raw(void *data, size_t data_size, size_t &received) {return Socket::Error;} }; } diff --git a/include/SSLSocket.h b/include/SSLSocket.h index efea3a3..667a780 100644 --- a/include/SSLSocket.h +++ b/include/SSLSocket.h @@ -66,11 +66,14 @@ const std::string certs = namespace fr { - class SSLSocket : public TcpSocket + class SSLSocket : public Socket { public: - SSLSocket(); - ~SSLSocket(); + SSLSocket() noexcept; + + ~SSLSocket() noexcept; + + SSLSocket(SSLSocket &&) noexcept = default; /*! * Effectively just fr::TcpSocket::send_raw() with encryption @@ -80,7 +83,7 @@ namespace fr * @param size The number of bytes, from data to send. Be careful not to overflow. * @return The status of the operation. */ - Status send_raw(const char *data, size_t size) override; + Socket::Status send_raw(const char *data, size_t size) override; /*! @@ -92,7 +95,7 @@ namespace fr * @param received Will be filled with the number of bytes actually received, might be less than you requested. * @return The status of the operation, if the socket has disconnected etc. */ - Status receive_raw(void *data, size_t data_size, size_t &received) override; + Socket::Status receive_raw(void *data, size_t data_size, size_t &received) override; /*! * Close the connection. @@ -108,10 +111,44 @@ namespace fr */ Socket::Status connect(const std::string &address, const std::string &port) override; + /*! + * Set the SSL context + * + * @param context The SSL context to use + */ void set_ssl_context(std::unique_ptr context); + + /*! + * Set the NET context + * + * @param context The NET context to use + */ void set_net_context(std::unique_ptr context); + /*! + * Gets the underlying socket descriptor. + * + * @return The socket's descriptor. + */ + virtual int32_t get_socket_descriptor() const override + { + return ssl_socket_descriptor->fd; + } + + /*! + * Sets if the socket should block or not. + * + * @param should_block True to block, false otherwise. + */ + virtual void set_blocking(bool should_block) override + { + abort(); + } + private: + std::string unprocessed_buffer; + std::unique_ptr recv_buffer; + std::unique_ptr ssl_socket_descriptor; mbedtls_entropy_context entropy; mbedtls_ctr_drbg_context ctr_drbg; diff --git a/include/Socket.h b/include/Socket.h index e5dae25..2975769 100644 --- a/include/Socket.h +++ b/include/Socket.h @@ -28,27 +28,8 @@ namespace fr VerificationFailed = 9, }; - Socket() - : is_blocking(true) - { - - } - - /*! - * Send a packet through the socket - * - * @param packet The packet to send - * @return A status enum value indicating if the operation succeeded or not. Success on success, Error on failure, Disconnected on disconnection etc. - */ - virtual Status send(const Packet &packet)=0; - - /*! - * Receive a packet through the socket - * - * @param packet The packet to receive - * @return A status enum value indicating if the operation succeeded or not. Success on success, Error on failure, Disconnected on disconnection, WouldBlock if the socket is non-blocking and the socket was not ready to receive, etc. - */ - virtual Status receive(Packet &packet)=0; + Socket() noexcept; + virtual ~Socket() noexcept = default; /*! * Close the connection. @@ -64,16 +45,6 @@ namespace fr */ virtual Socket::Status connect(const std::string &address, const std::string &port)=0; - /*! - * Sets the socket's printable remote address - * - * @param addr The string address - */ - inline virtual void set_remote_address(const std::string &addr) - { - remote_address = addr; - } - /*! * Gets the socket's printable remote address * @@ -84,43 +55,89 @@ namespace fr return remote_address; } - /*! - * Gets the socket descriptor of the object - * - * @return The socket file descriptor - */ - inline int32_t get_socket_descriptor() const - { - return socket_descriptor; - } - /*! * Sets the socket to blocking or non-blocking. * * @param should_block True for blocking (default argument), false otherwise. */ - inline virtual void set_blocking(bool should_block = true) + virtual void set_blocking(bool should_block = true) = 0; + + /*! + * Attempts to send raw data down the socket, without + * any of frnetlib's framing. Useful for communicating through + * different protocols. + * + * @param data The data to send. + * @param size The number of bytes, from data to send. Be careful not to overflow. + * @return The status of the operation. + */ + virtual Status send_raw(const char *data, size_t size) = 0; + + + /*! + * Receives raw data from the socket, without any of + * frnetlib's framing. Useful for communicating through + * different protocols. This will attempt to read 'data_size' + * bytes, but might not succeed. It'll return how many bytes were actually + * read in 'received'. + * + * @param data Where to store the received data. + * @param data_size The number of bytes to try and receive. Be sure that it's not larger than data. + * @param received Will be filled with the number of bytes actually received, might be less than you requested. + * @return The status of the operation, if the socket has disconnected etc. + */ + virtual Status receive_raw(void *data, size_t data_size, size_t &received) = 0; + + /*! + * Send a packet through the socket + * + * @param packet The packet to send + * @return True on success, false on failure. + */ + Status send(const Packet &packet); + + /*! + * Receive a packet through the socket + * + * @param packet The packet to receive + * @return True on success, false on failure. + */ + Status receive(Packet &packet); + + /*! + * Reads size bytes into dest from the socket. + * Unlike receive_raw, this will keep trying + * to receive data until 'size' bytes have been + * read, or the client has disconnected/there was + * an error. + * + * @param dest Where to read the data into + * @param size The number of bytes to read + * @return Operation status. + */ + Status receive_all(void *dest, size_t size); + + /*! + * Checks to see if we're connected to a socket or not + * + * @return True if it's connected. False otherwise. + */ + inline bool connected() const { - //Don't update it if we're already in that mode - if(should_block == is_blocking) - return; - - //Different API calls needed for both windows and unix - #ifdef WIN32 - u_long non_blocking = should_block ? 0 : 1; - ioctlsocket(socket_descriptor, FIONBIO, &non_blocking); - #else - int flags = fcntl(socket_descriptor, F_GETFL, 0); - fcntl(socket_descriptor, F_SETFL, is_blocking ? flags ^ O_NONBLOCK : flags ^= O_NONBLOCK); - #endif - - is_blocking = should_block; + return is_connected; } + /*! + * Gets the socket descriptor. + * + * @return The socket descriptor. + */ + virtual int32_t get_socket_descriptor() const = 0; + protected: - int32_t socket_descriptor; std::string remote_address; bool is_blocking; + bool is_connected; }; } diff --git a/include/TcpListener.h b/include/TcpListener.h index c800a4a..4711d3f 100644 --- a/include/TcpListener.h +++ b/include/TcpListener.h @@ -37,9 +37,9 @@ public: virtual Socket::Status accept(TcpSocket &client); private: + int32_t socket_descriptor; + //Stubs - virtual Status send(const Packet &packet){return Socket::Error;} - virtual Status receive(Packet &packet){return Socket::Error;} virtual void close(){} virtual Socket::Status connect(const std::string &address, const std::string &port){return Socket::Error;} }; diff --git a/include/TcpSocket.h b/include/TcpSocket.h index 3dac7b2..9da76c5 100644 --- a/include/TcpSocket.h +++ b/include/TcpSocket.h @@ -20,22 +20,6 @@ public: TcpSocket(TcpSocket &&) noexcept = default; void operator=(const TcpSocket &other)=delete; - /*! - * Send a packet through the socket - * - * @param packet The packet to send - * @return True on success, false on failure. - */ - virtual Status send(const Packet &packet); - - /*! - * Receive a packet through the socket - * - * @param packet The packet to receive - * @return True on success, false on failure. - */ - virtual Status receive(Packet &packet); - /*! * Close the connection. */ @@ -57,16 +41,6 @@ public: */ virtual void set_descriptor(int descriptor); - /*! - * Checks to see if we're connected to a socket or not - * - * @return True if it's connected. False otherwise. - */ - inline bool connected() const - { - return is_connected; - } - /*! * Attempts to send raw data down the socket, without * any of frnetlib's framing. Useful for communicating through @@ -76,7 +50,7 @@ public: * @param size The number of bytes, from data to send. Be careful not to overflow. * @return The status of the operation. */ - virtual Status send_raw(const char *data, size_t size); + virtual Status send_raw(const char *data, size_t size) override; /*! @@ -91,25 +65,36 @@ public: * @param received Will be filled with the number of bytes actually received, might be less than you requested. * @return The status of the operation, if the socket has disconnected etc. */ - virtual Status receive_raw(void *data, size_t data_size, size_t &received); + virtual Status receive_raw(void *data, size_t data_size, size_t &received) override; + + /*! + * Sets the connections remote address. + * + * @param addr The remote address to use + */ + void set_remote_address(const std::string &addr) + { + remote_address = addr; + } + + /*! + * Sets if the socket should be blocking or non-blocking. + * + * @param should_block True to block, false otherwise. + */ + virtual void set_blocking(bool should_block) override; + + /*! + * Gets the unerlying socket descriptor + * + * @return The socket descriptor + */ + int32_t get_socket_descriptor() const override; protected: - /*! - * Reads size bytes into dest from the socket. - * Unlike receive_raw, this will keep trying - * to receive data until 'size' bytes have been - * read, or the client has disconnected/there was - * an error. - * - * @param dest Where to read the data into - * @param size The number of bytes to read - * @return Operation status. - */ - Status receive_all(void *dest, size_t size); - std::string unprocessed_buffer; std::unique_ptr recv_buffer; - bool is_connected; + int32_t socket_descriptor; }; } diff --git a/main.cpp b/main.cpp index 26a28e6..019fc0c 100644 --- a/main.cpp +++ b/main.cpp @@ -11,40 +11,40 @@ int main() { -// fr::SSLListener listener; -// if(listener.listen("9091") != fr::Socket::Success) -// { -// std::cout << "Failed to bind to port" << std::endl; -// return 1; -// } -// -// while(true) -// { -// fr::HttpSocket http_socket; -// if(listener.accept(http_socket) != fr::Socket::Success) -// { -// std::cout << "Failed to accept client" << std::endl; -// continue; -// } -// -// fr::HttpRequest request; -// if(http_socket.receive(request) != fr::Socket::Success) -// { -// std::cout << "Failed to receive data" << std::endl; -// continue; -// } -// else -// { -// std::cout << "Read successfully" << std::endl; -// } -// -// std::cout << "Got: " << request.get_body() << std::endl; -// -// fr::HttpResponse response; -// response.set_body("

Hello, SSL World!

"); -// http_socket.send(response); -// http_socket.close(); -// } + fr::SSLListener listener; + if(listener.listen("9091") != fr::Socket::Success) + { + std::cout << "Failed to bind to port" << std::endl; + return 1; + } + + while(true) + { + fr::HttpSocket http_socket; + if(listener.accept(http_socket) != fr::Socket::Success) + { + std::cout << "Failed to accept client" << std::endl; + continue; + } + + fr::HttpRequest request; + if(http_socket.receive(request) != fr::Socket::Success) + { + std::cout << "Failed to receive data" << std::endl; + continue; + } + else + { + std::cout << "Read successfully" << std::endl; + } + + std::cout << "Got: " << request.get_body() << std::endl; + + fr::HttpResponse response; + response.set_body("

Hello, SSL World!

"); + http_socket.send(response); + http_socket.close(); + } // fr::SSLSocket socket; diff --git a/src/SSLListener.cpp b/src/SSLListener.cpp index dcf6afe..6d63469 100644 --- a/src/SSLListener.cpp +++ b/src/SSLListener.cpp @@ -27,14 +27,14 @@ namespace fr return; } - error = mbedtls_x509_crt_parse(&srvcert, (const unsigned char *) mbedtls_test_cas_pem, mbedtls_test_cas_pem_len); + error = mbedtls_x509_crt_parse(&srvcert, (const unsigned char *)mbedtls_test_cas_pem, mbedtls_test_cas_pem_len); if(error != 0) { std::cout << "Failed to initialise SSL listener. PEM Parse returned: " << error << std::endl; return; } - error = mbedtls_pk_parse_key(&pkey, (const unsigned char *) mbedtls_test_srv_key, mbedtls_test_srv_key_len, NULL, 0); + error = mbedtls_pk_parse_key(&pkey, (const unsigned char *)mbedtls_test_srv_key, mbedtls_test_srv_key_len, NULL, 0); if(error != 0) { std::cout << "Failed to initialise SSL listener. Private Key Parse returned: " << error << std::endl; @@ -44,7 +44,8 @@ namespace fr //Seed random number generator if((error = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, NULL, 0)) != 0) { - std::cout << "Failed to initialise SSL listener. Failed to seed random number generator: " << error << std::endl; + std::cout << "Failed to initialise SSL listener. Failed to seed random number generator: " << error + << std::endl; return; } @@ -64,7 +65,6 @@ namespace fr std::cout << "Failed to set certificate: " << error << std::endl; return; } - } SSLListener::~SSLListener() @@ -128,4 +128,4 @@ namespace fr } } -#endif //SSL_ENABLED \ No newline at end of file +#endif \ No newline at end of file diff --git a/src/SSLSocket.cpp b/src/SSLSocket.cpp index 115ccc0..15dfc04 100644 --- a/src/SSLSocket.cpp +++ b/src/SSLSocket.cpp @@ -8,7 +8,8 @@ namespace fr { - SSLSocket::SSLSocket() + SSLSocket::SSLSocket() noexcept + : recv_buffer(new char[RECV_CHUNK_SIZE]) { int error = 0; const char *pers = "ssl_client1"; @@ -34,7 +35,7 @@ namespace fr } } - SSLSocket::~SSLSocket() + SSLSocket::~SSLSocket() noexcept { //Close connection if active close(); diff --git a/src/Socket.cpp b/src/Socket.cpp index e5d7c98..ed7a063 100644 --- a/src/Socket.cpp +++ b/src/Socket.cpp @@ -3,3 +3,66 @@ // #include "Socket.h" + +namespace fr +{ + + Socket::Socket() noexcept + : is_blocking(true), + is_connected(false) + { + + } + + Socket::Status Socket::send(const Packet &packet) + { + //Get packet data + std::string data = packet.get_buffer(); + + //Prepend packet length + uint32_t length = htonl((uint32_t)data.size()); + data.insert(0, "1234"); + memcpy(&data[0], &length, sizeof(uint32_t)); + + //Send it + return send_raw(data.c_str(), data.size()); + } + + Socket::Status Socket::receive(Packet &packet) + { + Socket::Status status; + + //Try to read packet length + uint32_t packet_length = 0; + status = receive_all(&packet_length, sizeof(packet_length)); + if(status != Socket::Status::Success) + return status; + packet_length = ntohl(packet_length); + + //Now we've got the length, read the rest of the data in + std::string data(packet_length, 'c'); + status = receive_all(&data[0], packet_length); + if(status != Socket::Status::Success) + return status; + + //Set the packet to what we've read + packet.set_buffer(std::move(data)); + + return Socket::Status::Success; + } + + Socket::Status Socket::receive_all(void *dest, size_t size) + { + size_t bytes_read = 0; + while(bytes_read < size) + { + size_t read = 0; + Socket::Status status = receive_raw((uintptr_t*)dest + bytes_read, size, read); + if(status == Socket::Status::Success) + bytes_read += read; + else + return status; + } + return Socket::Status::Success; + } +} \ No newline at end of file diff --git a/src/TcpSocket.cpp b/src/TcpSocket.cpp index 4801ccf..09063bb 100644 --- a/src/TcpSocket.cpp +++ b/src/TcpSocket.cpp @@ -9,8 +9,7 @@ namespace fr { TcpSocket::TcpSocket() noexcept - : recv_buffer(new char[RECV_CHUNK_SIZE]), - is_connected(false) + : recv_buffer(new char[RECV_CHUNK_SIZE]) { } @@ -20,20 +19,6 @@ namespace fr close(); } - Socket::Status TcpSocket::send(const Packet &packet) - { - //Get packet data - std::string data = packet.get_buffer(); - - //Prepend packet length - uint32_t length = htonl((uint32_t)data.size()); - data.insert(0, "1234"); - memcpy(&data[0], &length, sizeof(uint32_t)); - - //Send it - return send_raw(data.c_str(), data.size()); - } - Socket::Status TcpSocket::send_raw(const char *data, size_t size) { size_t sent = 0; @@ -58,29 +43,6 @@ namespace fr return Socket::Status::Success; } - Socket::Status TcpSocket::receive(Packet &packet) - { - Socket::Status status; - - //Try to read packet length - uint32_t packet_length = 0; - status = receive_all(&packet_length, sizeof(packet_length)); - if(status != Socket::Status::Success) - return status; - packet_length = ntohl(packet_length); - - //Now we've got the length, read the rest of the data in - std::string data(packet_length, 'c'); - status = receive_all(&data[0], packet_length); - if(status != Socket::Status::Success) - return status; - - //Set the packet to what we've read - packet.set_buffer(std::move(data)); - - return Socket::Status::Success; - } - void TcpSocket::close() { if(is_connected) @@ -90,21 +52,6 @@ namespace fr } } - Socket::Status TcpSocket::receive_all(void *dest, size_t size) - { - size_t bytes_read = 0; - while(bytes_read < size) - { - size_t read = 0; - Socket::Status status = receive_raw((uintptr_t*)dest + bytes_read, size, read); - if(status == Socket::Status::Success) - bytes_read += read; - else - return status; - } - return Socket::Status::Success; - } - Socket::Status TcpSocket::receive_raw(void *data, size_t data_size, size_t &received) { received = 0; @@ -200,4 +147,15 @@ namespace fr return Socket::Status::Success; } + void TcpSocket::set_blocking(bool should_block) + { + set_unix_socket_blocking(socket_descriptor, is_blocking, should_block); + is_blocking = should_block; + } + + int32_t TcpSocket::get_socket_descriptor() const + { + return socket_descriptor; + } + } \ No newline at end of file