Fixed broken get_remote_address() for sockets accepted over SSL
+ A few other correctness fixes.
This commit is contained in:
parent
db738e9503
commit
decb0b10f9
@ -23,7 +23,7 @@ namespace fr
|
||||
{
|
||||
public:
|
||||
explicit SSLListener(std::shared_ptr<SSLContext> ssl_context, const std::string &pem_path, const std::string &private_key_path);
|
||||
virtual ~SSLListener() noexcept;
|
||||
~SSLListener() override;
|
||||
SSLListener(SSLListener &&) = delete;
|
||||
SSLListener(SSLListener &o) = delete;
|
||||
void operator=(const SSLListener &) = delete;
|
||||
@ -35,7 +35,7 @@ namespace fr
|
||||
* @param port The port to bind to
|
||||
* @return If the operation was successful
|
||||
*/
|
||||
virtual Socket::Status listen(const std::string &port) override;
|
||||
Socket::Status listen(const std::string &port) override;
|
||||
|
||||
/*!
|
||||
* Accepts a new connection.
|
||||
@ -43,7 +43,7 @@ namespace fr
|
||||
* @param client Where to store the connection information
|
||||
* @return True on success. False on failure.
|
||||
*/
|
||||
virtual Socket::Status accept(Socket &client) override;
|
||||
Socket::Status accept(Socket &client) override;
|
||||
|
||||
/*!
|
||||
* Closes the socket
|
||||
|
||||
@ -9,10 +9,10 @@
|
||||
|
||||
#define FRNETLIB_VERSION_MAJOR 1
|
||||
#define FRNETLIB_VERSION_MINOR 0
|
||||
#define FRNETLIB_VERSION_PATCH 1
|
||||
#define FRNETLIB_VERSION_PATCH 2
|
||||
|
||||
#define FRNETLIB_VERSION_NUMBER (FRNETLIB_VERSION_MAJOR * 100*100 + FRNETLIB_VERSION_MINOR * 100 + FRNETLIB_VERSION_PATCH)
|
||||
#define FRNETLIB_VERSION_STRING "1.0.1"
|
||||
#define FRNETLIB_VERSION_STRING_FULL "frnetlib 1.0.1"
|
||||
#define FRNETLIB_VERSION_STRING "1.0.2"
|
||||
#define FRNETLIB_VERSION_STRING_FULL "frnetlib 1.0.2"
|
||||
|
||||
#endif //FRNETLIB_VERSION_H
|
||||
|
||||
@ -85,19 +85,18 @@ namespace fr
|
||||
Socket::Status SSLListener::accept(Socket &client_)
|
||||
{
|
||||
//Cast to SSLSocket. Will throw bad cast on failure.
|
||||
SSLSocket &client = dynamic_cast<SSLSocket&>(client_);
|
||||
auto &client = dynamic_cast<SSLSocket&>(client_);
|
||||
|
||||
//Initialise mbedtls
|
||||
int error = 0;
|
||||
std::unique_ptr<mbedtls_ssl_context> ssl(new mbedtls_ssl_context);
|
||||
std::unique_ptr<mbedtls_net_context> client_fd(new mbedtls_net_context);
|
||||
auto ssl = std::make_unique<mbedtls_ssl_context>();
|
||||
auto client_fd = std::make_unique<mbedtls_net_context>();
|
||||
|
||||
mbedtls_ssl_init(ssl.get());
|
||||
mbedtls_net_init(client_fd.get());
|
||||
auto free_contexts = [&](){mbedtls_ssl_free(ssl.get()); mbedtls_net_free(client_fd.get());};
|
||||
if((error = mbedtls_ssl_setup(ssl.get(), &conf ) ) != 0)
|
||||
if((error = mbedtls_ssl_setup(ssl.get(), &conf)) != 0)
|
||||
{
|
||||
std::cout << "Failed to apply SSL setings: " << error << std::endl;
|
||||
free_contexts();
|
||||
return Socket::Error;
|
||||
}
|
||||
@ -111,9 +110,9 @@ namespace fr
|
||||
return Socket::Error;
|
||||
}
|
||||
|
||||
mbedtls_ssl_set_bio(ssl.get(), client_fd.get(), mbedtls_net_send, mbedtls_net_recv, nullptr);
|
||||
|
||||
//SSL Handshake
|
||||
mbedtls_ssl_set_bio(ssl.get(), client_fd.get(), mbedtls_net_send, mbedtls_net_recv, nullptr);
|
||||
while((error = mbedtls_ssl_handshake(ssl.get())) != 0)
|
||||
{
|
||||
if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE)
|
||||
@ -128,12 +127,16 @@ namespace fr
|
||||
//Get printable address. If we failed then set it as just 'unknown'
|
||||
char client_printable_addr[INET6_ADDRSTRLEN];
|
||||
struct sockaddr_storage socket_address{};
|
||||
socklen_t socket_length;
|
||||
socklen_t socket_length = sizeof(socket_address);
|
||||
error = getpeername(client_fd->fd, (struct sockaddr*)&socket_address, &socket_length);
|
||||
if(error == 0)
|
||||
{
|
||||
error = getnameinfo((sockaddr*)&socket_address, socket_length, client_printable_addr, sizeof(client_printable_addr), nullptr,0,NI_NUMERICHOST);
|
||||
}
|
||||
if(error != 0)
|
||||
{
|
||||
strcpy(client_printable_addr, "unknown");
|
||||
}
|
||||
|
||||
client.set_ssl_context(std::move(ssl));
|
||||
client.set_descriptor(client_fd.release());
|
||||
|
||||
@ -128,7 +128,6 @@ namespace fr
|
||||
{
|
||||
if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE)
|
||||
{
|
||||
std::cout << "Failed to connect to server. Handshake returned: " << error << std::endl;
|
||||
return Socket::Status::HandshakeFailed;
|
||||
}
|
||||
}
|
||||
@ -137,9 +136,8 @@ namespace fr
|
||||
if(should_verify && ((flags = mbedtls_ssl_get_verify_result(ssl.get())) != 0))
|
||||
{
|
||||
char verify_buffer[512];
|
||||
mbedtls_x509_crt_verify_info( verify_buffer, sizeof( verify_buffer ), " ! ", flags );
|
||||
mbedtls_x509_crt_verify_info(verify_buffer, sizeof(verify_buffer), " ! ", flags);
|
||||
|
||||
std::cout << "Failed to connect to server. Server certificate validation failed: " << verify_buffer << std::endl;
|
||||
return Socket::Status::VerificationFailed;
|
||||
}
|
||||
|
||||
|
||||
@ -273,11 +273,11 @@ namespace fr
|
||||
uint64_t total_bits = (transforms * BLOCK_BYTES + buffer.size()) * 8;
|
||||
|
||||
/* Padding */
|
||||
buffer += 0x80;
|
||||
buffer += static_cast<char>(0x80);
|
||||
size_t orig_size = buffer.size();
|
||||
while(buffer.size() < BLOCK_BYTES)
|
||||
{
|
||||
buffer += (char) 0x00;
|
||||
buffer += static_cast<char>(0x00);
|
||||
}
|
||||
|
||||
uint32_t block[BLOCK_INTS];
|
||||
|
||||
@ -86,7 +86,7 @@ namespace fr
|
||||
Socket::Status TcpListener::accept(Socket &client_)
|
||||
{
|
||||
//Cast to TcpSocket. Will throw bad cast on failure.
|
||||
TcpSocket &client = dynamic_cast<TcpSocket&>(client_);
|
||||
auto &client = dynamic_cast<TcpSocket&>(client_);
|
||||
|
||||
//Prepare to wait for the client
|
||||
sockaddr_storage client_addr{};
|
||||
@ -100,9 +100,11 @@ namespace fr
|
||||
return Socket::Unknown;
|
||||
|
||||
//Get printable address. If we failed then set it as just 'unknown'
|
||||
int err = getnameinfo((sockaddr*)&client_addr, client_addr_len, client_printable_addr, sizeof(client_printable_addr), nullptr,0,NI_NUMERICHOST);
|
||||
int err = getnameinfo((sockaddr*)&client_addr, client_addr_len, client_printable_addr, sizeof(client_printable_addr), nullptr, 0, NI_NUMERICHOST);
|
||||
if(err != 0)
|
||||
{
|
||||
strcpy(client_printable_addr, "unknown");
|
||||
}
|
||||
|
||||
//Set client data
|
||||
client.set_descriptor(&client_descriptor);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user