Code refactoring

This commit is contained in:
Cloaked9000 2016-12-15 14:57:01 +00:00
parent 14fccb84c9
commit 69d183ed18
12 changed files with 294 additions and 204 deletions

View File

@ -46,7 +46,7 @@ namespace fr
*/ */
Socket::Status send(const Http &request) Socket::Status send(const Http &request)
{ {
std::string data = request.construct(SocketType::remote_address); std::string data = request.construct(SocketType::get_remote_address());
return SocketType::send_raw(&data[0], data.size()); return SocketType::send_raw(&data[0], data.size());
} }
}; };

View File

@ -6,6 +6,7 @@
#define FRNETLIB_NETWORKENCODING_H #define FRNETLIB_NETWORKENCODING_H
#include <netinet/in.h> #include <netinet/in.h>
#include <fcntl.h>
#include <cstring> #include <cstring>
#define htonll(x) ((1==htonl(1)) ? (x) : ((uint64_t)htonl((x) & 0xFFFFFFFF) << 32) | htonl((x) >> 32)) #define htonll(x) ((1==htonl(1)) ? (x) : ((uint64_t)htonl((x) & 0xFFFFFFFF) << 32) | htonl((x) >> 32))
@ -53,6 +54,22 @@ inline void *get_sin_addr(struct sockaddr *sa)
return &(((sockaddr_in6*)sa)->sin6_addr); return &(((sockaddr_in6*)sa)->sin6_addr);
} }
inline void set_unix_socket_blocking(int32_t socket_descriptor, bool is_blocking_already, bool should_block)
{
//Don't update it if we're already in that mode
if(should_block == is_blocking_already)
return;
//Different API calls needed for both windows and unix
#ifdef WIN32
u_long non_blocking = should_block ? 0 : 1;
ioctlsocket(socket_descriptor, FIONBIO, &non_blocking);
#else
int flags = fcntl(socket_descriptor, F_GETFL, 0);
fcntl(socket_descriptor, F_SETFL, is_blocking_already ? flags ^ O_NONBLOCK : flags ^= O_NONBLOCK);
#endif
}
//Windows and UNIX require some different headers. //Windows and UNIX require some different headers.
//We also need some compatibility defines for cross platform support. //We also need some compatibility defines for cross platform support.

View File

@ -46,6 +46,18 @@ namespace fr
*/ */
virtual Socket::Status accept(SSLSocket &client); virtual Socket::Status accept(SSLSocket &client);
/*!
* Enables or disables blocking on the socket.
*
* @param should_block True to block, false otherwise.
*/
virtual void set_blocking(bool should_block) override {abort();}; //Not implemented
virtual int32_t get_socket_descriptor() const override
{
return listen_fd.fd;
}
private: private:
mbedtls_net_context listen_fd; mbedtls_net_context listen_fd;
mbedtls_entropy_context entropy; mbedtls_entropy_context entropy;
@ -55,10 +67,10 @@ namespace fr
mbedtls_pk_context pkey; mbedtls_pk_context pkey;
//Stubs //Stubs
virtual Status send(const Packet &packet){return Socket::Error;}
virtual Status receive(Packet &packet){return Socket::Error;}
virtual void close(){} virtual void close(){}
virtual Socket::Status connect(const std::string &address, const std::string &port){return Socket::Error;} virtual Socket::Status connect(const std::string &address, const std::string &port){return Socket::Error;}
virtual Status send_raw(const char *data, size_t size) {return Socket::Error;}
virtual Status receive_raw(void *data, size_t data_size, size_t &received) {return Socket::Error;}
}; };
} }

View File

