fr::SSLSocket bug fixes

Whilst receiving data, the SSL socket would not return if the remote socket disconnected, leading to it blocking indefinitely.

Whilst sending enough data to require multiple writes, the socket would disconnect instead of sending more. This has also been fixed.
This commit is contained in:
Fred Nicolson 2017-07-19 13:54:42 +01:00
parent ec95d0ac36
commit d3b51f75a5
2 changed files with 23 additions and 22 deletions

View File

@ -94,38 +94,38 @@ namespace fr
//Initialise mbedtls //Initialise mbedtls
int error = 0; int error = 0;
std::unique_ptr<mbedtls_ssl_context> ssl(new mbedtls_ssl_context); mbedtls_ssl_context *ssl = new mbedtls_ssl_context;
mbedtls_ssl_init(ssl.get()); client.set_ssl_context(std::unique_ptr<mbedtls_ssl_context>(ssl));
if((error = mbedtls_ssl_setup(ssl.get(), &conf ) ) != 0)
mbedtls_ssl_init(ssl);
if((error = mbedtls_ssl_setup(ssl, &conf ) ) != 0)
{ {
std::cout << "Failed to apply SSL setings: " << error << std::endl; std::cout << "Failed to apply SSL setings: " << error << std::endl;
return Socket::Error; return Socket::Error;
} }
//Accept a connection //Accept a connection
std::unique_ptr<mbedtls_net_context> client_fd(new mbedtls_net_context); mbedtls_net_context *client_fd = new mbedtls_net_context;
mbedtls_net_init(client_fd.get()); client.set_net_context(std::unique_ptr<mbedtls_net_context>(client_fd));
mbedtls_net_init(client_fd);
if((error = mbedtls_net_accept(&listen_fd, client_fd.get(), NULL, 0, NULL)) != 0) if((error = mbedtls_net_accept(&listen_fd, client_fd, NULL, 0, NULL)) != 0)
{ {
std::cout << "Accept error: " << error << std::endl; std::cout << "Accept error: " << error << std::endl;
return Socket::Error; return Socket::Error;
} }
mbedtls_ssl_set_bio(ssl.get(), client_fd.get(), mbedtls_net_send, mbedtls_net_recv, NULL); mbedtls_ssl_set_bio(ssl, client_fd, mbedtls_net_send, mbedtls_net_recv, NULL);
//SSL Handshake //SSL Handshake
while((error = mbedtls_ssl_handshake(ssl.get())) != 0) while((error = mbedtls_ssl_handshake(ssl)) != 0)
{ {
if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE) if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE)
{ {
std::cout << "Handshake error: " << error << std::endl;
return Socket::Status::HandshakeFailed; return Socket::Status::HandshakeFailed;
} }
} }
//Set socket details
client.set_net_context(std::move(client_fd));
client.set_ssl_context(std::move(ssl));
return Socket::Success; return Socket::Success;
} }

View File

@ -30,7 +30,7 @@ namespace fr
void SSLSocket::close_socket() void SSLSocket::close_socket()
{ {
if(ssl_socket_descriptor->fd > -1) if(ssl_socket_descriptor && ssl_socket_descriptor->fd > -1)
{ {
if(ssl) if(ssl)
mbedtls_ssl_close_notify(ssl.get()); mbedtls_ssl_close_notify(ssl.get());
@ -41,10 +41,16 @@ namespace fr
Socket::Status SSLSocket::send_raw(const char *data, size_t size) Socket::Status SSLSocket::send_raw(const char *data, size_t size)
{ {
int error = 0; int response = 0;
while((error = mbedtls_ssl_write(ssl.get(), (const unsigned char *)data, size)) <= 0) size_t data_sent = 0;
while(data_sent < size)
{ {
if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE) response = mbedtls_ssl_write(ssl.get(), (const unsigned char *)data + data_sent, size - data_sent);
if(response != MBEDTLS_ERR_SSL_WANT_READ && response != MBEDTLS_ERR_SSL_WANT_WRITE)
{
data_sent += response;
}
else if(response < 0)
{ {
close_socket(); close_socket();
return Socket::Status::Disconnected; return Socket::Status::Disconnected;
@ -64,16 +70,11 @@ namespace fr
read = mbedtls_ssl_read(ssl.get(), (unsigned char *)data, data_size); read = mbedtls_ssl_read(ssl.get(), (unsigned char *)data, data_size);
} }
if(read == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) if(read == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY || read <= 0)
{ {
close_socket(); close_socket();
return Socket::Status::Disconnected; return Socket::Status::Disconnected;
} }
else if(read <= 0)
{
//No data. But no error occurred.
return Socket::Status::Success;
}
received += read; received += read;
return Socket::Status::Success; return Socket::Status::Success;