fr::SSLSocket bug fixes

Whilst receiving data, the SSL socket would not return if the remote socket disconnected, leading to it blocking indefinitely.

Whilst sending enough data to require multiple writes, the socket would disconnect instead of sending more. This has also been fixed.
This commit is contained in:
Fred Nicolson 2017-07-19 13:54:42 +01:00
parent ec95d0ac36
commit d3b51f75a5
2 changed files with 23 additions and 22 deletions

View File

@ -94,38 +94,38 @@ namespace fr
//Initialise mbedtls
int error = 0;
std::unique_ptr<mbedtls_ssl_context> ssl(new mbedtls_ssl_context);
mbedtls_ssl_init(ssl.get());
if((error = mbedtls_ssl_setup(ssl.get(), &conf ) ) != 0)
mbedtls_ssl_context *ssl = new mbedtls_ssl_context;
client.set_ssl_context(std::unique_ptr<mbedtls_ssl_context>(ssl));
mbedtls_ssl_init(ssl);
if((error = mbedtls_ssl_setup(ssl, &conf ) ) != 0)
{
std::cout << "Failed to apply SSL setings: " << error << std::endl;
return Socket::Error;
}
//Accept a connection
std::unique_ptr<mbedtls_net_context> client_fd(new mbedtls_net_context);
mbedtls_net_init(client_fd.get());
mbedtls_net_context *client_fd = new mbedtls_net_context;
client.set_net_context(std::unique_ptr<mbedtls_net_context>(client_fd));
mbedtls_net_init(client_fd);
if((error = mbedtls_net_accept(&listen_fd, client_fd.get(), NULL, 0, NULL)) != 0)
if((error = mbedtls_net_accept(&listen_fd, client_fd, NULL, 0, NULL)) != 0)
{
std::cout << "Accept error: " << error << std::endl;
return Socket::Error;
}
mbedtls_ssl_set_bio(ssl.get(), client_fd.get(), mbedtls_net_send, mbedtls_net_recv, NULL);
mbedtls_ssl_set_bio(ssl, client_fd, mbedtls_net_send, mbedtls_net_recv, NULL);
//SSL Handshake
while((error = mbedtls_ssl_handshake(ssl.get())) != 0)
while((error = mbedtls_ssl_handshake(ssl)) != 0)
{
if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE)
{
std::cout << "Handshake error: " << error << std::endl;
return Socket::Status::HandshakeFailed;
}
}
//Set socket details
client.set_net_context(std::move(client_fd));
client.set_ssl_context(std::move(ssl));
return Socket::Success;
}

View File

@ -30,7 +30,7 @@ namespace fr
void SSLSocket::close_socket()
{
if(ssl_socket_descriptor->fd > -1)
if(ssl_socket_descriptor && ssl_socket_descriptor->fd > -1)
{
if(ssl)
mbedtls_ssl_close_notify(ssl.get());
@ -41,10 +41,16 @@ namespace fr
Socket::Status SSLSocket::send_raw(const char *data, size_t size)
{
int error = 0;
while((error = mbedtls_ssl_write(ssl.get(), (const unsigned char *)data, size)) <= 0)
int response = 0;
size_t data_sent = 0;
while(data_sent < size)
{
if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE)
response = mbedtls_ssl_write(ssl.get(), (const unsigned char *)data + data_sent, size - data_sent);
if(response != MBEDTLS_ERR_SSL_WANT_READ && response != MBEDTLS_ERR_SSL_WANT_WRITE)
{
data_sent += response;
}
else if(response < 0)
{
close_socket();
return Socket::Status::Disconnected;
@ -64,16 +70,11 @@ namespace fr
read = mbedtls_ssl_read(ssl.get(), (unsigned char *)data, data_size);
}
if(read == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)
if(read == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY || read <= 0)
{
close_socket();
return Socket::Status::Disconnected;
}
else if(read <= 0)
{
//No data. But no error occurred.
return Socket::Status::Success;
}
received += read;
return Socket::Status::Success;