Bug fixes. fr::Socket::send_raw now behaves like receive_raw.

Added missing virtual destructors in abstract classes.
This commit is contained in:
Fred Nicolson 2019-03-21 16:24:22 +00:00
parent 7343529302
commit 6814219cbb
No known key found for this signature in database
GPG Key ID: 78C1DD87B47797D2
16 changed files with 308 additions and 169 deletions

View File

@ -7,6 +7,7 @@
#include <frnetlib/HttpRequest.h> #include <frnetlib/HttpRequest.h>
#include <frnetlib/URL.h> #include <frnetlib/URL.h>
#include <frnetlib/HttpResponse.h> #include <frnetlib/HttpResponse.h>
#include <frnetlib/TcpListener.h>
int main() int main()
{ {
@ -23,20 +24,23 @@ int main()
return EXIT_FAILURE; return EXIT_FAILURE;
} }
//Try to connect to the parsed address fr::Socket::Status err;
fr::TcpSocket socket; fr::TcpSocket socket;
if(socket.connect(parsed_url.get_host(), parsed_url.get_port(), {}) != fr::Socket::Success) fr::TcpListener listener;
//Try to connect to the parsed address
if((err = socket.connect(parsed_url.get_host(), parsed_url.get_port(), {})) != fr::Socket::Success)
{ {
std::cerr << "Failed to connect to the specified URL" << std::endl; std::cerr << "Failed to connect to the specified URL: " << fr::Socket::status_to_string(err) << std::endl;
return EXIT_FAILURE; return EXIT_FAILURE;
} }
//Construct a request, requesting the user provided URI //Construct a request, requesting the user provided URI
fr::HttpRequest request; fr::HttpRequest request;
request.set_uri(parsed_url.get_uri()); request.set_uri(parsed_url.get_uri());
if(socket.send(request) != fr::Socket::Success) if((err = socket.send(request)) != fr::Socket::Success)
{ {
std::cerr << "Failed to send HTTP request" << std::endl; std::cerr << "Failed to send HTTP request: " + fr::Socket::status_to_string(err) << std::endl;
return EXIT_FAILURE; return EXIT_FAILURE;
} }

View File

