Fixed memory leak

fr::SSLSocket did not call mbedtls_net_free. It called mbedtls_ssl_close_notify instead.
This commit is contained in:
Fred Nicolson 2017-07-31 11:58:22 +01:00
parent 84382cad0b
commit dff81f495e
4 changed files with 41 additions and 34 deletions

View File

@ -24,7 +24,7 @@ namespace fr
class SSLListener : public Listener class SSLListener : public Listener
{ {
public: public:
SSLListener(std::shared_ptr<SSLContext> ssl_context, const std::string &crt_path, const std::string &pem_path, const std::string &private_key_path) noexcept; explicit SSLListener(std::shared_ptr<SSLContext> ssl_context, const std::string &crt_path, const std::string &pem_path, const std::string &private_key_path) noexcept;
virtual ~SSLListener() noexcept; virtual ~SSLListener() noexcept;
SSLListener(SSLListener &&o) noexcept = default; SSLListener(SSLListener &&o) noexcept = default;
@ -47,7 +47,7 @@ namespace fr
/*! /*!
* Closes the socket * Closes the socket
*/ */
virtual void close_socket() override; void close_socket() override;
/*! /*!
* Calls the shutdown syscall on the socket. * Calls the shutdown syscall on the socket.
@ -57,21 +57,21 @@ namespace fr
* it to immediately return (you might want to do this if * it to immediately return (you might want to do this if
* you're exiting and need the blocking socket to return). * you're exiting and need the blocking socket to return).
*/ */
virtual void shutdown() override; void shutdown() override;
/*! /*!
* Gets the socket descriptor. * Gets the socket descriptor.
* *
* @return The listen socket descriptor * @return The listen socket descriptor
*/ */
virtual int32_t get_socket_descriptor() const override; int32_t get_socket_descriptor() const override;
/*! /*!
* Sets the socket descriptor. * Sets the socket descriptor.
* *
* @param descriptor The listen descriptor to use * @param descriptor The listen descriptor to use
*/ */
virtual void set_socket_descriptor(int32_t descriptor) override; void set_socket_descriptor(int32_t descriptor) override;
private: private:
mbedtls_net_context listen_fd; mbedtls_net_context listen_fd;

View File

@ -4,7 +4,6 @@
#ifndef FRNETLIB_SSL_SOCKET_H #ifndef FRNETLIB_SSL_SOCKET_H
#define FRNETLIB_SSL_SOCKET_H #define FRNETLIB_SSL_SOCKET_H
#ifdef SSL_ENABLED #ifdef SSL_ENABLED
#include "TcpSocket.h" #include "TcpSocket.h"
@ -22,7 +21,7 @@ namespace fr
class SSLSocket : public Socket class SSLSocket : public Socket
{ {
public: public:
SSLSocket(std::shared_ptr<SSLContext> ssl_context) noexcept; explicit SSLSocket(std::shared_ptr<SSLContext> ssl_context) noexcept;
virtual ~SSLSocket() noexcept; virtual ~SSLSocket() noexcept;
@ -83,7 +82,7 @@ namespace fr
* *
* @return The socket's descriptor. * @return The socket's descriptor.
*/ */
virtual int32_t get_socket_descriptor() const override int32_t get_socket_descriptor() const override
{ {
return ssl_socket_descriptor->fd; return ssl_socket_descriptor->fd;
} }
@ -93,7 +92,7 @@ namespace fr
* *
* @param should_block True to block, false otherwise. * @param should_block True to block, false otherwise.
*/ */
virtual void set_blocking(bool should_block) override void set_blocking(bool should_block) override
{ {
abort(); abort();
} }
@ -103,7 +102,7 @@ namespace fr
* *
* @return True if it's connected. False otherwise. * @return True if it's connected. False otherwise.
*/ */
inline virtual bool connected() const override final inline bool connected() const final
{ {
return ssl_socket_descriptor->fd > -1; return ssl_socket_descriptor->fd > -1;
} }

View File

@ -3,6 +3,7 @@
// //
#include <chrono> #include <chrono>
#include <utility>
#include <frnetlib/TcpListener.h> #include <frnetlib/TcpListener.h>
#include "frnetlib/SSLListener.h" #include "frnetlib/SSLListener.h"
#ifdef SSL_ENABLED #ifdef SSL_ENABLED
@ -12,7 +13,7 @@
namespace fr namespace fr
{ {
SSLListener::SSLListener(std::shared_ptr<SSLContext> ssl_context_, const std::string &crt_path, const std::string &pem_path, const std::string &private_key_path) noexcept SSLListener::SSLListener(std::shared_ptr<SSLContext> ssl_context_, const std::string &crt_path, const std::string &pem_path, const std::string &private_key_path) noexcept
: ssl_context(ssl_context_) : ssl_context(std::move(ssl_context_))
{ {
//Initialise SSL objects required //Initialise SSL objects required
listen_fd.fd = -1; listen_fd.fd = -1;
@ -53,7 +54,7 @@ namespace fr
//Apply them //Apply them
mbedtls_ssl_conf_rng(&conf, mbedtls_ctr_drbg_random, &ssl_context->ctr_drbg); mbedtls_ssl_conf_rng(&conf, mbedtls_ctr_drbg_random, &ssl_context->ctr_drbg);
mbedtls_ssl_conf_ca_chain(&conf, srvcert.next, NULL); mbedtls_ssl_conf_ca_chain(&conf, srvcert.next, nullptr);
if((error = mbedtls_ssl_conf_own_cert(&conf, &srvcert, &pkey)) != 0) if((error = mbedtls_ssl_conf_own_cert(&conf, &srvcert, &pkey)) != 0)
{ {
@ -68,6 +69,7 @@ namespace fr
mbedtls_x509_crt_free(&srvcert); mbedtls_x509_crt_free(&srvcert);
mbedtls_pk_free(&pkey); mbedtls_pk_free(&pkey);
mbedtls_ssl_config_free(&conf); mbedtls_ssl_config_free(&conf);
mbedtls_net_free(&listen_fd);
} }
Socket::Status fr::SSLListener::listen(const std::string &port) Socket::Status fr::SSLListener::listen(const std::string &port)
@ -90,42 +92,46 @@ namespace fr
Socket::Status SSLListener::accept(Socket &client_) Socket::Status SSLListener::accept(Socket &client_)
{ {
//Cast to SSLSocket. Will throw bad cast on failure. //Cast to SSLSocket. Will throw bad cast on failure.
SSLSocket &client = dynamic_cast<SSLSocket&>(client_); auto &client = dynamic_cast<SSLSocket&>(client_);
//Initialise mbedtls //Initialise mbedtls
int error = 0; int error = 0;
mbedtls_ssl_context *ssl = new mbedtls_ssl_context; std::unique_ptr<mbedtls_ssl_context> ssl(new mbedtls_ssl_context);
client.set_ssl_context(std::unique_ptr<mbedtls_ssl_context>(ssl)); std::unique_ptr<mbedtls_net_context> client_fd(new mbedtls_net_context);
mbedtls_ssl_init(ssl); mbedtls_ssl_init(ssl.get());
if((error = mbedtls_ssl_setup(ssl, &conf ) ) != 0) mbedtls_net_init(client_fd.get());
auto free_contexts = [&](){mbedtls_ssl_free(ssl.get()); mbedtls_net_free(client_fd.get());};
if((error = mbedtls_ssl_setup(ssl.get(), &conf ) ) != 0)
{ {
std::cout << "Failed to apply SSL setings: " << error << std::endl; std::cout << "Failed to apply SSL setings: " << error << std::endl;
free_contexts();
return Socket::Error; return Socket::Error;
} }
//Accept a connection //Accept a connection
mbedtls_net_context *client_fd = new mbedtls_net_context; if((error = mbedtls_net_accept(&listen_fd, client_fd.get(), nullptr, 0, nullptr)) != 0)
client.set_net_context(std::unique_ptr<mbedtls_net_context>(client_fd));
mbedtls_net_init(client_fd);
if((error = mbedtls_net_accept(&listen_fd, client_fd, NULL, 0, NULL)) != 0)
{ {
std::cout << "Accept error: " << error << std::endl; std::cout << "Accept error: " << error << std::endl;
free_contexts();
return Socket::Error; return Socket::Error;
} }
mbedtls_ssl_set_bio(ssl, client_fd, mbedtls_net_send, mbedtls_net_recv, NULL); mbedtls_ssl_set_bio(ssl.get(), client_fd.get(), mbedtls_net_send, mbedtls_net_recv, nullptr);
//SSL Handshake //SSL Handshake
while((error = mbedtls_ssl_handshake(ssl)) != 0) while((error = mbedtls_ssl_handshake(ssl.get())) != 0)
{ {
if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE) if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE)
{ {
std::cout << "Handshake error: " << error << std::endl; std::cout << "Handshake error: " << error << std::endl;
free_contexts();
return Socket::Status::HandshakeFailed; return Socket::Status::HandshakeFailed;
} }
} }
client.set_ssl_context(std::move(ssl));
client.set_net_context(std::move(client_fd));
return Socket::Success; return Socket::Success;
} }

View File

@ -4,6 +4,7 @@
#include "frnetlib/SSLSocket.h" #include "frnetlib/SSLSocket.h"
#include <memory> #include <memory>
#include <utility>
#ifdef SSL_ENABLED #ifdef SSL_ENABLED
@ -12,7 +13,7 @@
namespace fr namespace fr
{ {
SSLSocket::SSLSocket(std::shared_ptr<SSLContext> ssl_context_) noexcept SSLSocket::SSLSocket(std::shared_ptr<SSLContext> ssl_context_) noexcept
: ssl_context(ssl_context_) : ssl_context(std::move(ssl_context_))
{ {
//Initialise mbedtls structures //Initialise mbedtls structures
mbedtls_ssl_config_init(&conf); mbedtls_ssl_config_init(&conf);
@ -30,13 +31,14 @@ namespace fr
void SSLSocket::close_socket() void SSLSocket::close_socket()
{ {
if(ssl_socket_descriptor && ssl_socket_descriptor->fd > -1)
{
if(ssl)
mbedtls_ssl_close_notify(ssl.get());
if(ssl_socket_descriptor) if(ssl_socket_descriptor)
mbedtls_net_free(ssl_socket_descriptor.get()); mbedtls_net_free(ssl_socket_descriptor.get());
if(ssl)
{
mbedtls_ssl_close_notify(ssl.get());
mbedtls_ssl_free(ssl.get());
} }
} }
Socket::Status SSLSocket::send_raw(const char *data, size_t size) Socket::Status SSLSocket::send_raw(const char *data, size_t size)
@ -84,8 +86,8 @@ namespace fr
Socket::Status SSLSocket::connect(const std::string &address, const std::string &port) Socket::Status SSLSocket::connect(const std::string &address, const std::string &port)
{ {
//Initialise mbedtls stuff //Initialise mbedtls stuff
ssl = std::unique_ptr<mbedtls_ssl_context>(new mbedtls_ssl_context); ssl = std::make_unique<mbedtls_ssl_context>();
ssl_socket_descriptor = std::unique_ptr<mbedtls_net_context>(new mbedtls_net_context); ssl_socket_descriptor = std::make_unique<mbedtls_net_context>();
mbedtls_ssl_init(ssl.get()); mbedtls_ssl_init(ssl.get());
mbedtls_net_init(ssl_socket_descriptor.get()); mbedtls_net_init(ssl_socket_descriptor.get());
@ -103,7 +105,7 @@ namespace fr
} }
mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_REQUIRED); mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_REQUIRED);
mbedtls_ssl_conf_ca_chain(&conf, &ssl_context->cacert, NULL); mbedtls_ssl_conf_ca_chain(&conf, &ssl_context->cacert, nullptr);
mbedtls_ssl_conf_rng(&conf, mbedtls_ctr_drbg_random, &ssl_context->ctr_drbg); mbedtls_ssl_conf_rng(&conf, mbedtls_ctr_drbg_random, &ssl_context->ctr_drbg);
if((error = mbedtls_ssl_setup(ssl.get(), &conf)) != 0) if((error = mbedtls_ssl_setup(ssl.get(), &conf)) != 0)
@ -116,7 +118,7 @@ namespace fr
return Socket::Status::Error; return Socket::Status::Error;
} }
mbedtls_ssl_set_bio(ssl.get(), ssl_socket_descriptor.get(), mbedtls_net_send, mbedtls_net_recv, NULL); mbedtls_ssl_set_bio(ssl.get(), ssl_socket_descriptor.get(), mbedtls_net_send, mbedtls_net_recv, nullptr);
//Do SSL handshake //Do SSL handshake
while((error = mbedtls_ssl_handshake(ssl.get())) != 0) while((error = mbedtls_ssl_handshake(ssl.get())) != 0)