TLS Test Code

This commit is contained in:
Cloaked9000 2016-12-12 18:03:27 +00:00
parent bbd6ee071b
commit 7e9f007acd
6 changed files with 194 additions and 116 deletions

View File

@ -1,8 +1,15 @@
cmake_minimum_required(VERSION 3.6)
project(frnetlib)
include_directories(include)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -m64 -fPIC -pthread")
set( CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} /home/fred/ClionProjects/frnetlib/cmake_modules)
set(SOURCE_FILES main.cpp src/TcpSocket.cpp include/TcpSocket.h src/TcpListener.cpp include/TcpListener.h src/Socket.cpp include/Socket.h src/Packet.cpp include/Packet.h include/NetworkEncoding.h src/SocketSelector.cpp include/SocketSelector.h src/HttpSocket.cpp include/HttpSocket.h src/HttpRequest.cpp include/HttpRequest.h src/HttpResponse.cpp include/HttpResponse.h src/Http.cpp include/Http.h)
add_executable(frnetlib ${SOURCE_FILES})
FIND_PACKAGE(MBEDTLS)
INCLUDE_DIRECTORIES(${MBEDTLS_INCLUDE_DIR})
include_directories(include)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -m64 -fPIC -pthread -lmbedtls -lmbedx509 -lmbedcrypto")
set(SOURCE_FILES main.cpp src/TcpSocket.cpp include/TcpSocket.h src/TcpListener.cpp include/TcpListener.h src/Socket.cpp include/Socket.h src/Packet.cpp include/Packet.h include/NetworkEncoding.h src/SocketSelector.cpp include/SocketSelector.h src/HttpSocket.cpp include/HttpSocket.h src/HttpRequest.cpp include/HttpRequest.h src/HttpResponse.cpp include/HttpResponse.h src/Http.cpp include/Http.h include/TLSSocket.cpp include/TLSSocket.h)
add_executable(frnetlib ${SOURCE_FILES})
TARGET_LINK_LIBRARIES(frnetlib ${MBEDTLS_LIBRARIES} -lmbedtls -lmbedx509 -lmbedcrypto)

View File

@ -0,0 +1,36 @@
find_path(MBEDTLS_INCLUDE_DIR
NAMES mbedtls/ssl.h
PATH_SUFFIXES include
HINTS ${MBEDTLS_ROOT})
find_library(MBEDTLS_LIBRARY
NAMES mbedtls
PATH_SUFFIXES lib
HINTS ${MBEDTLS_ROOT})
find_library(MBEDCRYPTO_LIBRARY
NAMES mbedcrypto
PATH_SUFFIXES lib
HINTS ${MBEDTLS_ROOT})
find_library(MBEDX509_LIBRARY
NAMES mbedx509
PATH_SUFFIXES lib
HINTS ${MBEDTLS_ROOT})
if(MBEDTLS_INCLUDE_DIR AND MBEDTLS_LIBRARY)
set(MBEDTLS_FOUND TRUE)
set(MBEDTLS_LIBRARIES ${MBEDTLS_LIBRARY} ${MBEDCRYPTO_LIBRARY} ${MBEDX509_LIBRARY})
endif()
if(MBEDTLS_FOUND)
if(NOT MBEDTLS_FIND_QUIETLY)
message(STATUS "Found mbed TLS: ${MBEDTLS_LIBRARIES}")
endif()
else()
if(MBEDTLS_FIND_REQUIRED)
message(FATAL_ERROR "mbed TLS was not found")
endif()
endif()
mark_as_advanced(MBEDTLS_INCLUDE_DIR MBEDTLS_LIBRARY)

61
include/TLSSocket.cpp Normal file
View File

@ -0,0 +1,61 @@
//
// Created by fred on 12/12/16.
//
#include "TLSSocket.h"
#ifdef SSL_ENABLED
namespace fr
{
TLSSocket::TLSSocket()
{
int error = 0;
//Initialise mbedtls structures
mbedtls_net_init(&ssl_socket_descriptor);
mbedtls_ssl_init(&ssl);
mbedtls_ssl_config_init(&conf);
mbedtls_x509_crt_init(&cacert);
mbedtls_ctr_drbg_init(&ctr_drbg);
//Seed random number generator
if((error = mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, NULL, NULL)) != 0)
{
std::cout << "Failed to initialise random number generator. Returned error: " << error << std::endl;
return;
}
//Load root CA certificate
if((error = mbedtls_x509_crt_parse(&cacert, (const unsigned char *)mbedtls_test_cas_pem, mbedtls_test_cas_pem_len) < 0))
{
std::cout << "Failed to parse root CA certificate. Parse returned: " << error << std::endl;
return;
}
}
Socket::Status TLSSocket::send_raw(const char *data, size_t size)
{
return TcpSocket::send_raw(data, size);
}
Socket::Status TLSSocket::receive_raw(void *data, size_t data_size, size_t &received)
{
return TcpSocket::receive_raw(data, data_size, received);
}
void TLSSocket::set_descriptor(int descriptor)
{
TcpSocket::set_descriptor(descriptor);
}
void TLSSocket::close()
{
TcpSocket::close();
}
Socket::Status TLSSocket::connect(const std::string &address, const std::string &port)
{
return TcpSocket::connect(address, port);
}
}
#endif

83
include/TLSSocket.h Normal file
View File

