Improving build system

Instead of #ifdefing files out, they are no longer included by CMake instead.
This commit is contained in:
Unknown 2018-02-24 21:14:43 +00:00
parent ff25d11089
commit 0840c07e24
9 changed files with 62 additions and 74 deletions

2
.gitignore vendored
View File

@ -1,3 +1,5 @@
.idea .idea
cmake-build-debug cmake-build-debug
cmake-build-release cmake-build-release
CMakeCache\.txt

View File

@ -24,7 +24,7 @@ add_definitions(-DLISTEN_QUEUE_SIZE=${LISTEN_QUEUE_SIZE})
if(USE_SSL) if(USE_SSL)
FIND_PACKAGE(MBEDTLS) FIND_PACKAGE(MBEDTLS)
INCLUDE_DIRECTORIES(${MBEDTLS_INCLUDE_DIR}) INCLUDE_DIRECTORIES(${MBEDTLS_INCLUDE_DIR})
add_definitions(-DSSL_ENABLED) set(SOURCE_FILES ${SOURCE_FILES} src/SSLSocket.cpp include/frnetlib/SSLSocket.h src/SSLListener.cpp include/frnetlib/SSLListener.h include/frnetlib/SSLContext.h)
endif() endif()
add_definitions(-DNOMINMAX) add_definitions(-DNOMINMAX)
@ -37,7 +37,7 @@ add_definitions(-Dntodh)
set( INCLUDE_PATH "${PROJECT_SOURCE_DIR}/include" ) set( INCLUDE_PATH "${PROJECT_SOURCE_DIR}/include" )
set( SOURCE_PATH "${PROJECT_SOURCE_DIR}/src" ) set( SOURCE_PATH "${PROJECT_SOURCE_DIR}/src" )
set(SOURCE_FILES main.cpp src/TcpSocket.cpp include/frnetlib/TcpSocket.h src/TcpListener.cpp include/frnetlib/TcpListener.h src/Socket.cpp include/frnetlib/Socket.h src/Packet.cpp include/frnetlib/Packet.h include/frnetlib/NetworkEncoding.h src/SocketSelector.cpp include/frnetlib/SocketSelector.h src/HttpRequest.cpp include/frnetlib/HttpRequest.h src/HttpResponse.cpp include/frnetlib/HttpResponse.h src/Http.cpp include/frnetlib/Http.h src/SSLSocket.cpp include/frnetlib/SSLSocket.h src/SSLListener.cpp include/frnetlib/SSLListener.h include/frnetlib/SSLContext.h src/SocketReactor.cpp include/frnetlib/SocketReactor.h include/frnetlib/Packetable.h include/frnetlib/Listener.h src/URL.cpp include/frnetlib/URL.h include/frnetlib/Sendable.h) set(SOURCE_FILES ${SOURCE_FILES} main.cpp src/TcpSocket.cpp include/frnetlib/TcpSocket.h src/TcpListener.cpp include/frnetlib/TcpListener.h src/Socket.cpp include/frnetlib/Socket.h src/Packet.cpp include/frnetlib/Packet.h include/frnetlib/NetworkEncoding.h src/SocketSelector.cpp include/frnetlib/SocketSelector.h src/HttpRequest.cpp include/frnetlib/HttpRequest.h src/HttpResponse.cpp include/frnetlib/HttpResponse.h src/Http.cpp include/frnetlib/Http.h src/SocketReactor.cpp include/frnetlib/SocketReactor.h include/frnetlib/Packetable.h include/frnetlib/Listener.h src/URL.cpp include/frnetlib/URL.h include/frnetlib/Sendable.h)
include_directories(include) include_directories(include)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14")

View File

@ -5,8 +5,6 @@
#ifndef FRNETLIB_SSLCONTEXT_H #ifndef FRNETLIB_SSLCONTEXT_H
#define FRNETLIB_SSLCONTEXT_H #define FRNETLIB_SSLCONTEXT_H
#ifdef SSL_ENABLED
#include <mbedtls/x509_crt.h> #include <mbedtls/x509_crt.h>
#include <mbedtls/ctr_drbg.h> #include <mbedtls/ctr_drbg.h>
#include <mbedtls/entropy.h> #include <mbedtls/entropy.h>
@ -86,7 +84,4 @@ namespace fr
}; };
} }
#endif // SSL_ENABLED
#endif //FRNETLIB_SSLCONTEXT_H #endif //FRNETLIB_SSLCONTEXT_H

View File