@ -66,11 +66,14 @@ const std::string certs =
namespace fr namespace fr
{ {
class SSLSocket : public TcpSocket class SSLSocket : public Socket
{ {
public: public:
SSLSocket(); SSLSocket() noexcept;
~SSLSocket();
~SSLSocket() noexcept;
SSLSocket(SSLSocket &&) noexcept = default;
/*! /*!
* Effectively just fr::TcpSocket::send_raw() with encryption * Effectively just fr::TcpSocket::send_raw() with encryption
@ -80,7 +83,7 @@ namespace fr
* @param size The number of bytes, from data to send. Be careful not to overflow. * @param size The number of bytes, from data to send. Be careful not to overflow.
* @return The status of the operation. * @return The status of the operation.
*/ */
Status send_raw(const char *data, size_t size) override; Socket::Status send_raw(const char *data, size_t size) override;
/*! /*!
@ -92,7 +95,7 @@ namespace fr
* @param received Will be filled with the number of bytes actually received, might be less than you requested. * @param received Will be filled with the number of bytes actually received, might be less than you requested.
* @return The status of the operation, if the socket has disconnected etc. * @return The status of the operation, if the socket has disconnected etc.
*/ */
Status receive_raw(void *data, size_t data_size, size_t &received) override; Socket::Status receive_raw(void *data, size_t data_size, size_t &received) override;
/*! /*!
* Close the connection. * Close the connection.
@ -108,10 +111,44 @@ namespace fr
*/ */
Socket::Status connect(const std::string &address, const std::string &port) override; Socket::Status connect(const std::string &address, const std::string &port) override;
/*!
* Set the SSL context
*
* @param context The SSL context to use
*/
void set_ssl_context(std::unique_ptr<mbedtls_ssl_context> context); void set_ssl_context(std::unique_ptr<mbedtls_ssl_context> context);
/*!
* Set the NET context
*
* @param context The NET context to use
*/
void set_net_context(std::unique_ptr<mbedtls_net_context> context); void set_net_context(std::unique_ptr<mbedtls_net_context> context);
/*!
* Gets the underlying socket descriptor.
*
* @return The socket's descriptor.
*/
virtual int32_t get_socket_descriptor() const override
{
return ssl_socket_descriptor->fd;
}
/*!
* Sets if the socket should block or not.
*
* @param should_block True to block, false otherwise.
*/
virtual void set_blocking(bool should_block) override
{
abort();
}
private: private:
std::string unprocessed_buffer;
std::unique_ptr<char[]> recv_buffer;
std::unique_ptr<mbedtls_net_context> ssl_socket_descriptor; std::unique_ptr<mbedtls_net_context> ssl_socket_descriptor;
mbedtls_entropy_context entropy; mbedtls_entropy_context entropy;
mbedtls_ctr_drbg_context ctr_drbg; mbedtls_ctr_drbg_context ctr_drbg;

View File

@ -28,27 +28,8 @@ namespace fr
VerificationFailed = 9, VerificationFailed = 9,
}; };
Socket() Socket() noexcept;
: is_blocking(true) virtual ~Socket() noexcept = default;
{
}
/*!
* Send a packet through the socket
*
* @param packet The packet to send
* @return A status enum value indicating if the operation succeeded or not. Success on success, Error on failure, Disconnected on disconnection etc.
*/
virtual Status send(const Packet &packet)=0;
/*!
* Receive a packet through the socket
*
* @param packet The packet to receive
* @return A status enum value indicating if the operation succeeded or not. Success on success, Error on failure, Disconnected on disconnection, WouldBlock if the socket is non-blocking and the socket was not ready to receive, etc.
*/
virtual Status receive(Packet &packet)=0;
/*! /*!
* Close the connection. * Close the connection.
@ -64,16 +45,6 @@ namespace fr
*/ */
virtual Socket::Status connect(const std::string &address, const std::string &port)=0; virtual Socket::Status connect(const std::string &address, const std::string &port)=0;
/*!
* Sets the socket's printable remote address
*
* @param addr The string address
*/
inline virtual void set_remote_address(const std::string &addr)
{
remote_address = addr;
}
/*! /*!
* Gets the socket's printable remote address * Gets the socket's printable remote address
* *
@ -84,43 +55,89 @@ namespace fr
return remote_address; return remote_address;
} }
/*!
* Gets the socket descriptor of the object
*
* @return The socket file descriptor
*/
inline int32_t get_socket_descriptor() const
{
return socket_descriptor;
}
/*! /*!
* Sets the socket to blocking or non-blocking. * Sets the socket to blocking or non-blocking.
* *
* @param should_block True for blocking (default argument), false otherwise. * @param should_block True for blocking (default argument), false otherwise.
*/ */
inline virtual void set_blocking(bool should_block = true) virtual void set_blocking(bool should_block = true) = 0;
/*!
* Attempts to send raw data down the socket, without
* any of frnetlib's framing. Useful for communicating through
* different protocols.
*
* @param data The data to send.
* @param size The number of bytes, from data to send. Be careful not to overflow.
* @return The status of the operation.
*/
virtual Status send_raw(const char *data, size_t size) = 0;
/*!
* Receives raw data from the socket, without any of
* frnetlib's framing. Useful for communicating through
* different protocols. This will attempt to read 'data_size'
* bytes, but might not succeed. It'll return how many bytes were actually
* read in 'received'.
*
* @param data Where to store the received data.
* @param data_size The number of bytes to try and receive. Be sure that it's not larger than data.
* @param received Will be filled with the number of bytes actually received, might be less than you requested.
* @return The status of the operation, if the socket has disconnected etc.
*/
virtual Status receive_raw(void *data, size_t data_size, size_t &received) = 0;
/*!
* Send a packet through the socket
*
* @param packet The packet to send
* @return True on success, false on failure.
*/
Status send(const Packet &packet);
/*!
* Receive a packet through the socket
*
* @param packet The packet to receive
* @return True on success, false on failure.
*/
Status receive(Packet &packet);
/*!
* Reads size bytes into dest from the socket.
* Unlike receive_raw, this will keep trying
* to receive data until 'size' bytes have been
* read, or the client has disconnected/there was
* an error.
*
* @param dest Where to read the data into
* @param size The number of bytes to read
* @return Operation status.
*/
Status receive_all(void *dest, size_t size);
/*!
* Checks to see if we're connected to a socket or not
*
* @return True if it's connected. False otherwise.
*/
inline bool connected() const
{ {
//Don't update it if we're already in that mode return is_connected;
if(should_block == is_blocking)
return;
//Different API calls needed for both windows and unix
#ifdef WIN32
u_long non_blocking = should_block ? 0 : 1;
ioctlsocket(socket_descriptor, FIONBIO, &non_blocking);
#else
int flags = fcntl(socket_descriptor, F_GETFL, 0);
fcntl(socket_descriptor, F_SETFL, is_blocking ? flags ^ O_NONBLOCK : flags ^= O_NONBLOCK);
#endif
is_blocking = should_block;
} }
/*!
* Gets the socket descriptor.
*
* @return The socket descriptor.
*/
virtual int32_t get_socket_descriptor() const = 0;
protected: protected:
int32_t socket_descriptor;
std::string remote_address; std::string remote_address;
bool is_blocking; bool is_blocking;
bool is_connected;
}; };
} }

