From 7e9f007acd99f2134394fe2d6b9b7dc12330849c Mon Sep 17 00:00:00 2001 From: Cloaked9000 Date: Mon, 12 Dec 2016 18:03:27 +0000 Subject: [PATCH] TLS Test Code --- CMakeLists.txt | 15 +++-- cmake_modules/FindMBEDTLS.cmake | 36 +++++++++++ include/TLSSocket.cpp | 61 ++++++++++++++++++ include/TLSSocket.h | 83 ++++++++++++++++++++++++ include/TcpSocket.h | 6 +- main.cpp | 109 -------------------------------- 6 files changed, 194 insertions(+), 116 deletions(-) create mode 100644 cmake_modules/FindMBEDTLS.cmake create mode 100644 include/TLSSocket.cpp create mode 100644 include/TLSSocket.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 0671483..0b52739 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,8 +1,15 @@ cmake_minimum_required(VERSION 3.6) project(frnetlib) -include_directories(include) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -m64 -fPIC -pthread") +set( CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} /home/fred/ClionProjects/frnetlib/cmake_modules) -set(SOURCE_FILES main.cpp src/TcpSocket.cpp include/TcpSocket.h src/TcpListener.cpp include/TcpListener.h src/Socket.cpp include/Socket.h src/Packet.cpp include/Packet.h include/NetworkEncoding.h src/SocketSelector.cpp include/SocketSelector.h src/HttpSocket.cpp include/HttpSocket.h src/HttpRequest.cpp include/HttpRequest.h src/HttpResponse.cpp include/HttpResponse.h src/Http.cpp include/Http.h) -add_executable(frnetlib ${SOURCE_FILES}) \ No newline at end of file +FIND_PACKAGE(MBEDTLS) +INCLUDE_DIRECTORIES(${MBEDTLS_INCLUDE_DIR}) + +include_directories(include) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -m64 -fPIC -pthread -lmbedtls -lmbedx509 -lmbedcrypto") + +set(SOURCE_FILES main.cpp src/TcpSocket.cpp include/TcpSocket.h src/TcpListener.cpp include/TcpListener.h src/Socket.cpp include/Socket.h src/Packet.cpp include/Packet.h include/NetworkEncoding.h src/SocketSelector.cpp include/SocketSelector.h src/HttpSocket.cpp include/HttpSocket.h src/HttpRequest.cpp include/HttpRequest.h src/HttpResponse.cpp include/HttpResponse.h src/Http.cpp include/Http.h include/TLSSocket.cpp include/TLSSocket.h) +add_executable(frnetlib ${SOURCE_FILES}) + +TARGET_LINK_LIBRARIES(frnetlib ${MBEDTLS_LIBRARIES} -lmbedtls -lmbedx509 -lmbedcrypto) \ No newline at end of file diff --git a/cmake_modules/FindMBEDTLS.cmake b/cmake_modules/FindMBEDTLS.cmake new file mode 100644 index 0000000..9f302b3 --- /dev/null +++ b/cmake_modules/FindMBEDTLS.cmake @@ -0,0 +1,36 @@ +find_path(MBEDTLS_INCLUDE_DIR + NAMES mbedtls/ssl.h + PATH_SUFFIXES include + HINTS ${MBEDTLS_ROOT}) + +find_library(MBEDTLS_LIBRARY + NAMES mbedtls + PATH_SUFFIXES lib + HINTS ${MBEDTLS_ROOT}) + +find_library(MBEDCRYPTO_LIBRARY + NAMES mbedcrypto + PATH_SUFFIXES lib + HINTS ${MBEDTLS_ROOT}) + +find_library(MBEDX509_LIBRARY + NAMES mbedx509 + PATH_SUFFIXES lib + HINTS ${MBEDTLS_ROOT}) + +if(MBEDTLS_INCLUDE_DIR AND MBEDTLS_LIBRARY) + set(MBEDTLS_FOUND TRUE) + set(MBEDTLS_LIBRARIES ${MBEDTLS_LIBRARY} ${MBEDCRYPTO_LIBRARY} ${MBEDX509_LIBRARY}) +endif() + +if(MBEDTLS_FOUND) + if(NOT MBEDTLS_FIND_QUIETLY) + message(STATUS "Found mbed TLS: ${MBEDTLS_LIBRARIES}") + endif() +else() + if(MBEDTLS_FIND_REQUIRED) + message(FATAL_ERROR "mbed TLS was not found") + endif() +endif() + +mark_as_advanced(MBEDTLS_INCLUDE_DIR MBEDTLS_LIBRARY) diff --git a/include/TLSSocket.cpp b/include/TLSSocket.cpp new file mode 100644 index 0000000..e4f2f5f --- /dev/null +++ b/include/TLSSocket.cpp @@ -0,0 +1,61 @@ +// +// Created by fred on 12/12/16. +// + +#include "TLSSocket.h" +#ifdef SSL_ENABLED + +namespace fr +{ + TLSSocket::TLSSocket() + { + int error = 0; + + //Initialise mbedtls structures + mbedtls_net_init(&ssl_socket_descriptor); + mbedtls_ssl_init(&ssl); + mbedtls_ssl_config_init(&conf); + mbedtls_x509_crt_init(&cacert); + mbedtls_ctr_drbg_init(&ctr_drbg); + + //Seed random number generator + if((error = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, NULL, NULL)) != 0) + { + std::cout << "Failed to initialise random number generator. Returned error: " << error << std::endl; + return; + } + + //Load root CA certificate + if((error = mbedtls_x509_crt_parse(&cacert, (const unsigned char *)mbedtls_test_cas_pem, mbedtls_test_cas_pem_len) < 0)) + { + std::cout << "Failed to parse root CA certificate. Parse returned: " << error << std::endl; + return; + } + } + Socket::Status TLSSocket::send_raw(const char *data, size_t size) + { + return TcpSocket::send_raw(data, size); + } + + Socket::Status TLSSocket::receive_raw(void *data, size_t data_size, size_t &received) + { + return TcpSocket::receive_raw(data, data_size, received); + } + + void TLSSocket::set_descriptor(int descriptor) + { + TcpSocket::set_descriptor(descriptor); + } + + void TLSSocket::close() + { + TcpSocket::close(); + } + + Socket::Status TLSSocket::connect(const std::string &address, const std::string &port) + { + return TcpSocket::connect(address, port); + } +} + +#endif \ No newline at end of file diff --git a/include/TLSSocket.h b/include/TLSSocket.h new file mode 100644 index 0000000..68ebb83 --- /dev/null +++ b/include/TLSSocket.h @@ -0,0 +1,83 @@ +// +// Created by fred on 12/12/16. +// + +#ifndef FRNETLIB_TLSSOCKET_H +#define FRNETLIB_TLSSOCKET_H + + + +#ifdef SSL_ENABLED + +#include "TcpSocket.h" +#include +#include +#include +#include +#include +#include +#include + +namespace fr +{ + class TLSSocket : public TcpSocket + { + public: + TLSSocket(); + + /*! + * Effectively just fr::TcpSocket::send_raw() with encryption + * added in. + * + * @param data The data to send. + * @param size The number of bytes, from data to send. Be careful not to overflow. + * @return The status of the operation. + */ + Status send_raw(const char *data, size_t size) override; + + + /*! + * Effectively just fr::TcpSocket::receive_raw() with encryption + * added in. + * + * @param data Where to store the received data. + * @param data_size The number of bytes to try and receive. Be sure that it's not larger than data. + * @param received Will be filled with the number of bytes actually received, might be less than you requested. + * @return The status of the operation, if the socket has disconnected etc. + */ + Status receive_raw(void *data, size_t data_size, size_t &received) override; + + /*! + * Sets the socket file descriptor. + * + * @param descriptor The socket descriptor. + */ + void set_descriptor(int descriptor) override; + + /*! + * Close the connection. + */ + void close() override; + + /*! + * Connects the socket to an address. + * + * @param address The address of the socket to connect to + * @param port The port of the socket to connect to + * @return A Socket::Status indicating the status of the operation. + */ + Socket::Status connect(const std::string &address, const std::string &port) override; + + private: + mbedtls_net_context ssl_socket_descriptor; + mbedtls_entropy_context entropy; + mbedtls_ctr_drbg_context ctr_drbg; + mbedtls_ssl_context ssl; + mbedtls_ssl_config conf; + mbedtls_x509_crt cacert; + }; +} + +#endif + +#endif //FRNETLIB_TLSSOCKET_H diff --git a/include/TcpSocket.h b/include/TcpSocket.h index e282685..5744074 100644 --- a/include/TcpSocket.h +++ b/include/TcpSocket.h @@ -55,7 +55,7 @@ public: * * @param descriptor The socket descriptor. */ - void set_descriptor(int descriptor); + virtual void set_descriptor(int descriptor); /*! * Checks to see if we're connected to a socket or not @@ -76,7 +76,7 @@ public: * @param size The number of bytes, from data to send. Be careful not to overflow. * @return The status of the operation. */ - Status send_raw(const char *data, size_t size); + virtual Status send_raw(const char *data, size_t size); /*! @@ -91,7 +91,7 @@ public: * @param received Will be filled with the number of bytes actually received, might be less than you requested. * @return The status of the operation, if the socket has disconnected etc. */ - Status receive_raw(void *data, size_t data_size, size_t &received); + virtual Status receive_raw(void *data, size_t data_size, size_t &received); private: /*! diff --git a/main.cpp b/main.cpp index 19eff82..0e195b5 100644 --- a/main.cpp +++ b/main.cpp @@ -10,117 +10,8 @@ #include "HttpRequest.h" #include "HttpResponse.h" -void server() -{ - //Bind to port - fr::TcpListener listener; - if(listener.listen("8081") != fr::Socket::Success) - { - std::cout << "Failed to listen to port" << std::endl; - return; - } - - //Create a selector and a container for holding connected clients - fr::SocketSelector selector; - std::vector> clients; - - //Add our connection listener to the selector - selector.add(listener); - - //Infinitely loop, waiting for connections or data - while(selector.wait()) - { - //If the listener is ready, that means we've got a new connection - if(selector.is_ready(listener)) - { - //Try and add them to our client container - clients.emplace_back(new fr::HttpSocket()); - if(listener.accept(*clients.back()) != fr::Socket::Success) - { - clients.pop_back(); - continue; - } - - //Add them to the selector if connected successfully - selector.add(*clients.back()); - std::cout << "Got new connection from: " << clients.back()->get_remote_address() << std::endl; - } - else - { - //Else it's one of the clients who's sent some data. Check each one - for(auto iter = clients.begin(); iter != clients.end();) - { - if(selector.is_ready(**iter)) - - { - //This client has sent a HTTP request, so receive_request it - fr::HttpRequest request; - if((*iter)->receive(request) == fr::Socket::Success) - { - //Print to the console what we've been requested for - std::cout << "Requested: " << request.get_uri() << std::endl; - - //Construct a response - fr::HttpResponse response; - response.set_body("

Hello, World!

"); - - //Send the response, and close the connection - (*iter)->send(response); - (*iter)->close(); - } - else - { - std::cout << (*iter)->get_remote_address() << " has disconnected." << std::endl; - selector.remove(*iter->get()); - iter = clients.erase(iter); - } - } - else - { - iter++; - } - } - } - } -} - -void client() -{ - fr::HttpSocket socket; - if(socket.connect("127.0.0.1", "8081") != fr::Socket::Success) - { - std::cout << "Failed to connect to web server!" << std::endl; - return; - } - - socket.set_blocking(false); - socket.set_blocking(true); - fr::HttpResponse response; - fr::Socket::Status status = socket.receive(response); - if(status != fr::Socket::Success) - { - if(status == fr::Socket::WouldBlock) - std::cout << "WouldBlock" << std::endl; - std::cout << "Failed to receive HTTP response from the server!" << std::endl; - } - - std::cout << "Got page body: " << response.get_body() << std::endl; - return; - -} - int main() { - std::thread t1(&server); - - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - auto start = std::chrono::system_clock::now(); - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - - client(); - - t1.join(); - return 0;