@ -707,7 +707,13 @@ namespace fr
{ {
uint32_t length = htonl((uint32_t)buffer.size() - PACKET_HEADER_LENGTH); uint32_t length = htonl((uint32_t)buffer.size() - PACKET_HEADER_LENGTH);
memcpy(&buffer[0], &length, sizeof(uint32_t)); memcpy(&buffer[0], &length, sizeof(uint32_t));
return socket->send_raw(buffer.c_str(), buffer.size()); fr::Socket::Status state;
size_t sent = 0;
do
{
state = socket->send_raw(&buffer[0], buffer.size(), sent);
} while(state == fr::Socket::WouldBlock);
return state;
} }
/*! /*!

View File

@ -20,7 +20,7 @@ namespace fr
{ {
public: public:
explicit SSLSocket(std::shared_ptr<SSLContext> ssl_context) noexcept; explicit SSLSocket(std::shared_ptr<SSLContext> ssl_context) noexcept;
virtual ~SSLSocket() noexcept; ~SSLSocket() override;
SSLSocket(SSLSocket &&) = delete; SSLSocket(SSLSocket &&) = delete;
SSLSocket(const SSLSocket &) = delete; SSLSocket(const SSLSocket &) = delete;
void operator=(SSLSocket &&)=delete; void operator=(SSLSocket &&)=delete;
@ -30,11 +30,15 @@ namespace fr
* Effectively just fr::TcpSocket::send_raw() with encryption * Effectively just fr::TcpSocket::send_raw() with encryption
* added in. * added in.
* *
* @note If this returns WouldBlock, you must call this function again with the *same* arguments.
* @param data The data to send. * @param data The data to send.
* @param size The number of bytes, from data to send. Be careful not to overflow. * @param size The number of bytes, from data to send. Be careful not to overflow.
* @return The status of the operation. * @return The status of the operation:
* 'WouldBlock' if no data has been received, and the socket is in non-blocking mode or the operation has timed out
* 'SSLError' An SSL error has occurred.
* 'Success' All the bytes you wanted have been read
*/ */
Socket::Status send_raw(const char *data, size_t size) override; Socket::Status send_raw(const char *data, size_t size, size_t &sent) override;
/*! /*!
@ -47,6 +51,7 @@ namespace fr
* @return The status of the operation: * @return The status of the operation:
* 'WouldBlock' if no data has been received, and the socket is in non-blocking mode or the operation has timed out * 'WouldBlock' if no data has been received, and the socket is in non-blocking mode or the operation has timed out
* 'Disconnected' if the socket has disconnected. * 'Disconnected' if the socket has disconnected.
* 'SSLError' An SSL error has occurred.
* 'Success' All the bytes you wanted have been read * 'Success' All the bytes you wanted have been read
*/ */
Socket::Status receive_raw(void *data, size_t data_size, size_t &received) override; Socket::Status receive_raw(void *data, size_t data_size, size_t &received) override;
@ -109,10 +114,31 @@ namespace fr
* *
* @note This must be set *WHILST* connected * @note This must be set *WHILST* connected
* @param should_block True to block, false otherwise. * @param should_block True to block, false otherwise.
* @return A status code indicating success:
* 'SSLError' on failure.
* 'Success' on success.
*/ */
inline void set_blocking(bool should_block) override inline fr::Socket::Status set_blocking(bool should_block) override
{ {
mbedtls_net_set_block(ssl_socket_descriptor.get()); int ret = mbedtls_net_set_block(ssl_socket_descriptor.get());
if(ret != 0)
{
errno = ret;
return fr::Socket::SSLError;
}
is_blocking = should_block;
return fr::Socket::Success;
}
/*!
* Checks if the socket is blocking
*
* @return True if it is, false otherwise
*/
inline bool get_blocking() const override
{
return is_blocking;
} }
/*! /*!
@ -134,12 +160,12 @@ namespace fr
void close_socket() override; void close_socket() override;
std::shared_ptr<SSLContext> ssl_context; std::shared_ptr<SSLContext> ssl_context;
std::unique_ptr<mbedtls_net_context> ssl_socket_descriptor; std::unique_ptr<mbedtls_net_context, decltype(&mbedtls_net_free)> ssl_socket_descriptor;
std::unique_ptr<mbedtls_ssl_context> ssl; std::unique_ptr<mbedtls_ssl_context, decltype(&mbedtls_ssl_free)> ssl;
mbedtls_ssl_config conf; mbedtls_ssl_config conf;
uint32_t flags;
bool should_verify; bool should_verify;
uint32_t receive_timeout; uint32_t receive_timeout;
bool is_blocking;
}; };
} }

View File

@ -11,6 +11,8 @@ namespace fr
class Sendable class Sendable
{ {
public: public:
virtual ~Sendable()=default;
/*! /*!
* Overridable send, to allow * Overridable send, to allow
* custom types to be directly sent through * custom types to be directly sent through

View File

@ -84,8 +84,18 @@ namespace fr
* *
* @note This must be set *WHILST* connected * @note This must be set *WHILST* connected
* @param should_block True for blocking (default argument), false otherwise. * @param should_block True for blocking (default argument), false otherwise.
* @return Status of the operation:
* 'Success' on success.
* Something else (depending on underlying socket type) on failure.
*/ */
virtual void set_blocking(bool should_block) = 0; virtual fr::Socket::Status set_blocking(bool should_block) = 0;
/*!
* Checks if the socket is blocking
*
* @return True if it is, false otherwise
*/
virtual bool get_blocking() const=0;
/*! /*!
* Attempts to send raw data down the socket, without * Attempts to send raw data down the socket, without
@ -94,9 +104,10 @@ namespace fr
* *
* @param data The data to send. * @param data The data to send.
* @param size The number of bytes, from data to send. Be careful not to overflow. * @param size The number of bytes, from data to send. Be careful not to overflow.
* @return The status of the operation. * @param sent The number of bytes that could be sent.
* @return The status of the operation. Dependent on the underlying socket type.
*/ */
virtual Status send_raw(const char *data, size_t size) = 0; virtual Status send_raw(const char *data, size_t size, size_t &sent) = 0;
/*! /*!
* Receives raw data from the socket, without any of * Receives raw data from the socket, without any of
@ -268,7 +279,6 @@ namespace fr
virtual void reconfigure_socket()=0; virtual void reconfigure_socket()=0;
std::string remote_address; std::string remote_address;
bool is_blocking;
int ai_family; int ai_family;
uint32_t max_receive_size; uint32_t max_receive_size;
uint32_t socket_read_timeout; uint32_t socket_read_timeout;

View File

@ -12,19 +12,21 @@ class SocketDescriptor
{ {
public: public:
virtual ~SocketDescriptor()=default;
/*! /*!
* Checks to see if the socket is connected or not * Checks to see if the socket is connected or not
* *
* @return True if connected, false otherwise * @return True if connected, false otherwise
*/ */
virtual bool connected() const noexcept = 0; virtual bool connected() const = 0;
/*! /*!
* Gets the underlying socket descriptor. * Gets the underlying socket descriptor.
* *
* @return The socket descriptor. * @return The socket descriptor.
*/ */
virtual int32_t get_socket_descriptor() const noexcept = 0; virtual int32_t get_socket_descriptor() const = 0;
}; };
} }