@ -0,0 +1,83 @@
//
// Created by fred on 12/12/16.
//
#ifndef FRNETLIB_TLSSOCKET_H
#define FRNETLIB_TLSSOCKET_H
#ifdef SSL_ENABLED
#include "TcpSocket.h"
#include <mbedtls/net_sockets.h>
#include <mbedtls/debug.h>
#include <mbedtls/ssl.h>
#include <mbedtls/entropy.h>
#include <mbedtls/ctr_drbg.h>
#include <mbedtls/error.h>
#include <mbedtls/certs.h>
namespace fr
{
class TLSSocket : public TcpSocket
{
public:
TLSSocket();
/*!
* Effectively just fr::TcpSocket::send_raw() with encryption
* added in.
*
* @param data The data to send.
* @param size The number of bytes, from data to send. Be careful not to overflow.
* @return The status of the operation.
*/
Status send_raw(const char *data, size_t size) override;
/*!
* Effectively just fr::TcpSocket::receive_raw() with encryption
* added in.
*
* @param data Where to store the received data.
* @param data_size The number of bytes to try and receive. Be sure that it's not larger than data.
* @param received Will be filled with the number of bytes actually received, might be less than you requested.
* @return The status of the operation, if the socket has disconnected etc.
*/
Status receive_raw(void *data, size_t data_size, size_t &received) override;
/*!
* Sets the socket file descriptor.
*
* @param descriptor The socket descriptor.
*/
void set_descriptor(int descriptor) override;
/*!
* Close the connection.
*/
void close() override;
/*!
* Connects the socket to an address.
*
* @param address The address of the socket to connect to
* @param port The port of the socket to connect to
* @return A Socket::Status indicating the status of the operation.
*/
Socket::Status connect(const std::string &address, const std::string &port) override;
private:
mbedtls_net_context ssl_socket_descriptor;
mbedtls_entropy_context entropy;
mbedtls_ctr_drbg_context ctr_drbg;
mbedtls_ssl_context ssl;
mbedtls_ssl_config conf;
mbedtls_x509_crt cacert;
};
}
#endif
#endif //FRNETLIB_TLSSOCKET_H

View File

@ -55,7 +55,7 @@ public:
*
* @param descriptor The socket descriptor.
*/
void set_descriptor(int descriptor);
virtual void set_descriptor(int descriptor);
/*!
* Checks to see if we're connected to a socket or not
@ -76,7 +76,7 @@ public:
* @param size The number of bytes, from data to send. Be careful not to overflow.
* @return The status of the operation.
*/
Status send_raw(const char *data, size_t size);
virtual Status send_raw(const char *data, size_t size);
/*!
@ -91,7 +91,7 @@ public:
* @param received Will be filled with the number of bytes actually received, might be less than you requested.
* @return The status of the operation, if the socket has disconnected etc.
*/
Status receive_raw(void *data, size_t data_size, size_t &received);
virtual Status receive_raw(void *data, size_t data_size, size_t &received);
private:
/*!

109
main.cpp
View File

@ -10,117 +10,8 @@
#include "HttpRequest.h"
#include "HttpResponse.h"
void server()
{
//Bind to port
fr::TcpListener listener;
if(listener.listen("8081") != fr::Socket::Success)
{
std::cout << "Failed to listen to port" << std::endl;
return;
}
//Create a selector and a container for holding connected clients
fr::SocketSelector selector;
std::vector<std::unique_ptr<fr::HttpSocket>> clients;
//Add our connection listener to the selector
selector.add(listener);
//Infinitely loop, waiting for connections or data
while(selector.wait())
{
//If the listener is ready, that means we've got a new connection
if(selector.is_ready(listener))
{
//Try and add them to our client container
clients.emplace_back(new fr::HttpSocket());
if(listener.accept(*clients.back()) != fr::Socket::Success)
{
clients.pop_back();
continue;
}
//Add them to the selector if connected successfully
selector.add(*clients.back());
std::cout << "Got new connection from: " << clients.back()->get_remote_address() << std::endl;
}
else
{
//Else it's one of the clients who's sent some data. Check each one
for(auto iter = clients.begin(); iter != clients.end();)
{
if(selector.is_ready(**iter))
{
//This client has sent a HTTP request, so receive_request it
fr::HttpRequest request;
if((*iter)->receive(request) == fr::Socket::Success)
{
//Print to the console what we've been requested for
std::cout << "Requested: " << request.get_uri() << std::endl;
//Construct a response
fr::HttpResponse response;
response.set_body("<h1>Hello, World!</h1>");
//Send the response, and close the connection
(*iter)->send(response);
(*iter)->close();
}
else
{
std::cout << (*iter)->get_remote_address() << " has disconnected." << std::endl;
selector.remove(*iter->get());
iter = clients.erase(iter);
}
}
else
{
iter++;
}
}
}
}
}
void client()
{
fr::HttpSocket socket;
if(socket.connect("127.0.0.1", "8081") != fr::Socket::Success)
{
std::cout << "Failed to connect to web server!" << std::endl;
return;
}
socket.set_blocking(false);
socket.set_blocking(true);
fr::HttpResponse response;
fr::Socket::Status status = socket.receive(response);
if(status != fr::Socket::Success)
{
if(status == fr::Socket::WouldBlock)
std::cout << "WouldBlock" << std::endl;
std::cout << "Failed to receive HTTP response from the server!" << std::endl;
}
std::cout << "Got page body: " << response.get_body() << std::endl;
return;
}
int main()
{
std::thread t1(&server);
std::this_thread::sleep_for(std::chrono::milliseconds(100));
auto start = std::chrono::system_clock::now();
std::this_thread::sleep_for(std::chrono::milliseconds(100));
client();
t1.join();
return 0;