fr::SSLSocket bug fixes
Whilst receiving data, the SSL socket would not return if the remote socket disconnected, leading to it blocking indefinitely. Whilst sending enough data to require multiple writes, the socket would disconnect instead of sending more. This has also been fixed.
This commit is contained in:
parent
ec95d0ac36
commit
d3b51f75a5
@ -94,38 +94,38 @@ namespace fr
|
|||||||
|
|
||||||
//Initialise mbedtls
|
//Initialise mbedtls
|
||||||
int error = 0;
|
int error = 0;
|
||||||
std::unique_ptr<mbedtls_ssl_context> ssl(new mbedtls_ssl_context);
|
mbedtls_ssl_context *ssl = new mbedtls_ssl_context;
|
||||||
mbedtls_ssl_init(ssl.get());
|
client.set_ssl_context(std::unique_ptr<mbedtls_ssl_context>(ssl));
|
||||||
if((error = mbedtls_ssl_setup(ssl.get(), &conf ) ) != 0)
|
|
||||||
|
mbedtls_ssl_init(ssl);
|
||||||
|
if((error = mbedtls_ssl_setup(ssl, &conf ) ) != 0)
|
||||||
{
|
{
|
||||||
std::cout << "Failed to apply SSL setings: " << error << std::endl;
|
std::cout << "Failed to apply SSL setings: " << error << std::endl;
|
||||||
return Socket::Error;
|
return Socket::Error;
|
||||||
}
|
}
|
||||||
|
|
||||||
//Accept a connection
|
//Accept a connection
|
||||||
std::unique_ptr<mbedtls_net_context> client_fd(new mbedtls_net_context);
|
mbedtls_net_context *client_fd = new mbedtls_net_context;
|
||||||
mbedtls_net_init(client_fd.get());
|
client.set_net_context(std::unique_ptr<mbedtls_net_context>(client_fd));
|
||||||
|
mbedtls_net_init(client_fd);
|
||||||
|
|
||||||
if((error = mbedtls_net_accept(&listen_fd, client_fd.get(), NULL, 0, NULL)) != 0)
|
if((error = mbedtls_net_accept(&listen_fd, client_fd, NULL, 0, NULL)) != 0)
|
||||||
{
|
{
|
||||||
std::cout << "Accept error: " << error << std::endl;
|
std::cout << "Accept error: " << error << std::endl;
|
||||||
return Socket::Error;
|
return Socket::Error;
|
||||||
}
|
}
|
||||||
|
|
||||||
mbedtls_ssl_set_bio(ssl.get(), client_fd.get(), mbedtls_net_send, mbedtls_net_recv, NULL);
|
mbedtls_ssl_set_bio(ssl, client_fd, mbedtls_net_send, mbedtls_net_recv, NULL);
|
||||||
|
|
||||||
//SSL Handshake
|
//SSL Handshake
|
||||||
while((error = mbedtls_ssl_handshake(ssl.get())) != 0)
|
while((error = mbedtls_ssl_handshake(ssl)) != 0)
|
||||||
{
|
{
|
||||||
if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE)
|
if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE)
|
||||||
{
|
{
|
||||||
|
std::cout << "Handshake error: " << error << std::endl;
|
||||||
return Socket::Status::HandshakeFailed;
|
return Socket::Status::HandshakeFailed;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//Set socket details
|
|
||||||
client.set_net_context(std::move(client_fd));
|
|
||||||
client.set_ssl_context(std::move(ssl));
|
|
||||||
return Socket::Success;
|
return Socket::Success;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -30,7 +30,7 @@ namespace fr
|
|||||||
|
|
||||||
void SSLSocket::close_socket()
|
void SSLSocket::close_socket()
|
||||||
{
|
{
|
||||||
if(ssl_socket_descriptor->fd > -1)
|
if(ssl_socket_descriptor && ssl_socket_descriptor->fd > -1)
|
||||||
{
|
{
|
||||||
if(ssl)
|
if(ssl)
|
||||||
mbedtls_ssl_close_notify(ssl.get());
|
mbedtls_ssl_close_notify(ssl.get());
|
||||||
@ -41,10 +41,16 @@ namespace fr
|
|||||||
|
|
||||||
Socket::Status SSLSocket::send_raw(const char *data, size_t size)
|
Socket::Status SSLSocket::send_raw(const char *data, size_t size)
|
||||||
{
|
{
|
||||||
int error = 0;
|
int response = 0;
|
||||||
while((error = mbedtls_ssl_write(ssl.get(), (const unsigned char *)data, size)) <= 0)
|
size_t data_sent = 0;
|
||||||
|
while(data_sent < size)
|
||||||
{
|
{
|
||||||
if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE)
|
response = mbedtls_ssl_write(ssl.get(), (const unsigned char *)data + data_sent, size - data_sent);
|
||||||
|
if(response != MBEDTLS_ERR_SSL_WANT_READ && response != MBEDTLS_ERR_SSL_WANT_WRITE)
|
||||||
|
{
|
||||||
|
data_sent += response;
|
||||||
|
}
|
||||||
|
else if(response < 0)
|
||||||
{
|
{
|
||||||
close_socket();
|
close_socket();
|
||||||
return Socket::Status::Disconnected;
|
return Socket::Status::Disconnected;
|
||||||
@ -64,16 +70,11 @@ namespace fr
|
|||||||
read = mbedtls_ssl_read(ssl.get(), (unsigned char *)data, data_size);
|
read = mbedtls_ssl_read(ssl.get(), (unsigned char *)data, data_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(read == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)
|
if(read == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY || read <= 0)
|
||||||
{
|
{
|
||||||
close_socket();
|
close_socket();
|
||||||
return Socket::Status::Disconnected;
|
return Socket::Status::Disconnected;
|
||||||
}
|
}
|
||||||
else if(read <= 0)
|
|
||||||
{
|
|
||||||
//No data. But no error occurred.
|
|
||||||
return Socket::Status::Success;
|
|
||||||
}
|
|
||||||
|
|
||||||
received += read;
|
received += read;
|
||||||
return Socket::Status::Success;
|
return Socket::Status::Success;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user