View File

@ -27,7 +27,7 @@ public:
* @param port The port to bind to * @param port The port to bind to
* @return If the operation was successful * @return If the operation was successful
*/ */
virtual Socket::Status listen(const std::string &port) override; Socket::Status listen(const std::string &port) override;
/*! /*!
* Accepts a new connection. * Accepts a new connection.
@ -35,7 +35,7 @@ public:
* @param client Where to store the connection information * @param client Where to store the connection information
* @return True on success. False on failure. * @return True on success. False on failure.
*/ */
virtual Socket::Status accept(Socket &client) override; Socket::Status accept(Socket &client) override;
/*! /*!
* Calls the shutdown syscall on the socket. * Calls the shutdown syscall on the socket.

View File

@ -37,10 +37,14 @@ public:
* different protocols. * different protocols.
* *
* @param data The data to send. * @param data The data to send.
* @param size The number of bytes, from data to send. Be careful not to overflow. * @param size The number of bytes, from data to send. Be careful not to overflow. You must zero this, prior to calling send_raw the first time.
* @return The status of the operation. * @param sent Will be filled with the number of bytes actually sent, might be less than you requested if in non-blocking mode/error.
* @return The status of the operation:
* 'WouldBlock' if no data has been received, and the socket is in non-blocking mode or operation has timed out
* 'SendError' if a send error has disconnected.
* 'Success' All the bytes you wanted have been read
*/ */
Status send_raw(const char *data, size_t size) override; Status send_raw(const char *data, size_t size, size_t &sent) override;
/*! /*!
@ -52,10 +56,11 @@ public:
* *
* @param data Where to store the received data. * @param data Where to store the received data.
* @param buffer_size The number of bytes to try and receive. Be sure that it's not larger than data. * @param buffer_size The number of bytes to try and receive. Be sure that it's not larger than data.
* @param received Will be filled with the number of bytes actually received, might be less than you requested. * @param received Will be filled with the number of bytes actually received, might be less than you requested if in non-blocking mode/error.
* @return The status of the operation: * @return The status of the operation:
* 'WouldBlock' if no data has been received, and the socket is in non-blocking mode or operation has timed out * 'WouldBlock' if no data has been received, and the socket is in non-blocking mode or operation has timed out
* 'Disconnected' if the socket has disconnected. * 'Disconnected' if the socket has disconnected.
* 'ReceiveError' A receive error occurred.
* 'Success' All the bytes you wanted have been read * 'Success' All the bytes you wanted have been read
*/ */
Status receive_raw(void *data, size_t buffer_size, size_t &received) override; Status receive_raw(void *data, size_t buffer_size, size_t &received) override;
@ -65,8 +70,18 @@ public:
* *
* @note This must be set *WHILST* connected * @note This must be set *WHILST* connected
* @param should_block True to block, false otherwise. * @param should_block True to block, false otherwise.
* @return State of the operation:
* 'Success' on success.
* 'Error' on failure.
*/ */
void set_blocking(bool should_block) override; Status set_blocking(bool should_block) override;
/*!
* Checks if the socket is blocking
*
* @return True if it is, false otherwise
*/
bool get_blocking() const override;
/*! /*!
* Sets the socket file descriptor. Internally used. * Sets the socket file descriptor. Internally used.
@ -104,6 +119,7 @@ protected:
void close_socket() override; void close_socket() override;
int32_t socket_descriptor; int32_t socket_descriptor;
bool is_blocking;
}; };
} }

View File

@ -18,6 +18,8 @@ namespace fr
class WebSocketBase class WebSocketBase
{ {
public: public:
virtual ~WebSocketBase()=default;
/*! /*!
* Checks if the socket is the client component or the server component * Checks if the socket is the client component or the server component
* *

View File

@ -996,7 +996,13 @@ namespace fr
Socket::Status Http::send(Socket *socket) const Socket::Status Http::send(Socket *socket) const
{ {
std::string data = construct(socket->get_remote_address()); std::string data = construct(socket->get_remote_address());
return socket->send_raw(&data[0], data.size()); size_t sent = 0;
fr::Socket::Status state;
do
{
state = socket->send_raw(&data[0], data.size(), sent);
} while(state == fr::Socket::WouldBlock);
return state;
} }
Socket::Status Http::receive(Socket *socket) Socket::Status Http::receive(Socket *socket)

View File

@ -9,19 +9,47 @@
#include <mbedtls/net_sockets.h> #include <mbedtls/net_sockets.h>
#include <frnetlib/SSLSocket.h> #include <frnetlib/SSLSocket.h>
mbedtls_net_context *net_create()
{
auto *ctx = new mbedtls_net_context;
mbedtls_net_init(ctx);
return ctx;
}
mbedtls_ssl_context *ssl_create()
{
auto *ctx = new mbedtls_ssl_context;
mbedtls_ssl_init(ctx);
return ctx;
}
void ssl_free(mbedtls_ssl_context *ctx)
{
mbedtls_ssl_free(ctx);
delete ctx;
}
void web_free(mbedtls_net_context *ctx)
{
mbedtls_net_free(ctx);
delete ctx;
}
namespace fr namespace fr
{ {
SSLSocket::SSLSocket(std::shared_ptr<SSLContext> ssl_context_) noexcept SSLSocket::SSLSocket(std::shared_ptr<SSLContext> ssl_context_) noexcept
: ssl_context(std::move(ssl_context_)), : ssl_context(std::move(ssl_context_)),
ssl_socket_descriptor(nullptr, web_free),
ssl(nullptr, ssl_free),
should_verify(true), should_verify(true),
receive_timeout(0) receive_timeout(0),
is_blocking(true)
{ {
//Initialise mbedtls structures //Initialise mbedtls structures
mbedtls_ssl_config_init(&conf); mbedtls_ssl_config_init(&conf);
} }
SSLSocket::~SSLSocket() noexcept SSLSocket::~SSLSocket()
{ {
//Close connection if active //Close connection if active
close_socket(); close_socket();
@ -32,33 +60,34 @@ namespace fr
void SSLSocket::close_socket() void SSLSocket::close_socket()
{ {
if(ssl) ssl = nullptr;
{
mbedtls_ssl_close_notify(ssl.get());
mbedtls_ssl_free(ssl.get());
}
if(ssl_socket_descriptor)
mbedtls_net_free(ssl_socket_descriptor.get());
ssl_socket_descriptor = nullptr; ssl_socket_descriptor = nullptr;
} }
Socket::Status SSLSocket::send_raw(const char *data, size_t size) Socket::Status SSLSocket::send_raw(const char *data, size_t size, size_t &sent)
{ {
int response = 0; sent = 0;
size_t data_sent = 0; ssize_t status = 0;
while(data_sent < size)
while(sent < size)
{ {
response = mbedtls_ssl_write(ssl.get(), (const unsigned char *)data + data_sent, size - data_sent); status = mbedtls_ssl_write(ssl.get(), (const unsigned char *)data + sent, size - sent);
if(response != MBEDTLS_ERR_SSL_WANT_READ && response != MBEDTLS_ERR_SSL_WANT_WRITE) if(status < 0)
{ {
data_sent += response; if(status == MBEDTLS_ERR_SSL_WANT_READ || status == MBEDTLS_ERR_SSL_WANT_WRITE)
{
if(is_blocking)
{
continue;
} }
else if(response < 0) return Socket::Status::WouldBlock;
{ }
errno = response;
errno = status;
return Socket::Status::SSLError; return Socket::Status::SSLError;
} }
sent += status;
} }
return Socket::Status::Success; return Socket::Status::Success;
@ -68,6 +97,8 @@ namespace fr
{ {
ssize_t status = 0; ssize_t status = 0;
if(receive_timeout == 0) if(receive_timeout == 0)
{
do
{ {
status = mbedtls_ssl_read(ssl.get(), (unsigned char *) data, data_size); status = mbedtls_ssl_read(ssl.get(), (unsigned char *) data, data_size);
if(status == 0) if(status == 0)
@ -77,26 +108,38 @@ namespace fr
if(status < 0) if(status < 0)
{ {
if(status == MBEDTLS_ERR_SSL_WANT_READ || status == MBEDTLS_ERR_SSL_WANT_WRITE) if(status == MBEDTLS_ERR_SSL_WANT_READ || status == MBEDTLS_ERR_SSL_WANT_WRITE)
{
if(is_blocking)
{ {
return Socket::Status::WouldBlock; return Socket::Status::WouldBlock;
} }
continue;
}
errno = static_cast<int>(status); errno = static_cast<int>(status);
return Socket::Status::SSLError; return Socket::Status::SSLError;
} }
break;
} while(true);
} }
else else
{ {
do do
{ {
status = mbedtls_net_recv_timeout(ssl.get(), (unsigned char *)data, data_size, receive_timeout); status = mbedtls_net_recv_timeout(ssl.get(), (unsigned char *)data, data_size, receive_timeout);
if(status <= 0) if(status == 0)
{
return Socket::Status::Disconnected;
}
if(status < 0)
{ {
if(status == MBEDTLS_ERR_SSL_TIMEOUT) if(status == MBEDTLS_ERR_SSL_TIMEOUT)
{ {
return Socket::Status::WouldBlock; return Socket::Status::WouldBlock;
} }
else if(status == MBEDTLS_ERR_SSL_WANT_READ)
if(status == MBEDTLS_ERR_SSL_WANT_READ)
{ {
continue; //try again, interrupted before anything could be received continue; //try again, interrupted before anything could be received
} }
@ -117,10 +160,8 @@ namespace fr
Socket::Status SSLSocket::connect(const std::string &address, const std::string &port, std::chrono::seconds timeout) Socket::Status SSLSocket::connect(const std::string &address, const std::string &port, std::chrono::seconds timeout)
{ {
//Initialise mbedtls stuff //Initialise mbedtls stuff
ssl = std::make_unique<mbedtls_ssl_context>(); ssl.reset(ssl_create());
ssl_socket_descriptor = std::make_unique<mbedtls_net_context>(); ssl_socket_descriptor.reset(net_create());
mbedtls_ssl_init(ssl.get());
mbedtls_net_init(ssl_socket_descriptor.get());
//Due to mbedtls not supporting connect timeouts, we have to use an fr::TcpSocket to //Due to mbedtls not supporting connect timeouts, we have to use an fr::TcpSocket to
//Open the descriptor, and then steal it. This is a hack. //Open the descriptor, and then steal it. This is a hack.
@ -171,10 +212,13 @@ namespace fr
} }
//Verify server certificate //Verify server certificate
if(should_verify && ((flags = mbedtls_ssl_get_verify_result(ssl.get())) != 0)) if(should_verify)
{
if(mbedtls_ssl_get_verify_result(ssl.get()) != 0)
{ {
return Socket::Status::VerificationFailed; return Socket::Status::VerificationFailed;
} }
}
//Update state //Update state
reconfigure_socket(); reconfigure_socket();
@ -184,7 +228,7 @@ namespace fr
void SSLSocket::set_ssl_context(std::unique_ptr<mbedtls_ssl_context> context) void SSLSocket::set_ssl_context(std::unique_ptr<mbedtls_ssl_context> context)
{ {
ssl = std::move(context); ssl.reset(context.release());
} }
void SSLSocket::set_descriptor(void *descriptor) void SSLSocket::set_descriptor(void *descriptor)

View File

@ -16,8 +16,7 @@
namespace fr namespace fr
{ {
Socket::Socket() Socket::Socket()
: is_blocking(true), : ai_family(AF_UNSPEC),
ai_family(AF_UNSPEC),
max_receive_size(0), max_receive_size(0),
socket_read_timeout(0) socket_read_timeout(0)
{ {

View File

@ -77,14 +77,15 @@ namespace fr
break; break;
} }
//We're done with this now, cleanup
freeaddrinfo(info);
//Check that we've actually bound //Check that we've actually bound
if(c == nullptr) if(c == nullptr)
{ {
return Socket::Status::BindFailed; return Socket::Status::BindFailed;
} }
//We're done with this now, cleanup
freeaddrinfo(info);
//Listen to socket //Listen to socket
if(::listen(socket_descriptor, LISTEN_QUEUE_SIZE) == SOCKET_ERROR) if(::listen(socket_descriptor, LISTEN_QUEUE_SIZE) == SOCKET_ERROR)

View File

@ -13,7 +13,8 @@ namespace fr
{ {
TcpSocket::TcpSocket() noexcept TcpSocket::TcpSocket() noexcept
: socket_descriptor(-1) : socket_descriptor(-1),
is_blocking(true)
{ {
} }
@ -23,20 +24,27 @@ namespace fr
close_socket(); close_socket();
} }
Socket::Status TcpSocket::send_raw(const char *data, size_t size) Socket::Status TcpSocket::send_raw(const char *data, size_t size, size_t &sent)
{ {
size_t sent = 0;
while(sent < size) while(sent < size)
{ {
int64_t status = ::send(socket_descriptor, data + sent, size - sent, 0); int64_t status = ::send(socket_descriptor, data + sent, size - sent, 0);
if(status > 0) if(status > 0)
{ {
sent += status; sent += status;
continue;
} }
else if(errno != EWOULDBLOCK && errno != EAGAIN) //Don't exit if the socket just couldn't block
if(errno == EWOULDBLOCK)
{ {
return Socket::Status::SendError; return Socket::Status::WouldBlock;
} }
else if(errno == EINTR)
{
continue; //try again, interrupted before anything could be received
}
return Socket::Status::SendError;
} }
return Socket::Status::Success; return Socket::Status::Success;
} }
@ -64,7 +72,7 @@ namespace fr
if(status < 0) if(status < 0)
{ {
if(errno == EWOULDBLOCK || errno == EAGAIN) if(errno == EWOULDBLOCK)
{ {
return Socket::Status::WouldBlock; return Socket::Status::WouldBlock;
} }
@ -182,10 +190,12 @@ namespace fr
return Socket::Status::Success; return Socket::Status::Success;
} }
void TcpSocket::set_blocking(bool should_block) Socket::Status TcpSocket::set_blocking(bool should_block)
{ {
set_unix_socket_blocking(socket_descriptor, is_blocking, should_block); if(!set_unix_socket_blocking(socket_descriptor, is_blocking, should_block))
return Status::Error;
is_blocking = should_block; is_blocking = should_block;
return Socket::Success;
} }
int32_t TcpSocket::get_socket_descriptor() const noexcept int32_t TcpSocket::get_socket_descriptor() const noexcept
@ -226,5 +236,10 @@ namespace fr
return socket_descriptor > -1; return socket_descriptor > -1;
} }
bool TcpSocket::get_blocking() const
{
return is_blocking;
}
} }

View File

@ -77,7 +77,13 @@ namespace fr
} }
buffer.append(payload); buffer.append(payload);
return socket_->send_raw(buffer.c_str(), buffer.size()); size_t sent = 0;
fr::Socket::Status state;
do
{
state = socket_->send_raw(buffer.c_str(), buffer.size(), sent);
} while(state == fr::Socket::WouldBlock);
return state;
} }
Socket::Status WebFrame::receive(Socket *socket) Socket::Status WebFrame::receive(Socket *socket)