diff --git a/CMakeLists.txt b/CMakeLists.txt index 0b52739..f4cc37c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,7 +9,7 @@ INCLUDE_DIRECTORIES(${MBEDTLS_INCLUDE_DIR}) include_directories(include) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -m64 -fPIC -pthread -lmbedtls -lmbedx509 -lmbedcrypto") -set(SOURCE_FILES main.cpp src/TcpSocket.cpp include/TcpSocket.h src/TcpListener.cpp include/TcpListener.h src/Socket.cpp include/Socket.h src/Packet.cpp include/Packet.h include/NetworkEncoding.h src/SocketSelector.cpp include/SocketSelector.h src/HttpSocket.cpp include/HttpSocket.h src/HttpRequest.cpp include/HttpRequest.h src/HttpResponse.cpp include/HttpResponse.h src/Http.cpp include/Http.h include/TLSSocket.cpp include/TLSSocket.h) +set(SOURCE_FILES main.cpp src/TcpSocket.cpp include/TcpSocket.h src/TcpListener.cpp include/TcpListener.h src/Socket.cpp include/Socket.h src/Packet.cpp include/Packet.h include/NetworkEncoding.h src/SocketSelector.cpp include/SocketSelector.h src/HttpSocket.cpp include/HttpSocket.h src/HttpRequest.cpp include/HttpRequest.h src/HttpResponse.cpp include/HttpResponse.h src/Http.cpp include/Http.h include/SSLSocket.cpp include/SSLSocket.h include/SSLListener.cpp include/SSLListener.h) add_executable(frnetlib ${SOURCE_FILES}) TARGET_LINK_LIBRARIES(frnetlib ${MBEDTLS_LIBRARIES} -lmbedtls -lmbedx509 -lmbedcrypto) \ No newline at end of file diff --git a/include/SSLListener.cpp b/include/SSLListener.cpp new file mode 100644 index 0000000..8aae3de --- /dev/null +++ b/include/SSLListener.cpp @@ -0,0 +1,128 @@ +// +// Created by fred on 13/12/16. +// + +#include "SSLListener.h" + +namespace fr +{ + SSLListener::SSLListener() noexcept + { + //Initialise SSL objects required + mbedtls_net_init(&listen_fd); + mbedtls_ssl_init(&ssl); + mbedtls_ssl_config_init(&conf); + mbedtls_x509_crt_init(&srvcert); + mbedtls_pk_init(&pkey); + mbedtls_entropy_init(&entropy); + mbedtls_ctr_drbg_init(&ctr_drbg); + + int error = 0; + + //Load certificates and private key todo: Switch from inbuilt test certificates + error = mbedtls_x509_crt_parse(&srvcert, (const unsigned char *)mbedtls_test_srv_crt, mbedtls_test_srv_crt_len); + if(error != 0) + { + std::cout << "Failed to initialise SSL listener. CRT Parse returned: " << error << std::endl; + return; + } + + error = mbedtls_x509_crt_parse(&srvcert, (const unsigned char *) mbedtls_test_cas_pem, mbedtls_test_cas_pem_len); + if(error != 0) + { + std::cout << "Failed to initialise SSL listener. PEM Parse returned: " << error << std::endl; + return; + } + + error = mbedtls_pk_parse_key(&pkey, (const unsigned char *) mbedtls_test_srv_key, mbedtls_test_srv_key_len, NULL, 0); + if(error != 0) + { + std::cout << "Failed to initialise SSL listener. Private Key Parse returned: " << error << std::endl; + return; + } + + //Seed random number generator + if((error = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, NULL, 0)) != 0) + { + std::cout << "Failed to initialise SSL listener. Failed to seed random number generator: " << error << std::endl; + return; + } + + //Setup data structures + if((error = mbedtls_ssl_config_defaults(&conf, MBEDTLS_SSL_IS_SERVER, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT)) != 0) + { + std::cout << "Failed to configure SSL presets: " << error << std::endl; + return; + } + + //Apply them + mbedtls_ssl_conf_rng(&conf, mbedtls_ctr_drbg_random, &ctr_drbg); + mbedtls_ssl_conf_ca_chain(&conf, srvcert.next, NULL); + + if((error = mbedtls_ssl_conf_own_cert(&conf, &srvcert, &pkey)) != 0) + { + std::cout << "Failed to set certificate: " << error << std::endl; + return; + } + + if((error = mbedtls_ssl_setup( &ssl, &conf ) ) != 0) + { + std::cout << "Failed to apply SSL setings: " << error << std::endl; + return; + } + + } + + SSLListener::~SSLListener() + { + mbedtls_net_free(&listen_fd); + mbedtls_x509_crt_free(&srvcert); + mbedtls_pk_free(&pkey); + mbedtls_ssl_free(&ssl); + mbedtls_ssl_config_free(&conf); + mbedtls_ctr_drbg_free(&ctr_drbg); + mbedtls_entropy_free( &entropy); + } + + Socket::Status fr::SSLListener::listen(const std::string &port) + { + //Bind to port + if(mbedtls_net_bind( &listen_fd, NULL, port.c_str(), MBEDTLS_NET_PROTO_TCP) != 0) + { + return Socket::BindFailed; + } + return Socket::Success; + } + + Socket::Status SSLListener::accept(SSLSocket &client) + { + int error = 0; + + //Accept a connection + mbedtls_net_context client_fd; + mbedtls_net_init(&client_fd); + + if((error = mbedtls_net_accept(&listen_fd, &client_fd, NULL, 0, NULL)) != 0) + { + std::cout << "Accept error: " << error << std::endl; + return Socket::Error; + } + + mbedtls_ssl_set_bio( &ssl, &client_fd, mbedtls_net_send, mbedtls_net_recv, NULL); + + //SSL Handshake + while((error = mbedtls_ssl_handshake(&ssl)) != 0) + { + if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE) + { + std::cout << "Handshake failed: " << error << std::end + return Socket::Status::HandshakeFailed; + } + } + + //Set socket details + client.set_descriptor(client_fd.fd); + return Socket::Success; + } + +} \ No newline at end of file diff --git a/include/SSLListener.h b/include/SSLListener.h new file mode 100644 index 0000000..c1666f7 --- /dev/null +++ b/include/SSLListener.h @@ -0,0 +1,68 @@ +// +// Created by fred on 13/12/16. +// + +#ifndef FRNETLIB_SSLLISTENER_H +#define FRNETLIB_SSLLISTENER_H + +//#define SSL_ENABLED + +#ifdef SSL_ENABLED + +#include +#include +#include +#include +#include +#include +#include + +#include "TcpListener.h" +#include "SSLSocket.h" + + +namespace fr +{ + class SSLListener : public Socket + { + public: + SSLListener() noexcept; + virtual ~SSLListener() noexcept; + SSLListener(SSLListener &&o) noexcept = default; + + /*! + * Listens to the given port for connections + * + * @param port The port to bind to + * @return If the operation was successful + */ + virtual Socket::Status listen(const std::string &port); + + /*! + * Accepts a new connection. + * + * @param client Where to store the connection information + * @return True on success. False on failure. + */ + virtual Socket::Status accept(SSLSocket &client); + + private: + mbedtls_net_context listen_fd; + mbedtls_entropy_context entropy; + mbedtls_ctr_drbg_context ctr_drbg; + mbedtls_ssl_context ssl; + mbedtls_ssl_config conf; + mbedtls_x509_crt srvcert; + mbedtls_pk_context pkey; + + //Stubs + virtual Status send(const Packet &packet){return Socket::Error;} + virtual Status receive(Packet &packet){return Socket::Error;} + virtual void close(){} + virtual Socket::Status connect(const std::string &address, const std::string &port){return Socket::Error;} + }; + +} + +#endif //SLL_ENABLED +#endif //FRNETLIB_SSLLISTENER_H diff --git a/include/SSLSocket.cpp b/include/SSLSocket.cpp new file mode 100644 index 0000000..52c8e69 --- /dev/null +++ b/include/SSLSocket.cpp @@ -0,0 +1,182 @@ +// +// Created by fred on 12/12/16. +// + +#include "SSLSocket.h" +#ifdef SSL_ENABLED + +namespace fr +{ + SSLSocket::SSLSocket() + { + int error = 0; + const char *pers = "ssl_client1"; + + //Initialise mbedtls structures + mbedtls_net_init(&ssl_socket_descriptor); + mbedtls_ssl_init(&ssl); + mbedtls_ssl_config_init(&conf); + mbedtls_x509_crt_init(&cacert); + mbedtls_ctr_drbg_init(&ctr_drbg); + + //Seed random number generator + mbedtls_entropy_init(&entropy); + if((error = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, (const unsigned char *)pers, strlen(pers))) != 0) + { + std::cout << "Failed to initialise random number generator. Returned error: " << error << std::endl; + return; + } + + //Load root CA certificate + if((error = mbedtls_x509_crt_parse(&cacert, (const unsigned char *)certs.c_str(), certs.size() + 1) < 0)) + { + std::cout << "Failed to parse root CA certificate. Parse returned: " << error << std::endl; + return; + } + } + + SSLSocket::~SSLSocket() + { + //Close connection if active + close(); + + //Cleanup mbedsql stuff + mbedtls_net_free(&ssl_socket_descriptor); + mbedtls_x509_crt_free(&cacert); + mbedtls_ssl_free(&ssl); + mbedtls_ssl_config_free(&conf); + mbedtls_ctr_drbg_free(&ctr_drbg); + mbedtls_entropy_free(&entropy); + } + + void SSLSocket::close() + { + if(is_connected) + { + mbedtls_ssl_close_notify(&ssl); + is_connected = false; + } + } + + Socket::Status SSLSocket::send_raw(const char *data, size_t size) + { + int error = 0; + while((error = mbedtls_ssl_write(&ssl, (const unsigned char *)data, size)) <= 0) + { + if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE) + { + return Socket::Status::Error; + } + } + + return Socket::Status::Success; + } + + Socket::Status SSLSocket::receive_raw(void *data, size_t data_size, size_t &received) + { + int read = 0; + received = 0; + if(unprocessed_buffer.size() < data_size) + { + read = mbedtls_ssl_read(&ssl, (unsigned char *)recv_buffer.get(), RECV_CHUNK_SIZE); + + if(read == MBEDTLS_ERR_SSL_WANT_READ || read == MBEDTLS_ERR_SSL_WANT_WRITE) + { + received = 0; + return Socket::Status::Success; + } + else if(read == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) + { + std::cout << "disconnected" << std::endl; + return Socket::Status::Disconnected; + } + else if(read <= 0) + { + std::cout << "read <= 0" << std::endl; + return Socket::Status::Error; + } + + received += read; + unprocessed_buffer += {recv_buffer.get(), (size_t)read}; + + if(received > data_size) + received = data_size; + } + else + { + received = data_size; + } + + //Copy data to where it needs to go + memcpy(data, &unprocessed_buffer[0], received); + unprocessed_buffer.erase(0, received); + return Socket::Status::Success; + + } + + void SSLSocket::set_descriptor(int descriptor) + { + is_connected = true; + socket_descriptor = descriptor; + ssl_socket_descriptor.fd = descriptor; + } + + Socket::Status SSLSocket::connect(const std::string &address, const std::string &port) + { + //Initialise the connection using mbedtls + int error = 0; + if((error = mbedtls_net_connect(&ssl_socket_descriptor, address.c_str(), port.c_str(), MBEDTLS_NET_PROTO_TCP)) != 0) + { + return Socket::Status::ConnectionFailed; + } + + //Initialise SSL data structures + if((error = mbedtls_ssl_config_defaults(&conf, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT)) != 0) + { + return Socket::Status::Error; + } + + mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_REQUIRED); + mbedtls_ssl_conf_ca_chain(&conf, &cacert, NULL); + mbedtls_ssl_conf_rng(&conf, mbedtls_ctr_drbg_random, &ctr_drbg); + + if((error = mbedtls_ssl_setup(&ssl, &conf)) != 0) + { + return Socket::Status::Error; + } + + if((error = mbedtls_ssl_set_hostname(&ssl, address.c_str())) != 0) + { + return Socket::Status::Error; + } + + mbedtls_ssl_set_bio(&ssl, &ssl_socket_descriptor, mbedtls_net_send, mbedtls_net_recv, NULL); + + //Do SSL handshake + while((error = mbedtls_ssl_handshake(&ssl)) != 0) + { + if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE) + { + std::cout << "Failed to connect to server. Handshake returned: " << error << std::endl; + return Socket::Status::HandshakeFailed; + } + } + + //Verify server certificate + if((flags = mbedtls_ssl_get_verify_result(&ssl)) != 0) + { + char verify_buffer[512]; + mbedtls_x509_crt_verify_info( verify_buffer, sizeof( verify_buffer ), " ! ", flags ); + + std::cout << "Failed to connect to server. Server certificate validation failed: " << verify_buffer << std::endl; + return Socket::Status::VerificationFailed; + } + + //Update members + is_connected = true; + remote_address = address + ":" + port; + return Socket::Status::Success; + } +} + +#endif \ No newline at end of file diff --git a/include/SSLSocket.h b/include/SSLSocket.h new file mode 100644 index 0000000..e93ad27 --- /dev/null +++ b/include/SSLSocket.h @@ -0,0 +1,131 @@ +// +// Created by fred on 12/12/16. +// + +#ifndef FRNETLIB_SSL_SOCKET_H +#define FRNETLIB_SSL_SOCKET_H + +//#define SSL_ENABLED + +#ifdef SSL_ENABLED + +#include "TcpSocket.h" +#include +#include +#include +#include +#include +#include +#include + +const std::string certs = + "-----BEGIN CERTIFICATE-----\n" + "MIIHyTCCBbGgAwIBAgIBATANBgkqhkiG9w0BAQUFADB9MQswCQYDVQQGEwJJTDEW\n" + "MBQGA1UEChMNU3RhcnRDb20gTHRkLjErMCkGA1UECxMiU2VjdXJlIERpZ2l0YWwg\n" + "Q2VydGlmaWNhdGUgU2lnbmluZzEpMCcGA1UEAxMgU3RhcnRDb20gQ2VydGlmaWNh\n" + "dGlvbiBBdXRob3JpdHkwHhcNMDYwOTE3MTk0NjM2WhcNMzYwOTE3MTk0NjM2WjB9\n" + "MQswCQYDVQQGEwJJTDEWMBQGA1UEChMNU3RhcnRDb20gTHRkLjErMCkGA1UECxMi\n" + "U2VjdXJlIERpZ2l0YWwgQ2VydGlmaWNhdGUgU2lnbmluZzEpMCcGA1UEAxMgU3Rh\n" + "cnRDb20gQ2VydGlmaWNhdGlvbiBBdXRob3JpdHkwggIiMA0GCSqGSIb3DQEBAQUA\n" + "A4ICDwAwggIKAoICAQDBiNsJvGxGfHiflXu1M5DycmLWwTYgIiRezul38kMKogZk\n" + "pMyONvg45iPwbm2xPN1yo4UcodM9tDMr0y+v/uqwQVlntsQGfQqedIXWeUyAN3rf\n" + "OQVSWff0G0ZDpNKFhdLDcfN1YjS6LIp/Ho/u7TTQEceWzVI9ujPW3U3eCztKS5/C\n" + "Ji/6tRYccjV3yjxd5srhJosaNnZcAdt0FCX+7bWgiA/deMotHweXMAEtcnn6RtYT\n" + "Kqi5pquDSR3l8u/d5AGOGAqPY1MWhWKpDhk6zLVmpsJrdAfkK+F2PrRt2PZE4XNi\n" + "HzvEvqBTViVsUQn3qqvKv3b9bZvzndu/PWa8DFaqr5hIlTpL36dYUNk4dalb6kMM\n" + "Av+Z6+hsTXBbKWWc3apdzK8BMewM69KN6Oqce+Zu9ydmDBpI125C4z/eIT574Q1w\n" + "+2OqqGwaVLRcJXrJosmLFqa7LH4XXgVNWG4SHQHuEhANxjJ/GP/89PrNbpHoNkm+\n" + "Gkhpi8KWTRoSsmkXwQqQ1vp5Iki/untp+HDH+no32NgN0nZPV/+Qt+OR0t3vwmC3\n" + "Zzrd/qqc8NSLf3Iizsafl7b4r4qgEKjZ+xjGtrVcUjyJthkqcwEKDwOzEmDyei+B\n" + "26Nu/yYwl/WL3YlXtq09s68rxbd2AvCl1iuahhQqcvbjM4xdCUsT37uMdBNSSwID\n" + "AQABo4ICUjCCAk4wDAYDVR0TBAUwAwEB/zALBgNVHQ8EBAMCAa4wHQYDVR0OBBYE\n" + "FE4L7xqkQFulF2mHMMo0aEPQQa7yMGQGA1UdHwRdMFswLKAqoCiGJmh0dHA6Ly9j\n" + "ZXJ0LnN0YXJ0Y29tLm9yZy9zZnNjYS1jcmwuY3JsMCugKaAnhiVodHRwOi8vY3Js\n" + "LnN0YXJ0Y29tLm9yZy9zZnNjYS1jcmwuY3JsMIIBXQYDVR0gBIIBVDCCAVAwggFM\n" + "BgsrBgEEAYG1NwEBATCCATswLwYIKwYBBQUHAgEWI2h0dHA6Ly9jZXJ0LnN0YXJ0\n" + "Y29tLm9yZy9wb2xpY3kucGRmMDUGCCsGAQUFBwIBFilodHRwOi8vY2VydC5zdGFy\n" + "dGNvbS5vcmcvaW50ZXJtZWRpYXRlLnBkZjCB0AYIKwYBBQUHAgIwgcMwJxYgU3Rh\n" + "cnQgQ29tbWVyY2lhbCAoU3RhcnRDb20pIEx0ZC4wAwIBARqBl0xpbWl0ZWQgTGlh\n" + "YmlsaXR5LCByZWFkIHRoZSBzZWN0aW9uICpMZWdhbCBMaW1pdGF0aW9ucyogb2Yg\n" + "dGhlIFN0YXJ0Q29tIENlcnRpZmljYXRpb24gQXV0aG9yaXR5IFBvbGljeSBhdmFp\n" + "bGFibGUgYXQgaHR0cDovL2NlcnQuc3RhcnRjb20ub3JnL3BvbGljeS5wZGYwEQYJ\n" + "YIZIAYb4QgEBBAQDAgAHMDgGCWCGSAGG+EIBDQQrFilTdGFydENvbSBGcmVlIFNT\n" + "TCBDZXJ0aWZpY2F0aW9uIEF1dGhvcml0eTANBgkqhkiG9w0BAQUFAAOCAgEAFmyZ\n" + "9GYMNPXQhV59CuzaEE44HF7fpiUFS5Eyweg78T3dRAlbB0mKKctmArexmvclmAk8\n" + "jhvh3TaHK0u7aNM5Zj2gJsfyOZEdUauCe37Vzlrk4gNXcGmXCPleWKYK34wGmkUW\n" + "FjgKXlf2Ysd6AgXmvB618p70qSmD+LIU424oh0TDkBreOKk8rENNZEXO3SipXPJz\n" + "ewT4F+irsfMuXGRuczE6Eri8sxHkfY+BUZo7jYn0TZNmezwD7dOaHZrzZVD1oNB1\n" + "ny+v8OqCQ5j4aZyJecRDjkZy42Q2Eq/3JR44iZB3fsNrarnDy0RLrHiQi+fHLB5L\n" + "EUTINFInzQpdn4XBidUaePKVEFMy3YCEZnXZtWgo+2EuvoSoOMCZEoalHmdkrQYu\n" + "L6lwhceWD3yJZfWOQ1QOq92lgDmUYMA0yZZwLKMS9R9Ie70cfmu3nZD0Ijuu+Pwq\n" + "yvqCUqDvr0tVk+vBtfAii6w0TiYiBKGHLHVKt+V9E9e4DGTANtLJL4YSjCMJwRuC\n" + "O3NJo2pXh5Tl1njFmUNj403gdy3hZZlyaQQaRwnmDwFWJPsfvw55qVguucQJAX6V\n" + "um0ABj6y6koQOdjQK/W/7HW/lwLFCRsI3FU34oH7N4RDYiDK51ZLZer+bMEkkySh\n" + "NOsF/5oirpt9P/FlUQqmMGqz9IgcgA38corog14=\n" + "-----END CERTIFICATE-----"; + +namespace fr +{ + class SSLSocket : public TcpSocket + { + public: + SSLSocket(); + ~SSLSocket(); + + /*! + * Effectively just fr::TcpSocket::send_raw() with encryption + * added in. + * + * @param data The data to send. + * @param size The number of bytes, from data to send. Be careful not to overflow. + * @return The status of the operation. + */ + Status send_raw(const char *data, size_t size) override; + + + /*! + * Effectively just fr::TcpSocket::receive_raw() with encryption + * added in. + * + * @param data Where to store the received data. + * @param data_size The number of bytes to try and receive. Be sure that it's not larger than data. + * @param received Will be filled with the number of bytes actually received, might be less than you requested. + * @return The status of the operation, if the socket has disconnected etc. + */ + Status receive_raw(void *data, size_t data_size, size_t &received) override; + + /*! + * Sets the socket file descriptor. + * + * @param descriptor The socket descriptor. + */ + void set_descriptor(int descriptor) override; + + /*! + * Close the connection. + */ + void close() override; + + /*! + * Connects the socket to an address. + * + * @param address The address of the socket to connect to + * @param port The port of the socket to connect to + * @return A Socket::Status indicating the status of the operation. + */ + Socket::Status connect(const std::string &address, const std::string &port) override; + + private: + mbedtls_net_context ssl_socket_descriptor; + mbedtls_entropy_context entropy; + mbedtls_ctr_drbg_context ctr_drbg; + mbedtls_ssl_context ssl; + mbedtls_ssl_config conf; + mbedtls_x509_crt cacert; + uint32_t flags; + }; +} + +#endif + +#endif //FRNETLIB_SSLSOCKET_H diff --git a/include/Socket.h b/include/Socket.h index eb7fa60..e5dae25 100644 --- a/include/Socket.h +++ b/include/Socket.h @@ -23,6 +23,9 @@ namespace fr Disconnected = 4, Error = 5, WouldBlock = 6, + ConnectionFailed = 7, + HandshakeFailed = 8, + VerificationFailed = 9, }; Socket() diff --git a/include/TLSSocket.cpp b/include/TLSSocket.cpp deleted file mode 100644 index e4f2f5f..0000000 --- a/include/TLSSocket.cpp +++ /dev/null @@ -1,61 +0,0 @@ -// -// Created by fred on 12/12/16. -// - -#include "TLSSocket.h" -#ifdef SSL_ENABLED - -namespace fr -{ - TLSSocket::TLSSocket() - { - int error = 0; - - //Initialise mbedtls structures - mbedtls_net_init(&ssl_socket_descriptor); - mbedtls_ssl_init(&ssl); - mbedtls_ssl_config_init(&conf); - mbedtls_x509_crt_init(&cacert); - mbedtls_ctr_drbg_init(&ctr_drbg); - - //Seed random number generator - if((error = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, NULL, NULL)) != 0) - { - std::cout << "Failed to initialise random number generator. Returned error: " << error << std::endl; - return; - } - - //Load root CA certificate - if((error = mbedtls_x509_crt_parse(&cacert, (const unsigned char *)mbedtls_test_cas_pem, mbedtls_test_cas_pem_len) < 0)) - { - std::cout << "Failed to parse root CA certificate. Parse returned: " << error << std::endl; - return; - } - } - Socket::Status TLSSocket::send_raw(const char *data, size_t size) - { - return TcpSocket::send_raw(data, size); - } - - Socket::Status TLSSocket::receive_raw(void *data, size_t data_size, size_t &received) - { - return TcpSocket::receive_raw(data, data_size, received); - } - - void TLSSocket::set_descriptor(int descriptor) - { - TcpSocket::set_descriptor(descriptor); - } - - void TLSSocket::close() - { - TcpSocket::close(); - } - - Socket::Status TLSSocket::connect(const std::string &address, const std::string &port) - { - return TcpSocket::connect(address, port); - } -} - -#endif \ No newline at end of file diff --git a/include/TLSSocket.h b/include/TLSSocket.h deleted file mode 100644 index 68ebb83..0000000 --- a/include/TLSSocket.h +++ /dev/null @@ -1,83 +0,0 @@ -// -// Created by fred on 12/12/16. -// - -#ifndef FRNETLIB_TLSSOCKET_H -#define FRNETLIB_TLSSOCKET_H - - - -#ifdef SSL_ENABLED - -#include "TcpSocket.h" -#include -#include -#include -#include -#include -#include -#include - -namespace fr -{ - class TLSSocket : public TcpSocket - { - public: - TLSSocket(); - - /*! - * Effectively just fr::TcpSocket::send_raw() with encryption - * added in. - * - * @param data The data to send. - * @param size The number of bytes, from data to send. Be careful not to overflow. - * @return The status of the operation. - */ - Status send_raw(const char *data, size_t size) override; - - - /*! - * Effectively just fr::TcpSocket::receive_raw() with encryption - * added in. - * - * @param data Where to store the received data. - * @param data_size The number of bytes to try and receive. Be sure that it's not larger than data. - * @param received Will be filled with the number of bytes actually received, might be less than you requested. - * @return The status of the operation, if the socket has disconnected etc. - */ - Status receive_raw(void *data, size_t data_size, size_t &received) override; - - /*! - * Sets the socket file descriptor. - * - * @param descriptor The socket descriptor. - */ - void set_descriptor(int descriptor) override; - - /*! - * Close the connection. - */ - void close() override; - - /*! - * Connects the socket to an address. - * - * @param address The address of the socket to connect to - * @param port The port of the socket to connect to - * @return A Socket::Status indicating the status of the operation. - */ - Socket::Status connect(const std::string &address, const std::string &port) override; - - private: - mbedtls_net_context ssl_socket_descriptor; - mbedtls_entropy_context entropy; - mbedtls_ctr_drbg_context ctr_drbg; - mbedtls_ssl_context ssl; - mbedtls_ssl_config conf; - mbedtls_x509_crt cacert; - }; -} - -#endif - -#endif //FRNETLIB_TLSSOCKET_H diff --git a/include/TcpListener.h b/include/TcpListener.h index 2a7cff5..c800a4a 100644 --- a/include/TcpListener.h +++ b/include/TcpListener.h @@ -16,13 +16,17 @@ namespace fr class TcpListener : public Socket { public: + TcpListener() noexcept = default; + virtual ~TcpListener() noexcept = default; + TcpListener(TcpListener &&o) noexcept = default; + /*! * Listens to the given port for connections * * @param port The port to bind to * @return If the operation was successful */ - Socket::Status listen(const std::string &port); + virtual Socket::Status listen(const std::string &port); /*! * Accepts a new connection. @@ -30,7 +34,7 @@ public: * @param client Where to store the connection information * @return True on success. False on failure. */ - Socket::Status accept(TcpSocket &client); + virtual Socket::Status accept(TcpSocket &client); private: //Stubs diff --git a/include/TcpSocket.h b/include/TcpSocket.h index 5744074..3dac7b2 100644 --- a/include/TcpSocket.h +++ b/include/TcpSocket.h @@ -93,7 +93,7 @@ public: */ virtual Status receive_raw(void *data, size_t data_size, size_t &received); -private: +protected: /*! * Reads size bytes into dest from the socket. * Unlike receive_raw, this will keep trying diff --git a/main.cpp b/main.cpp index 0e195b5..ef3da8d 100644 --- a/main.cpp +++ b/main.cpp @@ -1,18 +1,54 @@ #include +#include #include "include/Packet.h" #include "include/TcpSocket.h" #include "include/TcpListener.h" #include "include/SocketSelector.h" -#include -#include -#include #include "HttpSocket.h" #include "HttpRequest.h" #include "HttpResponse.h" +#include "SSLSocket.h" int main() { - +// fr::SSLListener listener; +// if(listener.listen("9091") != fr::Socket::Success) +// { +// std::cout << "Failed to bind to port" << std::endl; +// return 1; +// } +// +// fr::SSLSocket socket; +// if(listener.accept(socket) != fr::Socket::Success) +// { +// std::cout << "Failed to accept client" << std::endl; +// return 2; +// } +// +// std::string buf(1024, '\0'); +// size_t received = 0; +// if(socket.receive_raw(&buf[0], buf.size(), received) != fr::Socket::Success) +// { +// std::cout << "Failed to receive data" << std::endl; +// return 3; +// } +// +// std::cout << "Got: " << buf.substr(0, received) << std::endl; +// +// +// fr::SSLSocket socket; +// if(socket.connect("lloydsenpai.xyz", "443") != fr::Socket::Success) +// return 1; +// +// std::string request = "GET / HTTP/1.1\r\nhost: www.lloydsenpai.xyz\r\n\r\n"; +// socket.send_raw(request.c_str(), request.size()); +// +// char *data = new char[1024]; +// size_t received; +// if(socket.receive_raw(data, 1024, received) != fr::Socket::Success) +// return 2; +// +// std::cout << "Got: " << std::string(data, received) << std::endl; return 0; } \ No newline at end of file