View File

@ -37,9 +37,9 @@ public:
virtual Socket::Status accept(TcpSocket &client); virtual Socket::Status accept(TcpSocket &client);
private: private:
int32_t socket_descriptor;
//Stubs //Stubs
virtual Status send(const Packet &packet){return Socket::Error;}
virtual Status receive(Packet &packet){return Socket::Error;}
virtual void close(){} virtual void close(){}
virtual Socket::Status connect(const std::string &address, const std::string &port){return Socket::Error;} virtual Socket::Status connect(const std::string &address, const std::string &port){return Socket::Error;}
}; };

View File

@ -20,22 +20,6 @@ public:
TcpSocket(TcpSocket &&) noexcept = default; TcpSocket(TcpSocket &&) noexcept = default;
void operator=(const TcpSocket &other)=delete; void operator=(const TcpSocket &other)=delete;
/*!
* Send a packet through the socket
*
* @param packet The packet to send
* @return True on success, false on failure.
*/
virtual Status send(const Packet &packet);
/*!
* Receive a packet through the socket
*
* @param packet The packet to receive
* @return True on success, false on failure.
*/
virtual Status receive(Packet &packet);
/*! /*!
* Close the connection. * Close the connection.
*/ */
@ -57,16 +41,6 @@ public:
*/ */
virtual void set_descriptor(int descriptor); virtual void set_descriptor(int descriptor);
/*!
* Checks to see if we're connected to a socket or not
*
* @return True if it's connected. False otherwise.
*/
inline bool connected() const
{
return is_connected;
}
/*! /*!
* Attempts to send raw data down the socket, without * Attempts to send raw data down the socket, without
* any of frnetlib's framing. Useful for communicating through * any of frnetlib's framing. Useful for communicating through
@ -76,7 +50,7 @@ public:
* @param size The number of bytes, from data to send. Be careful not to overflow. * @param size The number of bytes, from data to send. Be careful not to overflow.
* @return The status of the operation. * @return The status of the operation.
*/ */
virtual Status send_raw(const char *data, size_t size); virtual Status send_raw(const char *data, size_t size) override;
/*! /*!
@ -91,25 +65,36 @@ public:
* @param received Will be filled with the number of bytes actually received, might be less than you requested. * @param received Will be filled with the number of bytes actually received, might be less than you requested.
* @return The status of the operation, if the socket has disconnected etc. * @return The status of the operation, if the socket has disconnected etc.
*/ */
virtual Status receive_raw(void *data, size_t data_size, size_t &received); virtual Status receive_raw(void *data, size_t data_size, size_t &received) override;
/*!
* Sets the connections remote address.
*
* @param addr The remote address to use
*/
void set_remote_address(const std::string &addr)
{
remote_address = addr;
}
/*!
* Sets if the socket should be blocking or non-blocking.
*
* @param should_block True to block, false otherwise.
*/
virtual void set_blocking(bool should_block) override;
/*!
* Gets the unerlying socket descriptor
*
* @return The socket descriptor
*/
int32_t get_socket_descriptor() const override;
protected: protected:
/*!
* Reads size bytes into dest from the socket.
* Unlike receive_raw, this will keep trying
* to receive data until 'size' bytes have been
* read, or the client has disconnected/there was
* an error.
*
* @param dest Where to read the data into
* @param size The number of bytes to read
* @return Operation status.
*/
Status receive_all(void *dest, size_t size);
std::string unprocessed_buffer; std::string unprocessed_buffer;
std::unique_ptr<char[]> recv_buffer; std::unique_ptr<char[]> recv_buffer;
bool is_connected; int32_t socket_descriptor;
}; };
} }

