Fixed memory leak

fr::SSLSocket did not call mbedtls_net_free. It called mbedtls_ssl_close_notify instead.
This commit is contained in:
Fred Nicolson 2017-07-31 11:58:22 +01:00
parent 84382cad0b
commit dff81f495e
4 changed files with 41 additions and 34 deletions

View File

@ -24,7 +24,7 @@ namespace fr
class SSLListener : public Listener
{
public:
SSLListener(std::shared_ptr<SSLContext> ssl_context, const std::string &crt_path, const std::string &pem_path, const std::string &private_key_path) noexcept;
explicit SSLListener(std::shared_ptr<SSLContext> ssl_context, const std::string &crt_path, const std::string &pem_path, const std::string &private_key_path) noexcept;
virtual ~SSLListener() noexcept;
SSLListener(SSLListener &&o) noexcept = default;
@ -47,7 +47,7 @@ namespace fr
/*!
* Closes the socket
*/
virtual void close_socket() override;
void close_socket() override;
/*!
* Calls the shutdown syscall on the socket.
@ -57,21 +57,21 @@ namespace fr
* it to immediately return (you might want to do this if
* you're exiting and need the blocking socket to return).
*/
virtual void shutdown() override;
void shutdown() override;
/*!
* Gets the socket descriptor.
*
* @return The listen socket descriptor
*/
virtual int32_t get_socket_descriptor() const override;
int32_t get_socket_descriptor() const override;
/*!
* Sets the socket descriptor.
*
* @param descriptor The listen descriptor to use
*/
virtual void set_socket_descriptor(int32_t descriptor) override;
void set_socket_descriptor(int32_t descriptor) override;
private:
mbedtls_net_context listen_fd;

View File

@ -4,7 +4,6 @@
#ifndef FRNETLIB_SSL_SOCKET_H
#define FRNETLIB_SSL_SOCKET_H
#ifdef SSL_ENABLED
#include "TcpSocket.h"
@ -22,7 +21,7 @@ namespace fr
class SSLSocket : public Socket
{
public:
SSLSocket(std::shared_ptr<SSLContext> ssl_context) noexcept;
explicit SSLSocket(std::shared_ptr<SSLContext> ssl_context) noexcept;
virtual ~SSLSocket() noexcept;
@ -83,7 +82,7 @@ namespace fr
*
* @return The socket's descriptor.
*/
virtual int32_t get_socket_descriptor() const override
int32_t get_socket_descriptor() const override
{
return ssl_socket_descriptor->fd;
}
@ -93,7 +92,7 @@ namespace fr
*
* @param should_block True to block, false otherwise.
*/
virtual void set_blocking(bool should_block) override
void set_blocking(bool should_block) override
{
abort();
}
@ -103,7 +102,7 @@ namespace fr
*
* @return True if it's connected. False otherwise.
*/
inline virtual bool connected() const override final
inline bool connected() const final
{
return ssl_socket_descriptor->fd > -1;
}

View File

@ -3,6 +3,7 @@
//
#include <chrono>
#include <utility>
#include <frnetlib/TcpListener.h>
#include "frnetlib/SSLListener.h"
#ifdef SSL_ENABLED
@ -12,7 +13,7 @@
namespace fr
{
SSLListener::SSLListener(std::shared_ptr<SSLContext> ssl_context_, const std::string &crt_path, const std::string &pem_path, const std::string &private_key_path) noexcept
: ssl_context(ssl_context_)
: ssl_context(std::move(ssl_context_))
{
//Initialise SSL objects required
listen_fd.fd = -1;
@ -53,7 +54,7 @@ namespace fr
//Apply them
mbedtls_ssl_conf_rng(&conf, mbedtls_ctr_drbg_random, &ssl_context->ctr_drbg);
mbedtls_ssl_conf_ca_chain(&conf, srvcert.next, NULL);
mbedtls_ssl_conf_ca_chain(&conf, srvcert.next, nullptr);
if((error = mbedtls_ssl_conf_own_cert(&conf, &srvcert, &pkey)) != 0)
{
@ -68,6 +69,7 @@ namespace fr
mbedtls_x509_crt_free(&srvcert);
mbedtls_pk_free(&pkey);
mbedtls_ssl_config_free(&conf);
mbedtls_net_free(&listen_fd);
}
Socket::Status fr::SSLListener::listen(const std::string &port)
@ -90,42 +92,46 @@ namespace fr
Socket::Status SSLListener::accept(Socket &client_)
{
//Cast to SSLSocket. Will throw bad cast on failure.
SSLSocket &client = dynamic_cast<SSLSocket&>(client_);
auto &client = dynamic_cast<SSLSocket&>(client_);
//Initialise mbedtls
int error = 0;
mbedtls_ssl_context *ssl = new mbedtls_ssl_context;
client.set_ssl_context(std::unique_ptr<mbedtls_ssl_context>(ssl));
std::unique_ptr<mbedtls_ssl_context> ssl(new mbedtls_ssl_context);
std::unique_ptr<mbedtls_net_context> client_fd(new mbedtls_net_context);
mbedtls_ssl_init(ssl);
if((error = mbedtls_ssl_setup(ssl, &conf ) ) != 0)
mbedtls_ssl_init(ssl.get());
mbedtls_net_init(client_fd.get());
auto free_contexts = [&](){mbedtls_ssl_free(ssl.get()); mbedtls_net_free(client_fd.get());};
if((error = mbedtls_ssl_setup(ssl.get(), &conf ) ) != 0)
{
std::cout << "Failed to apply SSL setings: " << error << std::endl;
free_contexts();
return Socket::Error;
}
//Accept a connection
mbedtls_net_context *client_fd = new mbedtls_net_context;
client.set_net_context(std::unique_ptr<mbedtls_net_context>(client_fd));
mbedtls_net_init(client_fd);
if((error = mbedtls_net_accept(&listen_fd, client_fd, NULL, 0, NULL)) != 0)
if((error = mbedtls_net_accept(&listen_fd, client_fd.get(), nullptr, 0, nullptr)) != 0)
{
std::cout << "Accept error: " << error << std::endl;
free_contexts();
return Socket::Error;
}
mbedtls_ssl_set_bio(ssl, client_fd, mbedtls_net_send, mbedtls_net_recv, NULL);
mbedtls_ssl_set_bio(ssl.get(), client_fd.get(), mbedtls_net_send, mbedtls_net_recv, nullptr);
//SSL Handshake
while((error = mbedtls_ssl_handshake(ssl)) != 0)
while((error = mbedtls_ssl_handshake(ssl.get())) != 0)
{
if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE)
{
std::cout << "Handshake error: " << error << std::endl;
free_contexts();
return Socket::Status::HandshakeFailed;
}
}
client.set_ssl_context(std::move(ssl));
client.set_net_context(std::move(client_fd));
return Socket::Success;
}

View File

@ -4,6 +4,7 @@
#include "frnetlib/SSLSocket.h"
#include <memory>
#include <utility>
#ifdef SSL_ENABLED
@ -12,7 +13,7 @@
namespace fr
{
SSLSocket::SSLSocket(std::shared_ptr<SSLContext> ssl_context_) noexcept
: ssl_context(ssl_context_)
: ssl_context(std::move(ssl_context_))
{
//Initialise mbedtls structures
mbedtls_ssl_config_init(&conf);
@ -30,13 +31,14 @@ namespace fr
void SSLSocket::close_socket()
{
if(ssl_socket_descriptor && ssl_socket_descriptor->fd > -1)
if(ssl_socket_descriptor)
mbedtls_net_free(ssl_socket_descriptor.get());
if(ssl)
{
if(ssl)
mbedtls_ssl_close_notify(ssl.get());
if(ssl_socket_descriptor)
mbedtls_net_free(ssl_socket_descriptor.get());
mbedtls_ssl_close_notify(ssl.get());
mbedtls_ssl_free(ssl.get());
}
}
Socket::Status SSLSocket::send_raw(const char *data, size_t size)
@ -84,8 +86,8 @@ namespace fr
Socket::Status SSLSocket::connect(const std::string &address, const std::string &port)
{
//Initialise mbedtls stuff
ssl = std::unique_ptr<mbedtls_ssl_context>(new mbedtls_ssl_context);
ssl_socket_descriptor = std::unique_ptr<mbedtls_net_context>(new mbedtls_net_context);
ssl = std::make_unique<mbedtls_ssl_context>();
ssl_socket_descriptor = std::make_unique<mbedtls_net_context>();
mbedtls_ssl_init(ssl.get());
mbedtls_net_init(ssl_socket_descriptor.get());
@ -103,7 +105,7 @@ namespace fr
}
mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_REQUIRED);
mbedtls_ssl_conf_ca_chain(&conf, &ssl_context->cacert, NULL);
mbedtls_ssl_conf_ca_chain(&conf, &ssl_context->cacert, nullptr);
mbedtls_ssl_conf_rng(&conf, mbedtls_ctr_drbg_random, &ssl_context->ctr_drbg);
if((error = mbedtls_ssl_setup(ssl.get(), &conf)) != 0)
@ -116,7 +118,7 @@ namespace fr
return Socket::Status::Error;
}
mbedtls_ssl_set_bio(ssl.get(), ssl_socket_descriptor.get(), mbedtls_net_send, mbedtls_net_recv, NULL);
mbedtls_ssl_set_bio(ssl.get(), ssl_socket_descriptor.get(), mbedtls_net_send, mbedtls_net_recv, nullptr);
//Do SSL handshake
while((error = mbedtls_ssl_handshake(ssl.get())) != 0)