From decb0b10f94cfc898e191844f1b73cda54a0b13f Mon Sep 17 00:00:00 2001 From: Fred Nicolson Date: Mon, 13 Aug 2018 12:35:52 +0100 Subject: [PATCH] Fixed broken get_remote_address() for sockets accepted over SSL + A few other correctness fixes. --- include/frnetlib/SSLListener.h | 6 +++--- include/frnetlib/version.h | 6 +++--- src/SSLListener.cpp | 17 ++++++++++------- src/SSLSocket.cpp | 4 +--- src/Sha1.cpp | 4 ++-- src/TcpListener.cpp | 6 ++++-- 6 files changed, 23 insertions(+), 20 deletions(-) diff --git a/include/frnetlib/SSLListener.h b/include/frnetlib/SSLListener.h index 2b04702..f4d20dd 100644 --- a/include/frnetlib/SSLListener.h +++ b/include/frnetlib/SSLListener.h @@ -23,7 +23,7 @@ namespace fr { public: explicit SSLListener(std::shared_ptr ssl_context, const std::string &pem_path, const std::string &private_key_path); - virtual ~SSLListener() noexcept; + ~SSLListener() override; SSLListener(SSLListener &&) = delete; SSLListener(SSLListener &o) = delete; void operator=(const SSLListener &) = delete; @@ -35,7 +35,7 @@ namespace fr * @param port The port to bind to * @return If the operation was successful */ - virtual Socket::Status listen(const std::string &port) override; + Socket::Status listen(const std::string &port) override; /*! * Accepts a new connection. @@ -43,7 +43,7 @@ namespace fr * @param client Where to store the connection information * @return True on success. False on failure. */ - virtual Socket::Status accept(Socket &client) override; + Socket::Status accept(Socket &client) override; /*! * Closes the socket diff --git a/include/frnetlib/version.h b/include/frnetlib/version.h index 5d7c9ac..27f89f2 100644 --- a/include/frnetlib/version.h +++ b/include/frnetlib/version.h @@ -9,10 +9,10 @@ #define FRNETLIB_VERSION_MAJOR 1 #define FRNETLIB_VERSION_MINOR 0 -#define FRNETLIB_VERSION_PATCH 1 +#define FRNETLIB_VERSION_PATCH 2 #define FRNETLIB_VERSION_NUMBER (FRNETLIB_VERSION_MAJOR * 100*100 + FRNETLIB_VERSION_MINOR * 100 + FRNETLIB_VERSION_PATCH) -#define FRNETLIB_VERSION_STRING "1.0.1" -#define FRNETLIB_VERSION_STRING_FULL "frnetlib 1.0.1" +#define FRNETLIB_VERSION_STRING "1.0.2" +#define FRNETLIB_VERSION_STRING_FULL "frnetlib 1.0.2" #endif //FRNETLIB_VERSION_H diff --git a/src/SSLListener.cpp b/src/SSLListener.cpp index 5b0a177..bfcab87 100644 --- a/src/SSLListener.cpp +++ b/src/SSLListener.cpp @@ -85,19 +85,18 @@ namespace fr Socket::Status SSLListener::accept(Socket &client_) { //Cast to SSLSocket. Will throw bad cast on failure. - SSLSocket &client = dynamic_cast(client_); + auto &client = dynamic_cast(client_); //Initialise mbedtls int error = 0; - std::unique_ptr ssl(new mbedtls_ssl_context); - std::unique_ptr client_fd(new mbedtls_net_context); + auto ssl = std::make_unique(); + auto client_fd = std::make_unique(); mbedtls_ssl_init(ssl.get()); mbedtls_net_init(client_fd.get()); auto free_contexts = [&](){mbedtls_ssl_free(ssl.get()); mbedtls_net_free(client_fd.get());}; - if((error = mbedtls_ssl_setup(ssl.get(), &conf ) ) != 0) + if((error = mbedtls_ssl_setup(ssl.get(), &conf)) != 0) { - std::cout << "Failed to apply SSL setings: " << error << std::endl; free_contexts(); return Socket::Error; } @@ -111,9 +110,9 @@ namespace fr return Socket::Error; } - mbedtls_ssl_set_bio(ssl.get(), client_fd.get(), mbedtls_net_send, mbedtls_net_recv, nullptr); //SSL Handshake + mbedtls_ssl_set_bio(ssl.get(), client_fd.get(), mbedtls_net_send, mbedtls_net_recv, nullptr); while((error = mbedtls_ssl_handshake(ssl.get())) != 0) { if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE) @@ -128,12 +127,16 @@ namespace fr //Get printable address. If we failed then set it as just 'unknown' char client_printable_addr[INET6_ADDRSTRLEN]; struct sockaddr_storage socket_address{}; - socklen_t socket_length; + socklen_t socket_length = sizeof(socket_address); error = getpeername(client_fd->fd, (struct sockaddr*)&socket_address, &socket_length); if(error == 0) + { error = getnameinfo((sockaddr*)&socket_address, socket_length, client_printable_addr, sizeof(client_printable_addr), nullptr,0,NI_NUMERICHOST); + } if(error != 0) + { strcpy(client_printable_addr, "unknown"); + } client.set_ssl_context(std::move(ssl)); client.set_descriptor(client_fd.release()); diff --git a/src/SSLSocket.cpp b/src/SSLSocket.cpp index 2181b53..259347c 100644 --- a/src/SSLSocket.cpp +++ b/src/SSLSocket.cpp @@ -128,7 +128,6 @@ namespace fr { if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE) { - std::cout << "Failed to connect to server. Handshake returned: " << error << std::endl; return Socket::Status::HandshakeFailed; } } @@ -137,9 +136,8 @@ namespace fr if(should_verify && ((flags = mbedtls_ssl_get_verify_result(ssl.get())) != 0)) { char verify_buffer[512]; - mbedtls_x509_crt_verify_info( verify_buffer, sizeof( verify_buffer ), " ! ", flags ); + mbedtls_x509_crt_verify_info(verify_buffer, sizeof(verify_buffer), " ! ", flags); - std::cout << "Failed to connect to server. Server certificate validation failed: " << verify_buffer << std::endl; return Socket::Status::VerificationFailed; } diff --git a/src/Sha1.cpp b/src/Sha1.cpp index 2778eaa..c88ca08 100644 --- a/src/Sha1.cpp +++ b/src/Sha1.cpp @@ -273,11 +273,11 @@ namespace fr uint64_t total_bits = (transforms * BLOCK_BYTES + buffer.size()) * 8; /* Padding */ - buffer += 0x80; + buffer += static_cast(0x80); size_t orig_size = buffer.size(); while(buffer.size() < BLOCK_BYTES) { - buffer += (char) 0x00; + buffer += static_cast(0x00); } uint32_t block[BLOCK_INTS]; diff --git a/src/TcpListener.cpp b/src/TcpListener.cpp index 249e0aa..cd4e29d 100644 --- a/src/TcpListener.cpp +++ b/src/TcpListener.cpp @@ -86,7 +86,7 @@ namespace fr Socket::Status TcpListener::accept(Socket &client_) { //Cast to TcpSocket. Will throw bad cast on failure. - TcpSocket &client = dynamic_cast(client_); + auto &client = dynamic_cast(client_); //Prepare to wait for the client sockaddr_storage client_addr{}; @@ -100,9 +100,11 @@ namespace fr return Socket::Unknown; //Get printable address. If we failed then set it as just 'unknown' - int err = getnameinfo((sockaddr*)&client_addr, client_addr_len, client_printable_addr, sizeof(client_printable_addr), nullptr,0,NI_NUMERICHOST); + int err = getnameinfo((sockaddr*)&client_addr, client_addr_len, client_printable_addr, sizeof(client_printable_addr), nullptr, 0, NI_NUMERICHOST); if(err != 0) + { strcpy(client_printable_addr, "unknown"); + } //Set client data client.set_descriptor(&client_descriptor);