From d3b51f75a541e8739fb86f14349d92a62837f7aa Mon Sep 17 00:00:00 2001 From: Fred Nicolson Date: Wed, 19 Jul 2017 13:54:42 +0100 Subject: [PATCH] fr::SSLSocket bug fixes Whilst receiving data, the SSL socket would not return if the remote socket disconnected, leading to it blocking indefinitely. Whilst sending enough data to require multiple writes, the socket would disconnect instead of sending more. This has also been fixed. --- src/SSLListener.cpp | 24 ++++++++++++------------ src/SSLSocket.cpp | 21 +++++++++++---------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/SSLListener.cpp b/src/SSLListener.cpp index cd71174..6503d62 100644 --- a/src/SSLListener.cpp +++ b/src/SSLListener.cpp @@ -94,38 +94,38 @@ namespace fr //Initialise mbedtls int error = 0; - std::unique_ptr ssl(new mbedtls_ssl_context); - mbedtls_ssl_init(ssl.get()); - if((error = mbedtls_ssl_setup(ssl.get(), &conf ) ) != 0) + mbedtls_ssl_context *ssl = new mbedtls_ssl_context; + client.set_ssl_context(std::unique_ptr(ssl)); + + mbedtls_ssl_init(ssl); + if((error = mbedtls_ssl_setup(ssl, &conf ) ) != 0) { std::cout << "Failed to apply SSL setings: " << error << std::endl; return Socket::Error; } //Accept a connection - std::unique_ptr client_fd(new mbedtls_net_context); - mbedtls_net_init(client_fd.get()); + mbedtls_net_context *client_fd = new mbedtls_net_context; + client.set_net_context(std::unique_ptr(client_fd)); + mbedtls_net_init(client_fd); - if((error = mbedtls_net_accept(&listen_fd, client_fd.get(), NULL, 0, NULL)) != 0) + if((error = mbedtls_net_accept(&listen_fd, client_fd, NULL, 0, NULL)) != 0) { std::cout << "Accept error: " << error << std::endl; return Socket::Error; } - mbedtls_ssl_set_bio(ssl.get(), client_fd.get(), mbedtls_net_send, mbedtls_net_recv, NULL); + mbedtls_ssl_set_bio(ssl, client_fd, mbedtls_net_send, mbedtls_net_recv, NULL); //SSL Handshake - while((error = mbedtls_ssl_handshake(ssl.get())) != 0) + while((error = mbedtls_ssl_handshake(ssl)) != 0) { if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE) { + std::cout << "Handshake error: " << error << std::endl; return Socket::Status::HandshakeFailed; } } - - //Set socket details - client.set_net_context(std::move(client_fd)); - client.set_ssl_context(std::move(ssl)); return Socket::Success; } diff --git a/src/SSLSocket.cpp b/src/SSLSocket.cpp index cae2aed..23d6d66 100644 --- a/src/SSLSocket.cpp +++ b/src/SSLSocket.cpp @@ -30,7 +30,7 @@ namespace fr void SSLSocket::close_socket() { - if(ssl_socket_descriptor->fd > -1) + if(ssl_socket_descriptor && ssl_socket_descriptor->fd > -1) { if(ssl) mbedtls_ssl_close_notify(ssl.get()); @@ -41,10 +41,16 @@ namespace fr Socket::Status SSLSocket::send_raw(const char *data, size_t size) { - int error = 0; - while((error = mbedtls_ssl_write(ssl.get(), (const unsigned char *)data, size)) <= 0) + int response = 0; + size_t data_sent = 0; + while(data_sent < size) { - if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE) + response = mbedtls_ssl_write(ssl.get(), (const unsigned char *)data + data_sent, size - data_sent); + if(response != MBEDTLS_ERR_SSL_WANT_READ && response != MBEDTLS_ERR_SSL_WANT_WRITE) + { + data_sent += response; + } + else if(response < 0) { close_socket(); return Socket::Status::Disconnected; @@ -64,16 +70,11 @@ namespace fr read = mbedtls_ssl_read(ssl.get(), (unsigned char *)data, data_size); } - if(read == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) + if(read == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY || read <= 0) { close_socket(); return Socket::Status::Disconnected; } - else if(read <= 0) - { - //No data. But no error occurred. - return Socket::Status::Success; - } received += read; return Socket::Status::Success;