@ -5,8 +5,6 @@
#ifndef FRNETLIB_SSLLISTENER_H #ifndef FRNETLIB_SSLLISTENER_H
#define FRNETLIB_SSLLISTENER_H #define FRNETLIB_SSLLISTENER_H
#ifdef SSL_ENABLED
#include <mbedtls/net_sockets.h> #include <mbedtls/net_sockets.h>
#include <mbedtls/debug.h> #include <mbedtls/debug.h>
#include <mbedtls/ssl.h> #include <mbedtls/ssl.h>
@ -83,6 +81,4 @@ namespace fr
}; };
} }
#endif //SLL_ENABLED
#endif //FRNETLIB_SSLLISTENER_H #endif //FRNETLIB_SSLLISTENER_H

View File

@ -4,8 +4,6 @@
#ifndef FRNETLIB_SSL_SOCKET_H #ifndef FRNETLIB_SSL_SOCKET_H
#define FRNETLIB_SSL_SOCKET_H #define FRNETLIB_SSL_SOCKET_H
#ifdef SSL_ENABLED
#include "TcpSocket.h" #include "TcpSocket.h"
#include "SSLContext.h" #include "SSLContext.h"
#include <mbedtls/net_sockets.h> #include <mbedtls/net_sockets.h>
@ -64,6 +62,13 @@ namespace fr
*/ */
virtual Socket::Status connect(const std::string &address, const std::string &port, std::chrono::seconds timeout) override; virtual Socket::Status connect(const std::string &address, const std::string &port, std::chrono::seconds timeout) override;
/*!
* Sets the socket file descriptor.
*
* @param descriptor The socket descriptor.
*/
virtual void set_descriptor(int descriptor) override;
/*! /*!
* Set the SSL context * Set the SSL context
* *
@ -71,13 +76,6 @@ namespace fr
*/ */
void set_ssl_context(std::unique_ptr<mbedtls_ssl_context> context); void set_ssl_context(std::unique_ptr<mbedtls_ssl_context> context);
/*!
* Set the NET context
*
* @param context The NET context to use
*/
void set_net_context(std::unique_ptr<mbedtls_net_context> context);
/*! /*!
* Gets the underlying socket descriptor. * Gets the underlying socket descriptor.
* *
@ -85,9 +83,7 @@ namespace fr
*/ */
int32_t get_socket_descriptor() const override int32_t get_socket_descriptor() const override
{ {
if(!ssl_socket_descriptor) return ssl_socket_descriptor.fd;
return -1;
return ssl_socket_descriptor->fd;
} }
/*! /*!
@ -107,7 +103,7 @@ namespace fr
*/ */
inline bool connected() const final inline bool connected() const final
{ {
return ssl_socket_descriptor && ssl_socket_descriptor->fd > -1; return ssl_socket_descriptor.fd > -1;
} }
/*! /*!
@ -123,7 +119,7 @@ namespace fr
private: private:
std::shared_ptr<SSLContext> ssl_context; std::shared_ptr<SSLContext> ssl_context;
std::unique_ptr<mbedtls_net_context> ssl_socket_descriptor; mbedtls_net_context ssl_socket_descriptor;
std::unique_ptr<mbedtls_ssl_context> ssl; std::unique_ptr<mbedtls_ssl_context> ssl;
mbedtls_ssl_config conf; mbedtls_ssl_config conf;
uint32_t flags; uint32_t flags;
@ -131,6 +127,4 @@ namespace fr
}; };
} }
#endif
#endif //FRNETLIB_SSLSOCKET_H #endif //FRNETLIB_SSLSOCKET_H

View File

@ -68,25 +68,6 @@ namespace fr
*/ */
virtual Socket::Status connect(const std::string &address, const std::string &port, std::chrono::seconds timeout)=0; virtual Socket::Status connect(const std::string &address, const std::string &port, std::chrono::seconds timeout)=0;
/*!
* Gets the socket's printable remote address
*
* @return The string address
*/
inline const std::string &get_remote_address()
{
return remote_address;
}
/*!
* Sets the connections remote address.
*
* @param addr The remote address to use
*/
void set_remote_address(const std::string &addr)
{
remote_address = addr;
}
/*! /*!
* Sets the socket to blocking or non-blocking. * Sets the socket to blocking or non-blocking.
@ -106,7 +87,6 @@ namespace fr
*/ */
virtual Status send_raw(const char *data, size_t size) = 0; virtual Status send_raw(const char *data, size_t size) = 0;
/*! /*!
* Receives raw data from the socket, without any of * Receives raw data from the socket, without any of
* frnetlib's framing. Useful for communicating through * frnetlib's framing. Useful for communicating through
@ -135,6 +115,33 @@ namespace fr
*/ */
virtual bool connected() const =0; virtual bool connected() const =0;
/*!
* Sets the socket file descriptor.
*
* @param descriptor The socket descriptor.
*/
virtual void set_descriptor(int descriptor)=0;
/*!
* Gets the socket's printable remote address
*
* @return The string address
*/
inline const std::string &get_remote_address()
{
return remote_address;
}
/*!
* Sets the connections remote address.
*
* @param addr The remote address to use
*/
void set_remote_address(const std::string &addr)
{
remote_address = addr;
}
/*! /*!
* Send a Sendable object through the socket * Send a Sendable object through the socket
* *

View File

@ -40,7 +40,7 @@ public:
* *
* @param descriptor The socket descriptor. * @param descriptor The socket descriptor.
*/ */
virtual void set_descriptor(int descriptor); virtual void set_descriptor(int descriptor) override;
/*! /*!
* Attempts to send raw data down the socket, without * Attempts to send raw data down the socket, without

View File

@ -7,9 +7,9 @@
#include "frnetlib/NetworkEncoding.h" #include "frnetlib/NetworkEncoding.h"
#include "frnetlib/TcpListener.h" #include "frnetlib/TcpListener.h"
#include "frnetlib/SSLListener.h" #include "frnetlib/SSLListener.h"
#ifdef SSL_ENABLED
#include <mbedtls/net_sockets.h> #include <mbedtls/net_sockets.h>
#include <iostream>
namespace fr namespace fr
{ {
@ -86,16 +86,16 @@ namespace fr
Socket::Status SSLListener::accept(Socket &client_) Socket::Status SSLListener::accept(Socket &client_)
{ {
//Cast to SSLSocket. Will throw bad cast on failure. //Cast to SSLSocket. Will throw bad cast on failure.
auto &client = dynamic_cast<SSLSocket&>(client_); SSLSocket &client = dynamic_cast<SSLSocket&>(client_);
//Initialise mbedtls //Initialise mbedtls
int error = 0; int error = 0;
std::unique_ptr<mbedtls_ssl_context> ssl(new mbedtls_ssl_context); std::unique_ptr<mbedtls_ssl_context> ssl(new mbedtls_ssl_context);
std::unique_ptr<mbedtls_net_context> client_fd(new mbedtls_net_context); mbedtls_net_context client_fd;
mbedtls_ssl_init(ssl.get()); mbedtls_ssl_init(ssl.get());
mbedtls_net_init(client_fd.get()); mbedtls_net_init(&client_fd);
auto free_contexts = [&](){mbedtls_ssl_free(ssl.get()); mbedtls_net_free(client_fd.get());}; auto free_contexts = [&](){mbedtls_ssl_free(ssl.get()); mbedtls_net_free(&client_fd);};
if((error = mbedtls_ssl_setup(ssl.get(), &conf ) ) != 0) if((error = mbedtls_ssl_setup(ssl.get(), &conf ) ) != 0)
{ {
std::cout << "Failed to apply SSL setings: " << error << std::endl; std::cout << "Failed to apply SSL setings: " << error << std::endl;
@ -106,14 +106,14 @@ namespace fr
//Accept a connection //Accept a connection
char client_ip[INET6_ADDRSTRLEN] = {0}; char client_ip[INET6_ADDRSTRLEN] = {0};
size_t ip_len = 0; size_t ip_len = 0;
if((error = mbedtls_net_accept(&listen_fd, client_fd.get(), client_ip, sizeof(client_ip), &ip_len)) != 0) if((error = mbedtls_net_accept(&listen_fd, &client_fd, client_ip, sizeof(client_ip), &ip_len)) != 0)
{ {
std::cout << "Accept error: " << error << std::endl; std::cout << "Accept error: " << error << std::endl;
free_contexts(); free_contexts();
return Socket::Error; return Socket::Error;
} }
mbedtls_ssl_set_bio(ssl.get(), client_fd.get(), mbedtls_net_send, mbedtls_net_recv, nullptr); mbedtls_ssl_set_bio(ssl.get(), &client_fd, mbedtls_net_send, mbedtls_net_recv, nullptr);
//SSL Handshake //SSL Handshake
while((error = mbedtls_ssl_handshake(ssl.get())) != 0) while((error = mbedtls_ssl_handshake(ssl.get())) != 0)
@ -132,14 +132,14 @@ namespace fr
char client_printable_addr[INET6_ADDRSTRLEN]; char client_printable_addr[INET6_ADDRSTRLEN];
struct sockaddr_storage socket_address{}; struct sockaddr_storage socket_address{};
socklen_t socket_length; socklen_t socket_length;
error = getpeername(client_fd->fd, (struct sockaddr*)&socket_address, &socket_length); error = getpeername(client_fd.fd, (struct sockaddr*)&socket_address, &socket_length);
if(error == 0) if(error == 0)
error = getnameinfo((sockaddr*)&socket_address, socket_length, client_printable_addr, sizeof(client_printable_addr), nullptr,0,NI_NUMERICHOST); error = getnameinfo((sockaddr*)&socket_address, socket_length, client_printable_addr, sizeof(client_printable_addr), nullptr,0,NI_NUMERICHOST);
if(error != 0) if(error != 0)
strcpy(client_printable_addr, "unknown"); strcpy(client_printable_addr, "unknown");
client.set_ssl_context(std::move(ssl)); client.set_ssl_context(std::move(ssl));
client.set_net_context(std::move(client_fd)); client.set_descriptor(client_fd.fd);
client.set_remote_address(client_printable_addr); client.set_remote_address(client_printable_addr);
return Socket::Success; return Socket::Success;
} }
@ -168,5 +168,4 @@ namespace fr
} }
} }
} }
#endif

View File

@ -6,8 +6,6 @@
#include <memory> #include <memory>
#include <utility> #include <utility>
#ifdef SSL_ENABLED
#include <mbedtls/net_sockets.h> #include <mbedtls/net_sockets.h>
namespace fr namespace fr
@ -18,7 +16,7 @@ namespace fr
{ {
//Initialise mbedtls structures //Initialise mbedtls structures
mbedtls_ssl_config_init(&conf); mbedtls_ssl_config_init(&conf);
ssl_socket_descriptor.fd = -1;
} }
SSLSocket::~SSLSocket() noexcept SSLSocket::~SSLSocket() noexcept
@ -26,20 +24,20 @@ namespace fr
//Close connection if active //Close connection if active
close_socket(); close_socket();
//Cleanup mbedsql stuff //Cleanup mbedtls stuff
mbedtls_ssl_config_free(&conf); mbedtls_ssl_config_free(&conf);
} }
void SSLSocket::close_socket() void SSLSocket::close_socket()
{ {
if(ssl_socket_descriptor)
mbedtls_net_free(ssl_socket_descriptor.get());
if(ssl) if(ssl)
{ {
mbedtls_ssl_close_notify(ssl.get()); mbedtls_ssl_close_notify(ssl.get());
mbedtls_ssl_free(ssl.get()); mbedtls_ssl_free(ssl.get());
} }
if(ssl_socket_descriptor.fd > -1)
mbedtls_net_free(&ssl_socket_descriptor);
ssl_socket_descriptor.fd = -1;
} }
Socket::Status SSLSocket::send_raw(const char *data, size_t size) Socket::Status SSLSocket::send_raw(const char *data, size_t size)
@ -88,9 +86,8 @@ namespace fr
{ {
//Initialise mbedtls stuff //Initialise mbedtls stuff
ssl = std::make_unique<mbedtls_ssl_context>(); ssl = std::make_unique<mbedtls_ssl_context>();
ssl_socket_descriptor = std::make_unique<mbedtls_net_context>();
mbedtls_ssl_init(ssl.get()); mbedtls_ssl_init(ssl.get());
mbedtls_net_init(ssl_socket_descriptor.get()); mbedtls_net_init(&ssl_socket_descriptor);
//Do to mbedtls not supporting connect timeouts, we have to use an fr::TcpSocket to //Do to mbedtls not supporting connect timeouts, we have to use an fr::TcpSocket to
//Open the descriptor, and then steal it. This is a hack. //Open the descriptor, and then steal it. This is a hack.
@ -99,7 +96,7 @@ namespace fr
auto ret = socket.connect(address, port, timeout); auto ret = socket.connect(address, port, timeout);
if(ret != fr::Socket::Success) if(ret != fr::Socket::Success)
return ret; return ret;
ssl_socket_descriptor->fd = socket.get_socket_descriptor(); ssl_socket_descriptor.fd = socket.get_socket_descriptor();
remote_address = socket.get_remote_address(); remote_address = socket.get_remote_address();
socket.set_descriptor(-1); socket.set_descriptor(-1);
} }
@ -125,7 +122,7 @@ namespace fr
return Socket::Status::Error; return Socket::Status::Error;
} }
mbedtls_ssl_set_bio(ssl.get(), ssl_socket_descriptor.get(), mbedtls_net_send, mbedtls_net_recv, nullptr); mbedtls_ssl_set_bio(ssl.get(), &ssl_socket_descriptor, mbedtls_net_send, mbedtls_net_recv, nullptr);
//Do SSL handshake //Do SSL handshake
while((error = mbedtls_ssl_handshake(ssl.get())) != 0) while((error = mbedtls_ssl_handshake(ssl.get())) != 0)
@ -158,9 +155,9 @@ namespace fr
ssl = std::move(context); ssl = std::move(context);
} }
void SSLSocket::set_net_context(std::unique_ptr<mbedtls_net_context> context) void SSLSocket::set_descriptor(int descriptor)
{ {
ssl_socket_descriptor = std::move(context); ssl_socket_descriptor.fd = descriptor;
reconfigure_socket(); reconfigure_socket();
} }
@ -169,5 +166,3 @@ namespace fr
should_verify = should_verify_; should_verify = should_verify_;
} }
} }
#endif