SSL Contexts are now shared between sockets

This commit is contained in:
Cloaked9000 2016-12-16 14:55:41 +00:00
parent 509b37095f
commit 8c94c337e9
13 changed files with 104 additions and 122 deletions

View File

@ -9,7 +9,7 @@ INCLUDE_DIRECTORIES(${MBEDTLS_INCLUDE_DIR})
include_directories(include)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -m64 -fPIC -pthread -lmbedtls -lmbedx509 -lmbedcrypto")
set(SOURCE_FILES main.cpp src/TcpSocket.cpp include/TcpSocket.h src/TcpListener.cpp include/TcpListener.h src/Socket.cpp include/Socket.h src/Packet.cpp include/Packet.h include/NetworkEncoding.h src/SocketSelector.cpp include/SocketSelector.h src/HttpSocket.cpp include/HttpSocket.h src/HttpRequest.cpp include/HttpRequest.h src/HttpResponse.cpp include/HttpResponse.h src/Http.cpp include/Http.h src/SSLSocket.cpp include/SSLSocket.h src/SSLListener.cpp include/SSLListener.h)
set(SOURCE_FILES main.cpp src/TcpSocket.cpp include/TcpSocket.h src/TcpListener.cpp include/TcpListener.h src/Socket.cpp include/Socket.h src/Packet.cpp include/Packet.h include/NetworkEncoding.h src/SocketSelector.cpp include/SocketSelector.h src/HttpSocket.cpp include/HttpSocket.h src/HttpRequest.cpp include/HttpRequest.h src/HttpResponse.cpp include/HttpResponse.h src/Http.cpp include/Http.h src/SSLSocket.cpp include/SSLSocket.h src/SSLListener.cpp include/SSLListener.h include/SSLContext.h)
add_executable(frnetlib ${SOURCE_FILES})
TARGET_LINK_LIBRARIES(frnetlib ${MBEDTLS_LIBRARIES} -lmbedtls -lmbedx509 -lmbedcrypto -static)

View File

