Fixed broken get_remote_address() for sockets accepted over SSL

+ A few other correctness fixes.
This commit is contained in:
Fred Nicolson 2018-08-13 12:35:52 +01:00
parent db738e9503
commit decb0b10f9
6 changed files with 23 additions and 20 deletions

View File

@ -23,7 +23,7 @@ namespace fr
{
public:
explicit SSLListener(std::shared_ptr<SSLContext> ssl_context, const std::string &pem_path, const std::string &private_key_path);
virtual ~SSLListener() noexcept;
~SSLListener() override;
SSLListener(SSLListener &&) = delete;
SSLListener(SSLListener &o) = delete;
void operator=(const SSLListener &) = delete;
@ -35,7 +35,7 @@ namespace fr
* @param port The port to bind to
* @return If the operation was successful
*/
virtual Socket::Status listen(const std::string &port) override;
Socket::Status listen(const std::string &port) override;
/*!
* Accepts a new connection.
@ -43,7 +43,7 @@ namespace fr
* @param client Where to store the connection information
* @return True on success. False on failure.
*/
virtual Socket::Status accept(Socket &client) override;
Socket::Status accept(Socket &client) override;
/*!
* Closes the socket

View File

@ -9,10 +9,10 @@
#define FRNETLIB_VERSION_MAJOR 1
#define FRNETLIB_VERSION_MINOR 0
#define FRNETLIB_VERSION_PATCH 1
#define FRNETLIB_VERSION_PATCH 2
#define FRNETLIB_VERSION_NUMBER (FRNETLIB_VERSION_MAJOR * 100*100 + FRNETLIB_VERSION_MINOR * 100 + FRNETLIB_VERSION_PATCH)
#define FRNETLIB_VERSION_STRING "1.0.1"
#define FRNETLIB_VERSION_STRING_FULL "frnetlib 1.0.1"
#define FRNETLIB_VERSION_STRING "1.0.2"
#define FRNETLIB_VERSION_STRING_FULL "frnetlib 1.0.2"
#endif //FRNETLIB_VERSION_H

View File

@ -85,19 +85,18 @@ namespace fr
Socket::Status SSLListener::accept(Socket &client_)
{
//Cast to SSLSocket. Will throw bad cast on failure.
SSLSocket &client = dynamic_cast<SSLSocket&>(client_);
auto &client = dynamic_cast<SSLSocket&>(client_);
//Initialise mbedtls
int error = 0;
std::unique_ptr<mbedtls_ssl_context> ssl(new mbedtls_ssl_context);
std::unique_ptr<mbedtls_net_context> client_fd(new mbedtls_net_context);
auto ssl = std::make_unique<mbedtls_ssl_context>();
auto client_fd = std::make_unique<mbedtls_net_context>();
mbedtls_ssl_init(ssl.get());
mbedtls_net_init(client_fd.get());
auto free_contexts = [&](){mbedtls_ssl_free(ssl.get()); mbedtls_net_free(client_fd.get());};
if((error = mbedtls_ssl_setup(ssl.get(), &conf ) ) != 0)
if((error = mbedtls_ssl_setup(ssl.get(), &conf)) != 0)
{
std::cout << "Failed to apply SSL setings: " << error << std::endl;
free_contexts();
return Socket::Error;
}
@ -111,9 +110,9 @@ namespace fr
return Socket::Error;
}
mbedtls_ssl_set_bio(ssl.get(), client_fd.get(), mbedtls_net_send, mbedtls_net_recv, nullptr);
//SSL Handshake
mbedtls_ssl_set_bio(ssl.get(), client_fd.get(), mbedtls_net_send, mbedtls_net_recv, nullptr);
while((error = mbedtls_ssl_handshake(ssl.get())) != 0)
{
if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE)
@ -128,12 +127,16 @@ namespace fr
//Get printable address. If we failed then set it as just 'unknown'
char client_printable_addr[INET6_ADDRSTRLEN];
struct sockaddr_storage socket_address{};
socklen_t socket_length;
socklen_t socket_length = sizeof(socket_address);
error = getpeername(client_fd->fd, (struct sockaddr*)&socket_address, &socket_length);
if(error == 0)
{
error = getnameinfo((sockaddr*)&socket_address, socket_length, client_printable_addr, sizeof(client_printable_addr), nullptr,0,NI_NUMERICHOST);
}
if(error != 0)
{
strcpy(client_printable_addr, "unknown");
}
client.set_ssl_context(std::move(ssl));
client.set_descriptor(client_fd.release());

View File

@ -128,7 +128,6 @@ namespace fr
{
if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE)
{
std::cout << "Failed to connect to server. Handshake returned: " << error << std::endl;
return Socket::Status::HandshakeFailed;
}
}
@ -137,9 +136,8 @@ namespace fr
if(should_verify && ((flags = mbedtls_ssl_get_verify_result(ssl.get())) != 0))
{
char verify_buffer[512];
mbedtls_x509_crt_verify_info( verify_buffer, sizeof( verify_buffer ), " ! ", flags );
mbedtls_x509_crt_verify_info(verify_buffer, sizeof(verify_buffer), " ! ", flags);
std::cout << "Failed to connect to server. Server certificate validation failed: " << verify_buffer << std::endl;
return Socket::Status::VerificationFailed;
}

View File

@ -273,11 +273,11 @@ namespace fr
uint64_t total_bits = (transforms * BLOCK_BYTES + buffer.size()) * 8;
/* Padding */
buffer += 0x80;
buffer += static_cast<char>(0x80);
size_t orig_size = buffer.size();
while(buffer.size() < BLOCK_BYTES)
{
buffer += (char) 0x00;
buffer += static_cast<char>(0x00);
}
uint32_t block[BLOCK_INTS];

View File

@ -86,7 +86,7 @@ namespace fr
Socket::Status TcpListener::accept(Socket &client_)
{
//Cast to TcpSocket. Will throw bad cast on failure.
TcpSocket &client = dynamic_cast<TcpSocket&>(client_);
auto &client = dynamic_cast<TcpSocket&>(client_);
//Prepare to wait for the client
sockaddr_storage client_addr{};
@ -100,9 +100,11 @@ namespace fr
return Socket::Unknown;
//Get printable address. If we failed then set it as just 'unknown'
int err = getnameinfo((sockaddr*)&client_addr, client_addr_len, client_printable_addr, sizeof(client_printable_addr), nullptr,0,NI_NUMERICHOST);
int err = getnameinfo((sockaddr*)&client_addr, client_addr_len, client_printable_addr, sizeof(client_printable_addr), nullptr, 0, NI_NUMERICHOST);
if(err != 0)
{
strcpy(client_printable_addr, "unknown");
}
//Set client data
client.set_descriptor(&client_descriptor);