Fixed memory leak
fr::SSLSocket did not call mbedtls_net_free. It called mbedtls_ssl_close_notify instead.
This commit is contained in:
parent
84382cad0b
commit
dff81f495e
@ -24,7 +24,7 @@ namespace fr
|
|||||||
class SSLListener : public Listener
|
class SSLListener : public Listener
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
SSLListener(std::shared_ptr<SSLContext> ssl_context, const std::string &crt_path, const std::string &pem_path, const std::string &private_key_path) noexcept;
|
explicit SSLListener(std::shared_ptr<SSLContext> ssl_context, const std::string &crt_path, const std::string &pem_path, const std::string &private_key_path) noexcept;
|
||||||
virtual ~SSLListener() noexcept;
|
virtual ~SSLListener() noexcept;
|
||||||
SSLListener(SSLListener &&o) noexcept = default;
|
SSLListener(SSLListener &&o) noexcept = default;
|
||||||
|
|
||||||
@ -47,7 +47,7 @@ namespace fr
|
|||||||
/*!
|
/*!
|
||||||
* Closes the socket
|
* Closes the socket
|
||||||
*/
|
*/
|
||||||
virtual void close_socket() override;
|
void close_socket() override;
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* Calls the shutdown syscall on the socket.
|
* Calls the shutdown syscall on the socket.
|
||||||
@ -57,21 +57,21 @@ namespace fr
|
|||||||
* it to immediately return (you might want to do this if
|
* it to immediately return (you might want to do this if
|
||||||
* you're exiting and need the blocking socket to return).
|
* you're exiting and need the blocking socket to return).
|
||||||
*/
|
*/
|
||||||
virtual void shutdown() override;
|
void shutdown() override;
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* Gets the socket descriptor.
|
* Gets the socket descriptor.
|
||||||
*
|
*
|
||||||
* @return The listen socket descriptor
|
* @return The listen socket descriptor
|
||||||
*/
|
*/
|
||||||
virtual int32_t get_socket_descriptor() const override;
|
int32_t get_socket_descriptor() const override;
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* Sets the socket descriptor.
|
* Sets the socket descriptor.
|
||||||
*
|
*
|
||||||
* @param descriptor The listen descriptor to use
|
* @param descriptor The listen descriptor to use
|
||||||
*/
|
*/
|
||||||
virtual void set_socket_descriptor(int32_t descriptor) override;
|
void set_socket_descriptor(int32_t descriptor) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
mbedtls_net_context listen_fd;
|
mbedtls_net_context listen_fd;
|
||||||
|
|||||||
@ -4,7 +4,6 @@
|
|||||||
|
|
||||||
#ifndef FRNETLIB_SSL_SOCKET_H
|
#ifndef FRNETLIB_SSL_SOCKET_H
|
||||||
#define FRNETLIB_SSL_SOCKET_H
|
#define FRNETLIB_SSL_SOCKET_H
|
||||||
|
|
||||||
#ifdef SSL_ENABLED
|
#ifdef SSL_ENABLED
|
||||||
|
|
||||||
#include "TcpSocket.h"
|
#include "TcpSocket.h"
|
||||||
@ -22,7 +21,7 @@ namespace fr
|
|||||||
class SSLSocket : public Socket
|
class SSLSocket : public Socket
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
SSLSocket(std::shared_ptr<SSLContext> ssl_context) noexcept;
|
explicit SSLSocket(std::shared_ptr<SSLContext> ssl_context) noexcept;
|
||||||
|
|
||||||
virtual ~SSLSocket() noexcept;
|
virtual ~SSLSocket() noexcept;
|
||||||
|
|
||||||
@ -83,7 +82,7 @@ namespace fr
|
|||||||
*
|
*
|
||||||
* @return The socket's descriptor.
|
* @return The socket's descriptor.
|
||||||
*/
|
*/
|
||||||
virtual int32_t get_socket_descriptor() const override
|
int32_t get_socket_descriptor() const override
|
||||||
{
|
{
|
||||||
return ssl_socket_descriptor->fd;
|
return ssl_socket_descriptor->fd;
|
||||||
}
|
}
|
||||||
@ -93,7 +92,7 @@ namespace fr
|
|||||||
*
|
*
|
||||||
* @param should_block True to block, false otherwise.
|
* @param should_block True to block, false otherwise.
|
||||||
*/
|
*/
|
||||||
virtual void set_blocking(bool should_block) override
|
void set_blocking(bool should_block) override
|
||||||
{
|
{
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
@ -103,7 +102,7 @@ namespace fr
|
|||||||
*
|
*
|
||||||
* @return True if it's connected. False otherwise.
|
* @return True if it's connected. False otherwise.
|
||||||
*/
|
*/
|
||||||
inline virtual bool connected() const override final
|
inline bool connected() const final
|
||||||
{
|
{
|
||||||
return ssl_socket_descriptor->fd > -1;
|
return ssl_socket_descriptor->fd > -1;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
//
|
//
|
||||||
|
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
|
#include <utility>
|
||||||
#include <frnetlib/TcpListener.h>
|
#include <frnetlib/TcpListener.h>
|
||||||
#include "frnetlib/SSLListener.h"
|
#include "frnetlib/SSLListener.h"
|
||||||
#ifdef SSL_ENABLED
|
#ifdef SSL_ENABLED
|
||||||
@ -12,7 +13,7 @@
|
|||||||
namespace fr
|
namespace fr
|
||||||
{
|
{
|
||||||
SSLListener::SSLListener(std::shared_ptr<SSLContext> ssl_context_, const std::string &crt_path, const std::string &pem_path, const std::string &private_key_path) noexcept
|
SSLListener::SSLListener(std::shared_ptr<SSLContext> ssl_context_, const std::string &crt_path, const std::string &pem_path, const std::string &private_key_path) noexcept
|
||||||
: ssl_context(ssl_context_)
|
: ssl_context(std::move(ssl_context_))
|
||||||
{
|
{
|
||||||
//Initialise SSL objects required
|
//Initialise SSL objects required
|
||||||
listen_fd.fd = -1;
|
listen_fd.fd = -1;
|
||||||
@ -53,7 +54,7 @@ namespace fr
|
|||||||
|
|
||||||
//Apply them
|
//Apply them
|
||||||
mbedtls_ssl_conf_rng(&conf, mbedtls_ctr_drbg_random, &ssl_context->ctr_drbg);
|
mbedtls_ssl_conf_rng(&conf, mbedtls_ctr_drbg_random, &ssl_context->ctr_drbg);
|
||||||
mbedtls_ssl_conf_ca_chain(&conf, srvcert.next, NULL);
|
mbedtls_ssl_conf_ca_chain(&conf, srvcert.next, nullptr);
|
||||||
|
|
||||||
if((error = mbedtls_ssl_conf_own_cert(&conf, &srvcert, &pkey)) != 0)
|
if((error = mbedtls_ssl_conf_own_cert(&conf, &srvcert, &pkey)) != 0)
|
||||||
{
|
{
|
||||||
@ -68,6 +69,7 @@ namespace fr
|
|||||||
mbedtls_x509_crt_free(&srvcert);
|
mbedtls_x509_crt_free(&srvcert);
|
||||||
mbedtls_pk_free(&pkey);
|
mbedtls_pk_free(&pkey);
|
||||||
mbedtls_ssl_config_free(&conf);
|
mbedtls_ssl_config_free(&conf);
|
||||||
|
mbedtls_net_free(&listen_fd);
|
||||||
}
|
}
|
||||||
|
|
||||||
Socket::Status fr::SSLListener::listen(const std::string &port)
|
Socket::Status fr::SSLListener::listen(const std::string &port)
|
||||||
@ -90,42 +92,46 @@ namespace fr
|
|||||||
Socket::Status SSLListener::accept(Socket &client_)
|
Socket::Status SSLListener::accept(Socket &client_)
|
||||||
{
|
{
|
||||||
//Cast to SSLSocket. Will throw bad cast on failure.
|
//Cast to SSLSocket. Will throw bad cast on failure.
|
||||||
SSLSocket &client = dynamic_cast<SSLSocket&>(client_);
|
auto &client = dynamic_cast<SSLSocket&>(client_);
|
||||||
|
|
||||||
//Initialise mbedtls
|
//Initialise mbedtls
|
||||||
int error = 0;
|
int error = 0;
|
||||||
mbedtls_ssl_context *ssl = new mbedtls_ssl_context;
|
std::unique_ptr<mbedtls_ssl_context> ssl(new mbedtls_ssl_context);
|
||||||
client.set_ssl_context(std::unique_ptr<mbedtls_ssl_context>(ssl));
|
std::unique_ptr<mbedtls_net_context> client_fd(new mbedtls_net_context);
|
||||||
|
|
||||||
mbedtls_ssl_init(ssl);
|
mbedtls_ssl_init(ssl.get());
|
||||||
if((error = mbedtls_ssl_setup(ssl, &conf ) ) != 0)
|
mbedtls_net_init(client_fd.get());
|
||||||
|
auto free_contexts = [&](){mbedtls_ssl_free(ssl.get()); mbedtls_net_free(client_fd.get());};
|
||||||
|
if((error = mbedtls_ssl_setup(ssl.get(), &conf ) ) != 0)
|
||||||
{
|
{
|
||||||
std::cout << "Failed to apply SSL setings: " << error << std::endl;
|
std::cout << "Failed to apply SSL setings: " << error << std::endl;
|
||||||
|
free_contexts();
|
||||||
return Socket::Error;
|
return Socket::Error;
|
||||||
}
|
}
|
||||||
|
|
||||||
//Accept a connection
|
//Accept a connection
|
||||||
mbedtls_net_context *client_fd = new mbedtls_net_context;
|
if((error = mbedtls_net_accept(&listen_fd, client_fd.get(), nullptr, 0, nullptr)) != 0)
|
||||||
client.set_net_context(std::unique_ptr<mbedtls_net_context>(client_fd));
|
|
||||||
mbedtls_net_init(client_fd);
|
|
||||||
|
|
||||||
if((error = mbedtls_net_accept(&listen_fd, client_fd, NULL, 0, NULL)) != 0)
|
|
||||||
{
|
{
|
||||||
std::cout << "Accept error: " << error << std::endl;
|
std::cout << "Accept error: " << error << std::endl;
|
||||||
|
free_contexts();
|
||||||
return Socket::Error;
|
return Socket::Error;
|
||||||
}
|
}
|
||||||
|
|
||||||
mbedtls_ssl_set_bio(ssl, client_fd, mbedtls_net_send, mbedtls_net_recv, NULL);
|
mbedtls_ssl_set_bio(ssl.get(), client_fd.get(), mbedtls_net_send, mbedtls_net_recv, nullptr);
|
||||||
|
|
||||||
//SSL Handshake
|
//SSL Handshake
|
||||||
while((error = mbedtls_ssl_handshake(ssl)) != 0)
|
while((error = mbedtls_ssl_handshake(ssl.get())) != 0)
|
||||||
{
|
{
|
||||||
if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE)
|
if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE)
|
||||||
{
|
{
|
||||||
std::cout << "Handshake error: " << error << std::endl;
|
std::cout << "Handshake error: " << error << std::endl;
|
||||||
|
free_contexts();
|
||||||
return Socket::Status::HandshakeFailed;
|
return Socket::Status::HandshakeFailed;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
client.set_ssl_context(std::move(ssl));
|
||||||
|
client.set_net_context(std::move(client_fd));
|
||||||
return Socket::Success;
|
return Socket::Success;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#include "frnetlib/SSLSocket.h"
|
#include "frnetlib/SSLSocket.h"
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#ifdef SSL_ENABLED
|
#ifdef SSL_ENABLED
|
||||||
|
|
||||||
@ -12,7 +13,7 @@
|
|||||||
namespace fr
|
namespace fr
|
||||||
{
|
{
|
||||||
SSLSocket::SSLSocket(std::shared_ptr<SSLContext> ssl_context_) noexcept
|
SSLSocket::SSLSocket(std::shared_ptr<SSLContext> ssl_context_) noexcept
|
||||||
: ssl_context(ssl_context_)
|
: ssl_context(std::move(ssl_context_))
|
||||||
{
|
{
|
||||||
//Initialise mbedtls structures
|
//Initialise mbedtls structures
|
||||||
mbedtls_ssl_config_init(&conf);
|
mbedtls_ssl_config_init(&conf);
|
||||||
@ -30,13 +31,14 @@ namespace fr
|
|||||||
|
|
||||||
void SSLSocket::close_socket()
|
void SSLSocket::close_socket()
|
||||||
{
|
{
|
||||||
if(ssl_socket_descriptor && ssl_socket_descriptor->fd > -1)
|
|
||||||
{
|
|
||||||
if(ssl)
|
|
||||||
mbedtls_ssl_close_notify(ssl.get());
|
|
||||||
if(ssl_socket_descriptor)
|
if(ssl_socket_descriptor)
|
||||||
mbedtls_net_free(ssl_socket_descriptor.get());
|
mbedtls_net_free(ssl_socket_descriptor.get());
|
||||||
|
if(ssl)
|
||||||
|
{
|
||||||
|
mbedtls_ssl_close_notify(ssl.get());
|
||||||
|
mbedtls_ssl_free(ssl.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Socket::Status SSLSocket::send_raw(const char *data, size_t size)
|
Socket::Status SSLSocket::send_raw(const char *data, size_t size)
|
||||||
@ -84,8 +86,8 @@ namespace fr
|
|||||||
Socket::Status SSLSocket::connect(const std::string &address, const std::string &port)
|
Socket::Status SSLSocket::connect(const std::string &address, const std::string &port)
|
||||||
{
|
{
|
||||||
//Initialise mbedtls stuff
|
//Initialise mbedtls stuff
|
||||||
ssl = std::unique_ptr<mbedtls_ssl_context>(new mbedtls_ssl_context);
|
ssl = std::make_unique<mbedtls_ssl_context>();
|
||||||
ssl_socket_descriptor = std::unique_ptr<mbedtls_net_context>(new mbedtls_net_context);
|
ssl_socket_descriptor = std::make_unique<mbedtls_net_context>();
|
||||||
mbedtls_ssl_init(ssl.get());
|
mbedtls_ssl_init(ssl.get());
|
||||||
mbedtls_net_init(ssl_socket_descriptor.get());
|
mbedtls_net_init(ssl_socket_descriptor.get());
|
||||||
|
|
||||||
@ -103,7 +105,7 @@ namespace fr
|
|||||||
}
|
}
|
||||||
|
|
||||||
mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_REQUIRED);
|
mbedtls_ssl_conf_authmode(&conf, MBEDTLS_SSL_VERIFY_REQUIRED);
|
||||||
mbedtls_ssl_conf_ca_chain(&conf, &ssl_context->cacert, NULL);
|
mbedtls_ssl_conf_ca_chain(&conf, &ssl_context->cacert, nullptr);
|
||||||
mbedtls_ssl_conf_rng(&conf, mbedtls_ctr_drbg_random, &ssl_context->ctr_drbg);
|
mbedtls_ssl_conf_rng(&conf, mbedtls_ctr_drbg_random, &ssl_context->ctr_drbg);
|
||||||
|
|
||||||
if((error = mbedtls_ssl_setup(ssl.get(), &conf)) != 0)
|
if((error = mbedtls_ssl_setup(ssl.get(), &conf)) != 0)
|
||||||
@ -116,7 +118,7 @@ namespace fr
|
|||||||
return Socket::Status::Error;
|
return Socket::Status::Error;
|
||||||
}
|
}
|
||||||
|
|
||||||
mbedtls_ssl_set_bio(ssl.get(), ssl_socket_descriptor.get(), mbedtls_net_send, mbedtls_net_recv, NULL);
|
mbedtls_ssl_set_bio(ssl.get(), ssl_socket_descriptor.get(), mbedtls_net_send, mbedtls_net_recv, nullptr);
|
||||||
|
|
||||||
//Do SSL handshake
|
//Do SSL handshake
|
||||||
while((error = mbedtls_ssl_handshake(ssl.get())) != 0)
|
while((error = mbedtls_ssl_handshake(ssl.get())) != 0)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user