Fixed broken get_remote_address() for sockets accepted over SSL

+ A few other correctness fixes.
This commit is contained in:
Fred Nicolson 2018-08-13 12:35:52 +01:00
parent db738e9503
commit decb0b10f9
6 changed files with 23 additions and 20 deletions

View File

@ -23,7 +23,7 @@ namespace fr
{ {
public: public:
explicit SSLListener(std::shared_ptr<SSLContext> ssl_context, const std::string &pem_path, const std::string &private_key_path); explicit SSLListener(std::shared_ptr<SSLContext> ssl_context, const std::string &pem_path, const std::string &private_key_path);
virtual ~SSLListener() noexcept; ~SSLListener() override;
SSLListener(SSLListener &&) = delete; SSLListener(SSLListener &&) = delete;
SSLListener(SSLListener &o) = delete; SSLListener(SSLListener &o) = delete;
void operator=(const SSLListener &) = delete; void operator=(const SSLListener &) = delete;
@ -35,7 +35,7 @@ namespace fr
* @param port The port to bind to * @param port The port to bind to
* @return If the operation was successful * @return If the operation was successful
*/ */
virtual Socket::Status listen(const std::string &port) override; Socket::Status listen(const std::string &port) override;
/*! /*!
* Accepts a new connection. * Accepts a new connection.
@ -43,7 +43,7 @@ namespace fr
* @param client Where to store the connection information * @param client Where to store the connection information
* @return True on success. False on failure. * @return True on success. False on failure.
*/ */
virtual Socket::Status accept(Socket &client) override; Socket::Status accept(Socket &client) override;
/*! /*!
* Closes the socket * Closes the socket

View File

@ -9,10 +9,10 @@
#define FRNETLIB_VERSION_MAJOR 1 #define FRNETLIB_VERSION_MAJOR 1
#define FRNETLIB_VERSION_MINOR 0 #define FRNETLIB_VERSION_MINOR 0
#define FRNETLIB_VERSION_PATCH 1 #define FRNETLIB_VERSION_PATCH 2
#define FRNETLIB_VERSION_NUMBER (FRNETLIB_VERSION_MAJOR * 100*100 + FRNETLIB_VERSION_MINOR * 100 + FRNETLIB_VERSION_PATCH) #define FRNETLIB_VERSION_NUMBER (FRNETLIB_VERSION_MAJOR * 100*100 + FRNETLIB_VERSION_MINOR * 100 + FRNETLIB_VERSION_PATCH)
#define FRNETLIB_VERSION_STRING "1.0.1" #define FRNETLIB_VERSION_STRING "1.0.2"
#define FRNETLIB_VERSION_STRING_FULL "frnetlib 1.0.1" #define FRNETLIB_VERSION_STRING_FULL "frnetlib 1.0.2"
#endif //FRNETLIB_VERSION_H #endif //FRNETLIB_VERSION_H

View File

@ -85,19 +85,18 @@ namespace fr
Socket::Status SSLListener::accept(Socket &client_) Socket::Status SSLListener::accept(Socket &client_)
{ {
//Cast to SSLSocket. Will throw bad cast on failure. //Cast to SSLSocket. Will throw bad cast on failure.
SSLSocket &client = dynamic_cast<SSLSocket&>(client_); auto &client = dynamic_cast<SSLSocket&>(client_);
//Initialise mbedtls //Initialise mbedtls
int error = 0; int error = 0;
std::unique_ptr<mbedtls_ssl_context> ssl(new mbedtls_ssl_context); auto ssl = std::make_unique<mbedtls_ssl_context>();
std::unique_ptr<mbedtls_net_context> client_fd(new mbedtls_net_context); auto client_fd = std::make_unique<mbedtls_net_context>();
mbedtls_ssl_init(ssl.get()); mbedtls_ssl_init(ssl.get());
mbedtls_net_init(client_fd.get()); mbedtls_net_init(client_fd.get());
auto free_contexts = [&](){mbedtls_ssl_free(ssl.get()); mbedtls_net_free(client_fd.get());}; auto free_contexts = [&](){mbedtls_ssl_free(ssl.get()); mbedtls_net_free(client_fd.get());};
if((error = mbedtls_ssl_setup(ssl.get(), &conf ) ) != 0) if((error = mbedtls_ssl_setup(ssl.get(), &conf)) != 0)
{ {
std::cout << "Failed to apply SSL setings: " << error << std::endl;
free_contexts(); free_contexts();
return Socket::Error; return Socket::Error;
} }
@ -111,9 +110,9 @@ namespace fr
return Socket::Error; return Socket::Error;
} }
mbedtls_ssl_set_bio(ssl.get(), client_fd.get(), mbedtls_net_send, mbedtls_net_recv, nullptr);
//SSL Handshake //SSL Handshake
mbedtls_ssl_set_bio(ssl.get(), client_fd.get(), mbedtls_net_send, mbedtls_net_recv, nullptr);
while((error = mbedtls_ssl_handshake(ssl.get())) != 0) while((error = mbedtls_ssl_handshake(ssl.get())) != 0)
{ {
if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE) if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE)
@ -128,12 +127,16 @@ namespace fr
//Get printable address. If we failed then set it as just 'unknown' //Get printable address. If we failed then set it as just 'unknown'
char client_printable_addr[INET6_ADDRSTRLEN]; char client_printable_addr[INET6_ADDRSTRLEN];
struct sockaddr_storage socket_address{}; struct sockaddr_storage socket_address{};
socklen_t socket_length; socklen_t socket_length = sizeof(socket_address);
error = getpeername(client_fd->fd, (struct sockaddr*)&socket_address, &socket_length); error = getpeername(client_fd->fd, (struct sockaddr*)&socket_address, &socket_length);
if(error == 0) if(error == 0)
{
error = getnameinfo((sockaddr*)&socket_address, socket_length, client_printable_addr, sizeof(client_printable_addr), nullptr,0,NI_NUMERICHOST); error = getnameinfo((sockaddr*)&socket_address, socket_length, client_printable_addr, sizeof(client_printable_addr), nullptr,0,NI_NUMERICHOST);
}
if(error != 0) if(error != 0)
{
strcpy(client_printable_addr, "unknown"); strcpy(client_printable_addr, "unknown");
}
client.set_ssl_context(std::move(ssl)); client.set_ssl_context(std::move(ssl));
client.set_descriptor(client_fd.release()); client.set_descriptor(client_fd.release());

View File

@ -128,7 +128,6 @@ namespace fr
{ {
if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE) if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE)
{ {
std::cout << "Failed to connect to server. Handshake returned: " << error << std::endl;
return Socket::Status::HandshakeFailed; return Socket::Status::HandshakeFailed;
} }
} }
@ -137,9 +136,8 @@ namespace fr
if(should_verify && ((flags = mbedtls_ssl_get_verify_result(ssl.get())) != 0)) if(should_verify && ((flags = mbedtls_ssl_get_verify_result(ssl.get())) != 0))
{ {
char verify_buffer[512]; char verify_buffer[512];
mbedtls_x509_crt_verify_info( verify_buffer, sizeof( verify_buffer ), " ! ", flags ); mbedtls_x509_crt_verify_info(verify_buffer, sizeof(verify_buffer), " ! ", flags);
std::cout << "Failed to connect to server. Server certificate validation failed: " << verify_buffer << std::endl;
return Socket::Status::VerificationFailed; return Socket::Status::VerificationFailed;
} }

View File

@ -273,11 +273,11 @@ namespace fr
uint64_t total_bits = (transforms * BLOCK_BYTES + buffer.size()) * 8; uint64_t total_bits = (transforms * BLOCK_BYTES + buffer.size()) * 8;
/* Padding */ /* Padding */
buffer += 0x80; buffer += static_cast<char>(0x80);
size_t orig_size = buffer.size(); size_t orig_size = buffer.size();
while(buffer.size() < BLOCK_BYTES) while(buffer.size() < BLOCK_BYTES)
{ {
buffer += (char) 0x00; buffer += static_cast<char>(0x00);
} }
uint32_t block[BLOCK_INTS]; uint32_t block[BLOCK_INTS];

View File

@ -86,7 +86,7 @@ namespace fr
Socket::Status TcpListener::accept(Socket &client_) Socket::Status TcpListener::accept(Socket &client_)
{ {
//Cast to TcpSocket. Will throw bad cast on failure. //Cast to TcpSocket. Will throw bad cast on failure.
TcpSocket &client = dynamic_cast<TcpSocket&>(client_); auto &client = dynamic_cast<TcpSocket&>(client_);
//Prepare to wait for the client //Prepare to wait for the client
sockaddr_storage client_addr{}; sockaddr_storage client_addr{};
@ -100,9 +100,11 @@ namespace fr
return Socket::Unknown; return Socket::Unknown;
//Get printable address. If we failed then set it as just 'unknown' //Get printable address. If we failed then set it as just 'unknown'
int err = getnameinfo((sockaddr*)&client_addr, client_addr_len, client_printable_addr, sizeof(client_printable_addr), nullptr,0,NI_NUMERICHOST); int err = getnameinfo((sockaddr*)&client_addr, client_addr_len, client_printable_addr, sizeof(client_printable_addr), nullptr, 0, NI_NUMERICHOST);
if(err != 0) if(err != 0)
{
strcpy(client_printable_addr, "unknown"); strcpy(client_printable_addr, "unknown");
}
//Set client data //Set client data
client.set_descriptor(&client_descriptor); client.set_descriptor(&client_descriptor);