View File

@ -11,40 +11,40 @@
int main() int main()
{ {
// fr::SSLListener listener; fr::SSLListener listener;
// if(listener.listen("9091") != fr::Socket::Success) if(listener.listen("9091") != fr::Socket::Success)
// { {
// std::cout << "Failed to bind to port" << std::endl; std::cout << "Failed to bind to port" << std::endl;
// return 1; return 1;
// } }
//
// while(true) while(true)
// { {
// fr::HttpSocket<fr::SSLSocket> http_socket; fr::HttpSocket<fr::SSLSocket> http_socket;
// if(listener.accept(http_socket) != fr::Socket::Success) if(listener.accept(http_socket) != fr::Socket::Success)
// { {
// std::cout << "Failed to accept client" << std::endl; std::cout << "Failed to accept client" << std::endl;
// continue; continue;
// } }
//
// fr::HttpRequest request; fr::HttpRequest request;
// if(http_socket.receive(request) != fr::Socket::Success) if(http_socket.receive(request) != fr::Socket::Success)
// { {
// std::cout << "Failed to receive data" << std::endl; std::cout << "Failed to receive data" << std::endl;
// continue; continue;
// } }
// else else
// { {
// std::cout << "Read successfully" << std::endl; std::cout << "Read successfully" << std::endl;
// } }
//
// std::cout << "Got: " << request.get_body() << std::endl; std::cout << "Got: " << request.get_body() << std::endl;
//
// fr::HttpResponse response; fr::HttpResponse response;
// response.set_body("<h1>Hello, SSL World!</h1>"); response.set_body("<h1>Hello, SSL World!</h1>");
// http_socket.send(response); http_socket.send(response);
// http_socket.close(); http_socket.close();
// } }
// fr::SSLSocket socket; // fr::SSLSocket socket;

View File

@ -27,14 +27,14 @@ namespace fr
return; return;
} }
error = mbedtls_x509_crt_parse(&srvcert, (const unsigned char *) mbedtls_test_cas_pem, mbedtls_test_cas_pem_len); error = mbedtls_x509_crt_parse(&srvcert, (const unsigned char *)mbedtls_test_cas_pem, mbedtls_test_cas_pem_len);
if(error != 0) if(error != 0)
{ {
std::cout << "Failed to initialise SSL listener. PEM Parse returned: " << error << std::endl; std::cout << "Failed to initialise SSL listener. PEM Parse returned: " << error << std::endl;
return; return;
} }
error = mbedtls_pk_parse_key(&pkey, (const unsigned char *) mbedtls_test_srv_key, mbedtls_test_srv_key_len, NULL, 0); error = mbedtls_pk_parse_key(&pkey, (const unsigned char *)mbedtls_test_srv_key, mbedtls_test_srv_key_len, NULL, 0);
if(error != 0) if(error != 0)
{ {
std::cout << "Failed to initialise SSL listener. Private Key Parse returned: " << error << std::endl; std::cout << "Failed to initialise SSL listener. Private Key Parse returned: " << error << std::endl;
@ -44,7 +44,8 @@ namespace fr
//Seed random number generator //Seed random number generator
if((error = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, NULL, 0)) != 0) if((error = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, NULL, 0)) != 0)
{ {
std::cout << "Failed to initialise SSL listener. Failed to seed random number generator: " << error << std::endl; std::cout << "Failed to initialise SSL listener. Failed to seed random number generator: " << error
<< std::endl;
return; return;
} }
@ -64,7 +65,6 @@ namespace fr
std::cout << "Failed to set certificate: " << error << std::endl; std::cout << "Failed to set certificate: " << error << std::endl;
return; return;
} }
} }
SSLListener::~SSLListener() SSLListener::~SSLListener()
@ -128,4 +128,4 @@ namespace fr
} }
} }
#endif //SSL_ENABLED #endif

