Fixed broken get_remote_address() for sockets accepted over SSL
+ A few other correctness fixes.
This commit is contained in:
parent
db738e9503
commit
decb0b10f9
@ -23,7 +23,7 @@ namespace fr
|
|||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
explicit SSLListener(std::shared_ptr<SSLContext> ssl_context, const std::string &pem_path, const std::string &private_key_path);
|
explicit SSLListener(std::shared_ptr<SSLContext> ssl_context, const std::string &pem_path, const std::string &private_key_path);
|
||||||
virtual ~SSLListener() noexcept;
|
~SSLListener() override;
|
||||||
SSLListener(SSLListener &&) = delete;
|
SSLListener(SSLListener &&) = delete;
|
||||||
SSLListener(SSLListener &o) = delete;
|
SSLListener(SSLListener &o) = delete;
|
||||||
void operator=(const SSLListener &) = delete;
|
void operator=(const SSLListener &) = delete;
|
||||||
@ -35,7 +35,7 @@ namespace fr
|
|||||||
* @param port The port to bind to
|
* @param port The port to bind to
|
||||||
* @return If the operation was successful
|
* @return If the operation was successful
|
||||||
*/
|
*/
|
||||||
virtual Socket::Status listen(const std::string &port) override;
|
Socket::Status listen(const std::string &port) override;
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* Accepts a new connection.
|
* Accepts a new connection.
|
||||||
@ -43,7 +43,7 @@ namespace fr
|
|||||||
* @param client Where to store the connection information
|
* @param client Where to store the connection information
|
||||||
* @return True on success. False on failure.
|
* @return True on success. False on failure.
|
||||||
*/
|
*/
|
||||||
virtual Socket::Status accept(Socket &client) override;
|
Socket::Status accept(Socket &client) override;
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* Closes the socket
|
* Closes the socket
|
||||||
|
|||||||
@ -9,10 +9,10 @@
|
|||||||
|
|
||||||
#define FRNETLIB_VERSION_MAJOR 1
|
#define FRNETLIB_VERSION_MAJOR 1
|
||||||
#define FRNETLIB_VERSION_MINOR 0
|
#define FRNETLIB_VERSION_MINOR 0
|
||||||
#define FRNETLIB_VERSION_PATCH 1
|
#define FRNETLIB_VERSION_PATCH 2
|
||||||
|
|
||||||
#define FRNETLIB_VERSION_NUMBER (FRNETLIB_VERSION_MAJOR * 100*100 + FRNETLIB_VERSION_MINOR * 100 + FRNETLIB_VERSION_PATCH)
|
#define FRNETLIB_VERSION_NUMBER (FRNETLIB_VERSION_MAJOR * 100*100 + FRNETLIB_VERSION_MINOR * 100 + FRNETLIB_VERSION_PATCH)
|
||||||
#define FRNETLIB_VERSION_STRING "1.0.1"
|
#define FRNETLIB_VERSION_STRING "1.0.2"
|
||||||
#define FRNETLIB_VERSION_STRING_FULL "frnetlib 1.0.1"
|
#define FRNETLIB_VERSION_STRING_FULL "frnetlib 1.0.2"
|
||||||
|
|
||||||
#endif //FRNETLIB_VERSION_H
|
#endif //FRNETLIB_VERSION_H
|
||||||
|
|||||||
@ -85,19 +85,18 @@ namespace fr
|
|||||||
Socket::Status SSLListener::accept(Socket &client_)
|
Socket::Status SSLListener::accept(Socket &client_)
|
||||||
{
|
{
|
||||||
//Cast to SSLSocket. Will throw bad cast on failure.
|
//Cast to SSLSocket. Will throw bad cast on failure.
|
||||||
SSLSocket &client = dynamic_cast<SSLSocket&>(client_);
|
auto &client = dynamic_cast<SSLSocket&>(client_);
|
||||||
|
|
||||||
//Initialise mbedtls
|
//Initialise mbedtls
|
||||||
int error = 0;
|
int error = 0;
|
||||||
std::unique_ptr<mbedtls_ssl_context> ssl(new mbedtls_ssl_context);
|
auto ssl = std::make_unique<mbedtls_ssl_context>();
|
||||||
std::unique_ptr<mbedtls_net_context> client_fd(new mbedtls_net_context);
|
auto client_fd = std::make_unique<mbedtls_net_context>();
|
||||||
|
|
||||||
mbedtls_ssl_init(ssl.get());
|
mbedtls_ssl_init(ssl.get());
|
||||||
mbedtls_net_init(client_fd.get());
|
mbedtls_net_init(client_fd.get());
|
||||||
auto free_contexts = [&](){mbedtls_ssl_free(ssl.get()); mbedtls_net_free(client_fd.get());};
|
auto free_contexts = [&](){mbedtls_ssl_free(ssl.get()); mbedtls_net_free(client_fd.get());};
|
||||||
if((error = mbedtls_ssl_setup(ssl.get(), &conf ) ) != 0)
|
if((error = mbedtls_ssl_setup(ssl.get(), &conf)) != 0)
|
||||||
{
|
{
|
||||||
std::cout << "Failed to apply SSL setings: " << error << std::endl;
|
|
||||||
free_contexts();
|
free_contexts();
|
||||||
return Socket::Error;
|
return Socket::Error;
|
||||||
}
|
}
|
||||||
@ -111,9 +110,9 @@ namespace fr
|
|||||||
return Socket::Error;
|
return Socket::Error;
|
||||||
}
|
}
|
||||||
|
|
||||||
mbedtls_ssl_set_bio(ssl.get(), client_fd.get(), mbedtls_net_send, mbedtls_net_recv, nullptr);
|
|
||||||
|
|
||||||
//SSL Handshake
|
//SSL Handshake
|
||||||
|
mbedtls_ssl_set_bio(ssl.get(), client_fd.get(), mbedtls_net_send, mbedtls_net_recv, nullptr);
|
||||||
while((error = mbedtls_ssl_handshake(ssl.get())) != 0)
|
while((error = mbedtls_ssl_handshake(ssl.get())) != 0)
|
||||||
{
|
{
|
||||||
if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE)
|
if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE)
|
||||||
@ -128,12 +127,16 @@ namespace fr
|
|||||||
//Get printable address. If we failed then set it as just 'unknown'
|
//Get printable address. If we failed then set it as just 'unknown'
|
||||||
char client_printable_addr[INET6_ADDRSTRLEN];
|
char client_printable_addr[INET6_ADDRSTRLEN];
|
||||||
struct sockaddr_storage socket_address{};
|
struct sockaddr_storage socket_address{};
|
||||||
socklen_t socket_length;
|
socklen_t socket_length = sizeof(socket_address);
|
||||||
error = getpeername(client_fd->fd, (struct sockaddr*)&socket_address, &socket_length);
|
error = getpeername(client_fd->fd, (struct sockaddr*)&socket_address, &socket_length);
|
||||||
if(error == 0)
|
if(error == 0)
|
||||||
|
{
|
||||||
error = getnameinfo((sockaddr*)&socket_address, socket_length, client_printable_addr, sizeof(client_printable_addr), nullptr,0,NI_NUMERICHOST);
|
error = getnameinfo((sockaddr*)&socket_address, socket_length, client_printable_addr, sizeof(client_printable_addr), nullptr,0,NI_NUMERICHOST);
|
||||||
|
}
|
||||||
if(error != 0)
|
if(error != 0)
|
||||||
|
{
|
||||||
strcpy(client_printable_addr, "unknown");
|
strcpy(client_printable_addr, "unknown");
|
||||||
|
}
|
||||||
|
|
||||||
client.set_ssl_context(std::move(ssl));
|
client.set_ssl_context(std::move(ssl));
|
||||||
client.set_descriptor(client_fd.release());
|
client.set_descriptor(client_fd.release());
|
||||||
|
|||||||
@ -128,7 +128,6 @@ namespace fr
|
|||||||
{
|
{
|
||||||
if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE)
|
if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE)
|
||||||
{
|
{
|
||||||
std::cout << "Failed to connect to server. Handshake returned: " << error << std::endl;
|
|
||||||
return Socket::Status::HandshakeFailed;
|
return Socket::Status::HandshakeFailed;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -137,9 +136,8 @@ namespace fr
|
|||||||
if(should_verify && ((flags = mbedtls_ssl_get_verify_result(ssl.get())) != 0))
|
if(should_verify && ((flags = mbedtls_ssl_get_verify_result(ssl.get())) != 0))
|
||||||
{
|
{
|
||||||
char verify_buffer[512];
|
char verify_buffer[512];
|
||||||
mbedtls_x509_crt_verify_info( verify_buffer, sizeof( verify_buffer ), " ! ", flags );
|
mbedtls_x509_crt_verify_info(verify_buffer, sizeof(verify_buffer), " ! ", flags);
|
||||||
|
|
||||||
std::cout << "Failed to connect to server. Server certificate validation failed: " << verify_buffer << std::endl;
|
|
||||||
return Socket::Status::VerificationFailed;
|
return Socket::Status::VerificationFailed;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -273,11 +273,11 @@ namespace fr
|
|||||||
uint64_t total_bits = (transforms * BLOCK_BYTES + buffer.size()) * 8;
|
uint64_t total_bits = (transforms * BLOCK_BYTES + buffer.size()) * 8;
|
||||||
|
|
||||||
/* Padding */
|
/* Padding */
|
||||||
buffer += 0x80;
|
buffer += static_cast<char>(0x80);
|
||||||
size_t orig_size = buffer.size();
|
size_t orig_size = buffer.size();
|
||||||
while(buffer.size() < BLOCK_BYTES)
|
while(buffer.size() < BLOCK_BYTES)
|
||||||
{
|
{
|
||||||
buffer += (char) 0x00;
|
buffer += static_cast<char>(0x00);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t block[BLOCK_INTS];
|
uint32_t block[BLOCK_INTS];
|
||||||
|
|||||||
@ -86,7 +86,7 @@ namespace fr
|
|||||||
Socket::Status TcpListener::accept(Socket &client_)
|
Socket::Status TcpListener::accept(Socket &client_)
|
||||||
{
|
{
|
||||||
//Cast to TcpSocket. Will throw bad cast on failure.
|
//Cast to TcpSocket. Will throw bad cast on failure.
|
||||||
TcpSocket &client = dynamic_cast<TcpSocket&>(client_);
|
auto &client = dynamic_cast<TcpSocket&>(client_);
|
||||||
|
|
||||||
//Prepare to wait for the client
|
//Prepare to wait for the client
|
||||||
sockaddr_storage client_addr{};
|
sockaddr_storage client_addr{};
|
||||||
@ -100,9 +100,11 @@ namespace fr
|
|||||||
return Socket::Unknown;
|
return Socket::Unknown;
|
||||||
|
|
||||||
//Get printable address. If we failed then set it as just 'unknown'
|
//Get printable address. If we failed then set it as just 'unknown'
|
||||||
int err = getnameinfo((sockaddr*)&client_addr, client_addr_len, client_printable_addr, sizeof(client_printable_addr), nullptr,0,NI_NUMERICHOST);
|
int err = getnameinfo((sockaddr*)&client_addr, client_addr_len, client_printable_addr, sizeof(client_printable_addr), nullptr, 0, NI_NUMERICHOST);
|
||||||
if(err != 0)
|
if(err != 0)
|
||||||
|
{
|
||||||
strcpy(client_printable_addr, "unknown");
|
strcpy(client_printable_addr, "unknown");
|
||||||
|
}
|
||||||
|
|
||||||
//Set client data
|
//Set client data
|
||||||
client.set_descriptor(&client_descriptor);
|
client.set_descriptor(&client_descriptor);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user