@ -139,7 +139,7 @@ namespace fr
*
* @return The status
*/
RequestStatus get_status();
RequestStatus get_status() const;
/*!
* Sets the request URI.

View File

@ -7,6 +7,7 @@
#include "TcpSocket.h"
#include "Http.h"
#include "SSLContext.h"
namespace fr
{
@ -14,6 +15,13 @@ namespace fr
class HttpSocket : public SocketType
{
public:
HttpSocket() noexcept =default;
//Forward constructor arguments to SocketType if needed
template<typename T>
HttpSocket(T &&var)
: SocketType(var){}
/*!
* Receives a HTTP request from the connected socket
*
@ -23,7 +31,7 @@ namespace fr
Socket::Status receive(Http &request)
{
//Create buffer to receive_request the request
std::string buffer(2048, '\0');
std::string buffer(RECV_CHUNK_SIZE, '\0');
//Receive the request
size_t received;

62
include/SSLContext.h Normal file
View File

@ -0,0 +1,62 @@
//
// Created by fred on 16/12/16.
//
#ifndef FRNETLIB_SSLCONTEXT_H
#define FRNETLIB_SSLCONTEXT_H
#define USE_SSL
#ifdef USE_SSL
#include <mbedtls/x509_crt.h>
#include <mbedtls/ctr_drbg.h>
#include <mbedtls/entropy.h>
#include <cstring>
#include <iostream>
namespace fr
{
class SSLContext
{
public:
SSLContext(const std::string &ca_certs_path)
{
int error = 0;
//Initialise mbed_tls structures
mbedtls_x509_crt_init(&cacert);
mbedtls_ctr_drbg_init(&ctr_drbg);
//Seed random number generator
mbedtls_entropy_init(&entropy);
if((error = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, nullptr, 0)) != 0)
{
std::cout << "Failed to initialise random number generator. Returned error: " << error << std::endl;
return;
}
//Load root CA certificate
if((error = mbedtls_x509_crt_parse_file(&cacert, ca_certs_path.c_str()) < 0))
{
std::cout << "Failed to parse root CA certificates. Parse returned: " << error << std::endl;
return;
}
}
~SSLContext()
{
mbedtls_ctr_drbg_free(&ctr_drbg);
mbedtls_entropy_free(&entropy);
mbedtls_x509_crt_free(&cacert);
}
mbedtls_entropy_context entropy;
mbedtls_ctr_drbg_context ctr_drbg;
mbedtls_x509_crt cacert;
};
}
#endif // USE_SSSL
#endif //FRNETLIB_SSLCONTEXT_H

View File

@ -26,7 +26,7 @@ namespace fr
class SSLListener : public Socket
{
public:
SSLListener(const std::string &crt_path, const std::string &pem_path, const std::string &private_key_path) noexcept;
SSLListener(std::shared_ptr<SSLContext> ssl_context, const std::string &crt_path, const std::string &pem_path, const std::string &private_key_path) noexcept;
virtual ~SSLListener() noexcept;
SSLListener(SSLListener &&o) noexcept = default;
@ -60,12 +60,12 @@ namespace fr
private:
mbedtls_net_context listen_fd;
mbedtls_entropy_context entropy;
mbedtls_ctr_drbg_context ctr_drbg;
mbedtls_ssl_config conf;
mbedtls_x509_crt srvcert;
mbedtls_pk_context pkey;
std::shared_ptr<SSLContext> ssl_context;
//Stubs
virtual void close(){}
virtual Socket::Status connect(const std::string &address, const std::string &port){return Socket::Error;}

View File

@ -10,6 +10,7 @@
#ifdef SSL_ENABLED
#include "TcpSocket.h"
#include "SSLContext.h"
#include <mbedtls/net_sockets.h>
#include <mbedtls/debug.h>
#include <mbedtls/ssl.h>
@ -69,7 +70,7 @@ namespace fr
class SSLSocket : public Socket
{
public:
SSLSocket() noexcept;
SSLSocket(std::shared_ptr<SSLContext> ssl_context) noexcept;
~SSLSocket() noexcept;
@ -148,13 +149,11 @@ namespace fr
private:
std::string unprocessed_buffer;
std::unique_ptr<char[]> recv_buffer;
std::shared_ptr<SSLContext> ssl_context;
std::unique_ptr<mbedtls_net_context> ssl_socket_descriptor;
mbedtls_entropy_context entropy;
mbedtls_ctr_drbg_context ctr_drbg;
std::unique_ptr<mbedtls_ssl_context> ssl;
mbedtls_ssl_config conf;
mbedtls_x509_crt cacert;
uint32_t flags;
};
}

View File

@ -10,7 +10,7 @@
namespace fr
{
#define RECV_CHUNK_SIZE 1024 //How much data to try and recv at once
#define RECV_CHUNK_SIZE 2048 //How much data to try and recv at once
class TcpSocket : public Socket
{

View File

@ -11,73 +11,19 @@
int main()
{
//Bind to port
fr::SSLListener listener("key.crt", "key.pem", "private.key");
if(listener.listen("8080") != fr::Socket::Success)
{
//Error
}
std::shared_ptr<fr::SSLContext> ssl_context(new fr::SSLContext("certs.crt"));
//Create socket selector and add listener
fr::SocketSelector selector;
selector.add(listener);
fr::HttpSocket<fr::SSLSocket> socket(ssl_context);
std::string addr;
std::cin >> addr;
socket.connect(addr, "443");
//Create vector to store open connections
std::vector<std::unique_ptr<fr::Socket>> connections;
//Infinitely loop. No timeout is specified so it will not return false.
while(selector.wait())
{
//Check if it was the selector who sent data
if(selector.is_ready(listener))
{
std::unique_ptr<fr::HttpSocket<fr::SSLSocket>> socket(new fr::HttpSocket<fr::SSLSocket>);
if(listener.accept(*socket) == fr::Socket::Success)
{
selector.add(*socket);
connections.emplace_back(std::move(socket));
}
}
//Else it must have been one of the clients
else
{
//Find which client send the data
for(auto iter = connections.begin(); iter != connections.end();)
{
//Eww
fr::HttpSocket<fr::SSLSocket> &client = (fr::HttpSocket<fr::SSLSocket>&)**iter;
//Check if it's this client
if(selector.is_ready(client))
{
//It is, so receive their HTTP request
fr::HttpRequest request;
if(client.receive(request) == fr::Socket::Success)
{
//Send back a HTTP response containing 'Hello, World!'
fr::HttpResponse response;
response.set_body("<h1>frnetlib test page</h1>");
client.send(response);
//Remove them from the selector and close the connection
selector.remove(client);
client.close();
iter = connections.erase(iter);
}
else
{
iter++;
}
}
else
{
iter++;
}
}
}
}
fr::HttpRequest request;
socket.send(request);
fr::HttpResponse response;
socket.receive(response);
std::cout << response.get_body() << std::endl;
return 0;
}

View File

@ -54,7 +54,7 @@ namespace fr
get_variables.clear();
uri = "/";
status = Ok;
request_type = Get;
request_type = Unknown;
}
std::string &Http::get(const std::string &key)
@ -87,7 +87,7 @@ namespace fr
status = status_;
}
Http::RequestStatus Http::get_status()
Http::RequestStatus Http::get_status() const
{
return status;
}

View File

@ -108,7 +108,7 @@ namespace fr
std::string HttpRequest::construct(const std::string &host) const
{
//Add HTTP header
std::string request = request_type_to_string(request_type) + " " + uri + " HTTP/1.1\n";
std::string request = request_type_to_string(request_type == Http::Unknown ? Http::Get : request_type) + " " + uri + " HTTP/1.1\n";
//Add the headers to the request
for(const auto &header : headers)

View File

@ -9,7 +9,6 @@ namespace fr
{
void HttpResponse::parse(const std::string &response_data)
{
std::cout << "Parsing: " << response_data << std::endl;
//Clear old headers/data
clear();

View File

@ -2,24 +2,24 @@
// Created by fred on 13/12/16.
//
#include <chrono>
#include "SSLListener.h"
#ifdef SSL_ENABLED
namespace fr
{
SSLListener::SSLListener(const std::string &crt_path, const std::string &pem_path, const std::string &private_key_path) noexcept
SSLListener::SSLListener(std::shared_ptr<SSLContext> ssl_context_, const std::string &crt_path, const std::string &pem_path, const std::string &private_key_path) noexcept
: ssl_context(ssl_context_)
{
//Initialise SSL objects required
mbedtls_net_init(&listen_fd);
mbedtls_ssl_config_init(&conf);
mbedtls_x509_crt_init(&srvcert);
mbedtls_pk_init(&pkey);
mbedtls_entropy_init(&entropy);
mbedtls_ctr_drbg_init(&ctr_drbg);
int error = 0;
//Load certificates and private key todo: Switch from inbuilt test certificates
//Load certificates and private key
error = mbedtls_x509_crt_parse_file(&srvcert, crt_path.c_str());
if(error != 0)
{
@ -41,14 +41,6 @@ namespace fr
return;
}
//Seed random number generator
if((error = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, NULL, 0)) != 0)
{
std::cout << "Failed to initialise SSL listener. Failed to seed random number generator: " << error
<< std::endl;
return;
}
//Setup data structures
if((error = mbedtls_ssl_config_defaults(&conf, MBEDTLS_SSL_IS_SERVER, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT)) != 0)
{
@ -57,7 +49,7 @@ namespace fr
}
//Apply them
mbedtls_ssl_conf_rng(&conf, mbedtls_ctr_drbg_random, &ctr_drbg);
mbedtls_ssl_conf_rng(&conf, mbedtls_ctr_drbg_random, &ssl_context->ctr_drbg);
mbedtls_ssl_conf_ca_chain(&conf, srvcert.next, NULL);
if((error = mbedtls_ssl_conf_own_cert(&conf, &srvcert, &pkey)) != 0)
@ -73,8 +65,6 @@ namespace fr
mbedtls_x509_crt_free(&srvcert);
mbedtls_pk_free(&pkey);
mbedtls_ssl_config_free(&conf);
mbedtls_ctr_drbg_free(&ctr_drbg);
mbedtls_entropy_free( &entropy);
}
Socket::Status fr::SSLListener::listen(const std::string &port)
@ -98,7 +88,6 @@ namespace fr
return Socket::Error;
}
//Accept a connection
std::unique_ptr<mbedtls_net_context> client_fd(new mbedtls_net_context);
mbedtls_net_init(client_fd.get());
@ -111,12 +100,12 @@ namespace fr
mbedtls_ssl_set_bio(ssl.get(), client_fd.get(), mbedtls_net_send, mbedtls_net_recv, NULL);
auto start = std::chrono::system_clock::now();
//SSL Handshake
while((error = mbedtls_ssl_handshake(ssl.get())) != 0)
{
if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE)
{
std::cout << "Handshake failed: " << error << std::endl;
return Socket::Status::HandshakeFailed;
}
}

View File

@ -8,31 +8,13 @@
namespace fr
{
SSLSocket::SSLSocket() noexcept
: recv_buffer(new char[RECV_CHUNK_SIZE])
SSLSocket::SSLSocket(std::shared_ptr<SSLContext> ssl_context_) noexcept
: recv_buffer(new char[RECV_CHUNK_SIZE]),
ssl_context(ssl_context_)
{
int error = 0;
const char *pers = "ssl_client1";
//Initialise mbedtls structures
mbedtls_ssl_config_init(&conf);
mbedtls_x509_crt_init(&cacert);
mbedtls_ctr_drbg_init(&ctr_drbg);
//Seed random number generator
mbedtls_entropy_init(&entropy);
if((error = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, (const unsigned char *)pers, strlen(pers))) != 0)
{
std::cout << "Failed to initialise random number generator. Returned error: " << error << std::endl;
return;
}
//Load root CA certificate
if((error = mbedtls_x509_crt_parse(&cacert, (const unsigned char *)certs.c_str(), certs.size() + 1) < 0))
{
std::cout << "Failed to parse root CA certificate. Parse returned: " << error << std::endl;
return;
}
}
SSLSocket::~SSLSocket() noexcept
@ -41,10 +23,7 @@ namespace fr
close();
//Cleanup mbedsql stuff
mbedtls_x509_crt_free(&cacert);
mbedtls_ssl_config_free(&conf);
mbedtls_ctr_drbg_free(&ctr_drbg);
mbedtls_entropy_free(&entropy);
}
void SSLSocket::close()
@ -134,8 +113,8 @@ namespace fr
}
mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_REQUIRED);
mbedtls_ssl_conf_ca_chain(&conf, &cacert, NULL);
mbedtls_ssl_conf_rng(&conf, mbedtls_ctr_drbg_random, &ctr_drbg);
mbedtls_ssl_conf_ca_chain(&conf, &ssl_context->cacert, NULL);
mbedtls_ssl_conf_rng(&conf, mbedtls_ctr_drbg_random, &ssl_context->ctr_drbg);
if((error = mbedtls_ssl_setup(ssl.get(), &conf)) != 0)
{