diff --git a/src/SSLListener.cpp b/src/SSLListener.cpp index cd71174..6503d62 100644 --- a/src/SSLListener.cpp +++ b/src/SSLListener.cpp @@ -94,38 +94,38 @@ namespace fr //Initialise mbedtls int error = 0; - std::unique_ptr ssl(new mbedtls_ssl_context); - mbedtls_ssl_init(ssl.get()); - if((error = mbedtls_ssl_setup(ssl.get(), &conf ) ) != 0) + mbedtls_ssl_context *ssl = new mbedtls_ssl_context; + client.set_ssl_context(std::unique_ptr(ssl)); + + mbedtls_ssl_init(ssl); + if((error = mbedtls_ssl_setup(ssl, &conf ) ) != 0) { std::cout << "Failed to apply SSL setings: " << error << std::endl; return Socket::Error; } //Accept a connection - std::unique_ptr client_fd(new mbedtls_net_context); - mbedtls_net_init(client_fd.get()); + mbedtls_net_context *client_fd = new mbedtls_net_context; + client.set_net_context(std::unique_ptr(client_fd)); + mbedtls_net_init(client_fd); - if((error = mbedtls_net_accept(&listen_fd, client_fd.get(), NULL, 0, NULL)) != 0) + if((error = mbedtls_net_accept(&listen_fd, client_fd, NULL, 0, NULL)) != 0) { std::cout << "Accept error: " << error << std::endl; return Socket::Error; } - mbedtls_ssl_set_bio(ssl.get(), client_fd.get(), mbedtls_net_send, mbedtls_net_recv, NULL); + mbedtls_ssl_set_bio(ssl, client_fd, mbedtls_net_send, mbedtls_net_recv, NULL); //SSL Handshake - while((error = mbedtls_ssl_handshake(ssl.get())) != 0) + while((error = mbedtls_ssl_handshake(ssl)) != 0) { if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE) { + std::cout << "Handshake error: " << error << std::endl; return Socket::Status::HandshakeFailed; } } - - //Set socket details - client.set_net_context(std::move(client_fd)); - client.set_ssl_context(std::move(ssl)); return Socket::Success; } diff --git a/src/SSLSocket.cpp b/src/SSLSocket.cpp index cae2aed..23d6d66 100644 --- a/src/SSLSocket.cpp +++ b/src/SSLSocket.cpp @@ -30,7 +30,7 @@ namespace fr void SSLSocket::close_socket() { - if(ssl_socket_descriptor->fd > -1) + if(ssl_socket_descriptor && ssl_socket_descriptor->fd > -1) { if(ssl) mbedtls_ssl_close_notify(ssl.get()); @@ -41,10 +41,16 @@ namespace fr Socket::Status SSLSocket::send_raw(const char *data, size_t size) { - int error = 0; - while((error = mbedtls_ssl_write(ssl.get(), (const unsigned char *)data, size)) <= 0) + int response = 0; + size_t data_sent = 0; + while(data_sent < size) { - if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE) + response = mbedtls_ssl_write(ssl.get(), (const unsigned char *)data + data_sent, size - data_sent); + if(response != MBEDTLS_ERR_SSL_WANT_READ && response != MBEDTLS_ERR_SSL_WANT_WRITE) + { + data_sent += response; + } + else if(response < 0) { close_socket(); return Socket::Status::Disconnected; @@ -64,16 +70,11 @@ namespace fr read = mbedtls_ssl_read(ssl.get(), (unsigned char *)data, data_size); } - if(read == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) + if(read == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY || read <= 0) { close_socket(); return Socket::Status::Disconnected; } - else if(read <= 0) - { - //No data. But no error occurred. - return Socket::Status::Success; - } received += read; return Socket::Status::Success;