View File

@ -8,7 +8,8 @@
namespace fr namespace fr
{ {
SSLSocket::SSLSocket() SSLSocket::SSLSocket() noexcept
: recv_buffer(new char[RECV_CHUNK_SIZE])
{ {
int error = 0; int error = 0;
const char *pers = "ssl_client1"; const char *pers = "ssl_client1";
@ -34,7 +35,7 @@ namespace fr
} }
} }
SSLSocket::~SSLSocket() SSLSocket::~SSLSocket() noexcept
{ {
//Close connection if active //Close connection if active
close(); close();

View File

@ -3,3 +3,66 @@
// //
#include "Socket.h" #include "Socket.h"
namespace fr
{
Socket::Socket() noexcept
: is_blocking(true),
is_connected(false)
{
}
Socket::Status Socket::send(const Packet &packet)
{
//Get packet data
std::string data = packet.get_buffer();
//Prepend packet length
uint32_t length = htonl((uint32_t)data.size());
data.insert(0, "1234");
memcpy(&data[0], &length, sizeof(uint32_t));
//Send it
return send_raw(data.c_str(), data.size());
}
Socket::Status Socket::receive(Packet &packet)
{
Socket::Status status;
//Try to read packet length
uint32_t packet_length = 0;
status = receive_all(&packet_length, sizeof(packet_length));
if(status != Socket::Status::Success)
return status;
packet_length = ntohl(packet_length);
//Now we've got the length, read the rest of the data in
std::string data(packet_length, 'c');
status = receive_all(&data[0], packet_length);
if(status != Socket::Status::Success)
return status;
//Set the packet to what we've read
packet.set_buffer(std::move(data));
return Socket::Status::Success;
}
Socket::Status Socket::receive_all(void *dest, size_t size)
{
size_t bytes_read = 0;
while(bytes_read < size)
{
size_t read = 0;
Socket::Status status = receive_raw((uintptr_t*)dest + bytes_read, size, read);
if(status == Socket::Status::Success)
bytes_read += read;
else
return status;
}
return Socket::Status::Success;
}
}

View File

@ -9,8 +9,7 @@ namespace fr
{ {
TcpSocket::TcpSocket() noexcept TcpSocket::TcpSocket() noexcept
: recv_buffer(new char[RECV_CHUNK_SIZE]), : recv_buffer(new char[RECV_CHUNK_SIZE])
is_connected(false)
{ {
} }
@ -20,20 +19,6 @@ namespace fr
close(); close();
} }
Socket::Status TcpSocket::send(const Packet &packet)
{
//Get packet data
std::string data = packet.get_buffer();
//Prepend packet length
uint32_t length = htonl((uint32_t)data.size());
data.insert(0, "1234");
memcpy(&data[0], &length, sizeof(uint32_t));
//Send it
return send_raw(data.c_str(), data.size());
}
Socket::Status TcpSocket::send_raw(const char *data, size_t size) Socket::Status TcpSocket::send_raw(const char *data, size_t size)
{ {
size_t sent = 0; size_t sent = 0;
@ -58,29 +43,6 @@ namespace fr
return Socket::Status::Success; return Socket::Status::Success;
} }
Socket::Status TcpSocket::receive(Packet &packet)
{
Socket::Status status;
//Try to read packet length
uint32_t packet_length = 0;
status = receive_all(&packet_length, sizeof(packet_length));
if(status != Socket::Status::Success)
return status;
packet_length = ntohl(packet_length);
//Now we've got the length, read the rest of the data in
std::string data(packet_length, 'c');
status = receive_all(&data[0], packet_length);
if(status != Socket::Status::Success)
return status;
//Set the packet to what we've read
packet.set_buffer(std::move(data));
return Socket::Status::Success;
}
void TcpSocket::close() void TcpSocket::close()
{ {
if(is_connected) if(is_connected)
@ -90,21 +52,6 @@ namespace fr
} }
} }
Socket::Status TcpSocket::receive_all(void *dest, size_t size)
{
size_t bytes_read = 0;
while(bytes_read < size)
{
size_t read = 0;
Socket::Status status = receive_raw((uintptr_t*)dest + bytes_read, size, read);
if(status == Socket::Status::Success)
bytes_read += read;
else
return status;
}
return Socket::Status::Success;
}
Socket::Status TcpSocket::receive_raw(void *data, size_t data_size, size_t &received) Socket::Status TcpSocket::receive_raw(void *data, size_t data_size, size_t &received)
{ {
received = 0; received = 0;
@ -200,4 +147,15 @@ namespace fr
return Socket::Status::Success; return Socket::Status::Success;
} }
void TcpSocket::set_blocking(bool should_block)
{
set_unix_socket_blocking(socket_descriptor, is_blocking, should_block);
is_blocking = should_block;
}
int32_t TcpSocket::get_socket_descriptor() const
{
return socket_descriptor;
}
} }