diff --git a/include/frnetlib/SSLListener.h b/include/frnetlib/SSLListener.h index 7c5c47e..c168813 100644 --- a/include/frnetlib/SSLListener.h +++ b/include/frnetlib/SSLListener.h @@ -24,7 +24,7 @@ namespace fr class SSLListener : public Listener { public: - SSLListener(std::shared_ptr ssl_context, const std::string &crt_path, const std::string &pem_path, const std::string &private_key_path) noexcept; + explicit SSLListener(std::shared_ptr ssl_context, const std::string &crt_path, const std::string &pem_path, const std::string &private_key_path) noexcept; virtual ~SSLListener() noexcept; SSLListener(SSLListener &&o) noexcept = default; @@ -47,7 +47,7 @@ namespace fr /*! * Closes the socket */ - virtual void close_socket() override; + void close_socket() override; /*! * Calls the shutdown syscall on the socket. @@ -57,21 +57,21 @@ namespace fr * it to immediately return (you might want to do this if * you're exiting and need the blocking socket to return). */ - virtual void shutdown() override; + void shutdown() override; /*! * Gets the socket descriptor. * * @return The listen socket descriptor */ - virtual int32_t get_socket_descriptor() const override; + int32_t get_socket_descriptor() const override; /*! * Sets the socket descriptor. * * @param descriptor The listen descriptor to use */ - virtual void set_socket_descriptor(int32_t descriptor) override; + void set_socket_descriptor(int32_t descriptor) override; private: mbedtls_net_context listen_fd; diff --git a/include/frnetlib/SSLSocket.h b/include/frnetlib/SSLSocket.h index 53c2da0..a184cf6 100644 --- a/include/frnetlib/SSLSocket.h +++ b/include/frnetlib/SSLSocket.h @@ -4,7 +4,6 @@ #ifndef FRNETLIB_SSL_SOCKET_H #define FRNETLIB_SSL_SOCKET_H - #ifdef SSL_ENABLED #include "TcpSocket.h" @@ -22,7 +21,7 @@ namespace fr class SSLSocket : public Socket { public: - SSLSocket(std::shared_ptr ssl_context) noexcept; + explicit SSLSocket(std::shared_ptr ssl_context) noexcept; virtual ~SSLSocket() noexcept; @@ -83,7 +82,7 @@ namespace fr * * @return The socket's descriptor. */ - virtual int32_t get_socket_descriptor() const override + int32_t get_socket_descriptor() const override { return ssl_socket_descriptor->fd; } @@ -93,7 +92,7 @@ namespace fr * * @param should_block True to block, false otherwise. */ - virtual void set_blocking(bool should_block) override + void set_blocking(bool should_block) override { abort(); } @@ -103,7 +102,7 @@ namespace fr * * @return True if it's connected. False otherwise. */ - inline virtual bool connected() const override final + inline bool connected() const final { return ssl_socket_descriptor->fd > -1; } diff --git a/src/SSLListener.cpp b/src/SSLListener.cpp index 6503d62..e6db304 100644 --- a/src/SSLListener.cpp +++ b/src/SSLListener.cpp @@ -3,6 +3,7 @@ // #include +#include #include #include "frnetlib/SSLListener.h" #ifdef SSL_ENABLED @@ -12,7 +13,7 @@ namespace fr { SSLListener::SSLListener(std::shared_ptr ssl_context_, const std::string &crt_path, const std::string &pem_path, const std::string &private_key_path) noexcept - : ssl_context(ssl_context_) + : ssl_context(std::move(ssl_context_)) { //Initialise SSL objects required listen_fd.fd = -1; @@ -53,7 +54,7 @@ namespace fr //Apply them mbedtls_ssl_conf_rng(&conf, mbedtls_ctr_drbg_random, &ssl_context->ctr_drbg); - mbedtls_ssl_conf_ca_chain(&conf, srvcert.next, NULL); + mbedtls_ssl_conf_ca_chain(&conf, srvcert.next, nullptr); if((error = mbedtls_ssl_conf_own_cert(&conf, &srvcert, &pkey)) != 0) { @@ -68,6 +69,7 @@ namespace fr mbedtls_x509_crt_free(&srvcert); mbedtls_pk_free(&pkey); mbedtls_ssl_config_free(&conf); + mbedtls_net_free(&listen_fd); } Socket::Status fr::SSLListener::listen(const std::string &port) @@ -90,42 +92,46 @@ namespace fr Socket::Status SSLListener::accept(Socket &client_) { //Cast to SSLSocket. Will throw bad cast on failure. - SSLSocket &client = dynamic_cast(client_); + auto &client = dynamic_cast(client_); //Initialise mbedtls int error = 0; - mbedtls_ssl_context *ssl = new mbedtls_ssl_context; - client.set_ssl_context(std::unique_ptr(ssl)); + std::unique_ptr ssl(new mbedtls_ssl_context); + std::unique_ptr client_fd(new mbedtls_net_context); - mbedtls_ssl_init(ssl); - if((error = mbedtls_ssl_setup(ssl, &conf ) ) != 0) + mbedtls_ssl_init(ssl.get()); + mbedtls_net_init(client_fd.get()); + auto free_contexts = [&](){mbedtls_ssl_free(ssl.get()); mbedtls_net_free(client_fd.get());}; + if((error = mbedtls_ssl_setup(ssl.get(), &conf ) ) != 0) { std::cout << "Failed to apply SSL setings: " << error << std::endl; + free_contexts(); return Socket::Error; } //Accept a connection - mbedtls_net_context *client_fd = new mbedtls_net_context; - client.set_net_context(std::unique_ptr(client_fd)); - mbedtls_net_init(client_fd); - - if((error = mbedtls_net_accept(&listen_fd, client_fd, NULL, 0, NULL)) != 0) + if((error = mbedtls_net_accept(&listen_fd, client_fd.get(), nullptr, 0, nullptr)) != 0) { std::cout << "Accept error: " << error << std::endl; + free_contexts(); return Socket::Error; } - mbedtls_ssl_set_bio(ssl, client_fd, mbedtls_net_send, mbedtls_net_recv, NULL); + mbedtls_ssl_set_bio(ssl.get(), client_fd.get(), mbedtls_net_send, mbedtls_net_recv, nullptr); //SSL Handshake - while((error = mbedtls_ssl_handshake(ssl)) != 0) + while((error = mbedtls_ssl_handshake(ssl.get())) != 0) { if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE) { std::cout << "Handshake error: " << error << std::endl; + free_contexts(); return Socket::Status::HandshakeFailed; } } + + client.set_ssl_context(std::move(ssl)); + client.set_net_context(std::move(client_fd)); return Socket::Success; } diff --git a/src/SSLSocket.cpp b/src/SSLSocket.cpp index 23d6d66..afef88c 100644 --- a/src/SSLSocket.cpp +++ b/src/SSLSocket.cpp @@ -4,6 +4,7 @@ #include "frnetlib/SSLSocket.h" #include +#include #ifdef SSL_ENABLED @@ -12,7 +13,7 @@ namespace fr { SSLSocket::SSLSocket(std::shared_ptr ssl_context_) noexcept - : ssl_context(ssl_context_) + : ssl_context(std::move(ssl_context_)) { //Initialise mbedtls structures mbedtls_ssl_config_init(&conf); @@ -30,13 +31,14 @@ namespace fr void SSLSocket::close_socket() { - if(ssl_socket_descriptor && ssl_socket_descriptor->fd > -1) + if(ssl_socket_descriptor) + mbedtls_net_free(ssl_socket_descriptor.get()); + if(ssl) { - if(ssl) - mbedtls_ssl_close_notify(ssl.get()); - if(ssl_socket_descriptor) - mbedtls_net_free(ssl_socket_descriptor.get()); + mbedtls_ssl_close_notify(ssl.get()); + mbedtls_ssl_free(ssl.get()); } + } Socket::Status SSLSocket::send_raw(const char *data, size_t size) @@ -84,8 +86,8 @@ namespace fr Socket::Status SSLSocket::connect(const std::string &address, const std::string &port) { //Initialise mbedtls stuff - ssl = std::unique_ptr(new mbedtls_ssl_context); - ssl_socket_descriptor = std::unique_ptr(new mbedtls_net_context); + ssl = std::make_unique(); + ssl_socket_descriptor = std::make_unique(); mbedtls_ssl_init(ssl.get()); mbedtls_net_init(ssl_socket_descriptor.get()); @@ -103,7 +105,7 @@ namespace fr } mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_REQUIRED); - mbedtls_ssl_conf_ca_chain(&conf, &ssl_context->cacert, NULL); + mbedtls_ssl_conf_ca_chain(&conf, &ssl_context->cacert, nullptr); mbedtls_ssl_conf_rng(&conf, mbedtls_ctr_drbg_random, &ssl_context->ctr_drbg); if((error = mbedtls_ssl_setup(ssl.get(), &conf)) != 0) @@ -116,7 +118,7 @@ namespace fr return Socket::Status::Error; } - mbedtls_ssl_set_bio(ssl.get(), ssl_socket_descriptor.get(), mbedtls_net_send, mbedtls_net_recv, NULL); + mbedtls_ssl_set_bio(ssl.get(), ssl_socket_descriptor.get(), mbedtls_net_send, mbedtls_net_recv, nullptr); //Do SSL handshake while((error = mbedtls_ssl_handshake(ssl.get())) != 0)