Bug fixes. fr::Socket::send_raw now behaves like receive_raw.

Added missing virtual destructors in abstract classes.
This commit is contained in:
Fred Nicolson 2019-03-21 16:24:22 +00:00
parent 7343529302
commit 6814219cbb
No known key found for this signature in database
GPG Key ID: 78C1DD87B47797D2
16 changed files with 308 additions and 169 deletions

View File

@ -7,6 +7,7 @@
#include <frnetlib/HttpRequest.h>
#include <frnetlib/URL.h>
#include <frnetlib/HttpResponse.h>
#include <frnetlib/TcpListener.h>
int main()
{
@ -23,20 +24,23 @@ int main()
return EXIT_FAILURE;
}
//Try to connect to the parsed address
fr::Socket::Status err;
fr::TcpSocket socket;
if(socket.connect(parsed_url.get_host(), parsed_url.get_port(), {}) != fr::Socket::Success)
fr::TcpListener listener;
//Try to connect to the parsed address
if((err = socket.connect(parsed_url.get_host(), parsed_url.get_port(), {})) != fr::Socket::Success)
{
std::cerr << "Failed to connect to the specified URL" << std::endl;
std::cerr << "Failed to connect to the specified URL: " << fr::Socket::status_to_string(err) << std::endl;
return EXIT_FAILURE;
}
//Construct a request, requesting the user provided URI
fr::HttpRequest request;
request.set_uri(parsed_url.get_uri());
if(socket.send(request) != fr::Socket::Success)
if((err = socket.send(request)) != fr::Socket::Success)
{
std::cerr << "Failed to send HTTP request" << std::endl;
std::cerr << "Failed to send HTTP request: " + fr::Socket::status_to_string(err) << std::endl;
return EXIT_FAILURE;
}

View File

@ -707,7 +707,13 @@ namespace fr
{
uint32_t length = htonl((uint32_t)buffer.size() - PACKET_HEADER_LENGTH);
memcpy(&buffer[0], &length, sizeof(uint32_t));
return socket->send_raw(buffer.c_str(), buffer.size());
fr::Socket::Status state;
size_t sent = 0;
do
{
state = socket->send_raw(&buffer[0], buffer.size(), sent);
} while(state == fr::Socket::WouldBlock);
return state;
}
/*!

View File

@ -20,7 +20,7 @@ namespace fr
{
public:
explicit SSLSocket(std::shared_ptr<SSLContext> ssl_context) noexcept;
virtual ~SSLSocket() noexcept;
~SSLSocket() override;
SSLSocket(SSLSocket &&) = delete;
SSLSocket(const SSLSocket &) = delete;
void operator=(SSLSocket &&)=delete;
@ -30,11 +30,15 @@ namespace fr
* Effectively just fr::TcpSocket::send_raw() with encryption
* added in.
*
* @note If this returns WouldBlock, you must call this function again with the *same* arguments.
* @param data The data to send.
* @param size The number of bytes, from data to send. Be careful not to overflow.
* @return The status of the operation.
* @return The status of the operation:
* 'WouldBlock' if no data has been received, and the socket is in non-blocking mode or the operation has timed out
* 'SSLError' An SSL error has occurred.
* 'Success' All the bytes you wanted have been read
*/
Socket::Status send_raw(const char *data, size_t size) override;
Socket::Status send_raw(const char *data, size_t size, size_t &sent) override;
/*!
@ -47,6 +51,7 @@ namespace fr
* @return The status of the operation:
* 'WouldBlock' if no data has been received, and the socket is in non-blocking mode or the operation has timed out
* 'Disconnected' if the socket has disconnected.
* 'SSLError' An SSL error has occurred.
* 'Success' All the bytes you wanted have been read
*/
Socket::Status receive_raw(void *data, size_t data_size, size_t &received) override;
@ -109,10 +114,31 @@ namespace fr
*
* @note This must be set *WHILST* connected
* @param should_block True to block, false otherwise.
* @return A status code indicating success:
* 'SSLError' on failure.
* 'Success' on success.
*/
inline void set_blocking(bool should_block) override
inline fr::Socket::Status set_blocking(bool should_block) override
{
mbedtls_net_set_block(ssl_socket_descriptor.get());
int ret = mbedtls_net_set_block(ssl_socket_descriptor.get());
if(ret != 0)
{
errno = ret;
return fr::Socket::SSLError;
}
is_blocking = should_block;
return fr::Socket::Success;
}
/*!
* Checks if the socket is blocking
*
* @return True if it is, false otherwise
*/
inline bool get_blocking() const override
{
return is_blocking;
}
/*!
@ -134,13 +160,13 @@ namespace fr
void close_socket() override;
std::shared_ptr<SSLContext> ssl_context;
std::unique_ptr<mbedtls_net_context> ssl_socket_descriptor;
std::unique_ptr<mbedtls_ssl_context> ssl;
std::unique_ptr<mbedtls_net_context, decltype(&mbedtls_net_free)> ssl_socket_descriptor;
std::unique_ptr<mbedtls_ssl_context, decltype(&mbedtls_ssl_free)> ssl;
mbedtls_ssl_config conf;
uint32_t flags;
bool should_verify;
uint32_t receive_timeout;
bool is_blocking;
};
}
#endif //FRNETLIB_SSLSOCKET_H
#endif //FRNETLIB_SSLSOCKET_H

View File

@ -11,6 +11,8 @@ namespace fr
class Sendable
{
public:
virtual ~Sendable()=default;
/*!
* Overridable send, to allow
* custom types to be directly sent through

View File

@ -84,8 +84,18 @@ namespace fr
*
* @note This must be set *WHILST* connected
* @param should_block True for blocking (default argument), false otherwise.
* @return Status of the operation:
* 'Success' on success.
* Something else (depending on underlying socket type) on failure.
*/
virtual void set_blocking(bool should_block) = 0;
virtual fr::Socket::Status set_blocking(bool should_block) = 0;
/*!
* Checks if the socket is blocking
*
* @return True if it is, false otherwise
*/
virtual bool get_blocking() const=0;
/*!
* Attempts to send raw data down the socket, without
@ -94,9 +104,10 @@ namespace fr
*
* @param data The data to send.
* @param size The number of bytes, from data to send. Be careful not to overflow.
* @return The status of the operation.
* @param sent The number of bytes that could be sent.
* @return The status of the operation. Dependent on the underlying socket type.
*/
virtual Status send_raw(const char *data, size_t size) = 0;
virtual Status send_raw(const char *data, size_t size, size_t &sent) = 0;
/*!
* Receives raw data from the socket, without any of
@ -268,7 +279,6 @@ namespace fr
virtual void reconfigure_socket()=0;
std::string remote_address;
bool is_blocking;
int ai_family;
uint32_t max_receive_size;
uint32_t socket_read_timeout;
@ -276,4 +286,4 @@ namespace fr
}
#endif //FRNETLIB_SOCKET_H
#endif //FRNETLIB_SOCKET_H

View File

@ -12,19 +12,21 @@ class SocketDescriptor
{
public:
virtual ~SocketDescriptor()=default;
/*!
* Checks to see if the socket is connected or not
*
* @return True if connected, false otherwise
*/
virtual bool connected() const noexcept = 0;
virtual bool connected() const = 0;
/*!
* Gets the underlying socket descriptor.
*
* @return The socket descriptor.
*/
virtual int32_t get_socket_descriptor() const noexcept = 0;
virtual int32_t get_socket_descriptor() const = 0;
};
}

View File

@ -27,7 +27,7 @@ public:
* @param port The port to bind to
* @return If the operation was successful
*/
virtual Socket::Status listen(const std::string &port) override;
Socket::Status listen(const std::string &port) override;
/*!
* Accepts a new connection.
@ -35,7 +35,7 @@ public:
* @param client Where to store the connection information
* @return True on success. False on failure.
*/
virtual Socket::Status accept(Socket &client) override;
Socket::Status accept(Socket &client) override;
/*!
* Calls the shutdown syscall on the socket.

View File

@ -11,102 +11,118 @@
namespace fr
{
class TcpSocket : public Socket
{
public:
TcpSocket() noexcept;
~TcpSocket() override;
TcpSocket(TcpSocket &&) = delete;
TcpSocket(const TcpSocket &) = delete;
void operator=(TcpSocket &&)=delete;
void operator=(const TcpSocket &)=delete;
class TcpSocket : public Socket
{
public:
TcpSocket() noexcept;
~TcpSocket() override;
TcpSocket(TcpSocket &&) = delete;
TcpSocket(const TcpSocket &) = delete;
void operator=(TcpSocket &&)=delete;
void operator=(const TcpSocket &)=delete;
/*!
* Connects the socket to an address.
*
* @param address The address of the socket to connect to
* @param port The port of the socket to connect to
* @param timeout The number of seconds to wait before timing the connection attempt out. Pass {} for default.
* @return A Socket::Status indicating the status of the operation. (Success on success, an error type on failure).
*/
Socket::Status connect(const std::string &address, const std::string &port, std::chrono::seconds timeout) override;
/*!
* Connects the socket to an address.
*
* @param address The address of the socket to connect to
* @param port The port of the socket to connect to
* @param timeout The number of seconds to wait before timing the connection attempt out. Pass {} for default.
* @return A Socket::Status indicating the status of the operation. (Success on success, an error type on failure).
*/
Socket::Status connect(const std::string &address, const std::string &port, std::chrono::seconds timeout) override;
/*!
* Attempts to send raw data down the socket, without
* any of frnetlib's framing. Useful for communicating through
* different protocols.
*
* @param data The data to send.
* @param size The number of bytes, from data to send. Be careful not to overflow.
* @return The status of the operation.
*/
Status send_raw(const char *data, size_t size) override;
/*!
* Attempts to send raw data down the socket, without
* any of frnetlib's framing. Useful for communicating through
* different protocols.
*
* @param data The data to send.
* @param size The number of bytes, from data to send. Be careful not to overflow. You must zero this, prior to calling send_raw the first time.
* @param sent Will be filled with the number of bytes actually sent, might be less than you requested if in non-blocking mode/error.
* @return The status of the operation:
* 'WouldBlock' if no data has been received, and the socket is in non-blocking mode or operation has timed out
* 'SendError' if a send error has disconnected.
* 'Success' All the bytes you wanted have been read
*/
Status send_raw(const char *data, size_t size, size_t &sent) override;
/*!
* Receives raw data from the socket, without any of
* frnetlib's framing. Useful for communicating through
* different protocols. This will attempt to read 'data_size'
* bytes, but might not succeed. It'll return how many bytes were actually
* read in 'received'.
*
* @param data Where to store the received data.
* @param buffer_size The number of bytes to try and receive. Be sure that it's not larger than data.
* @param received Will be filled with the number of bytes actually received, might be less than you requested.
* @return The status of the operation:
* 'WouldBlock' if no data has been received, and the socket is in non-blocking mode or operation has timed out
* 'Disconnected' if the socket has disconnected.
* 'Success' All the bytes you wanted have been read
*/
Status receive_raw(void *data, size_t buffer_size, size_t &received) override;
/*!
* Receives raw data from the socket, without any of
* frnetlib's framing. Useful for communicating through
* different protocols. This will attempt to read 'data_size'
* bytes, but might not succeed. It'll return how many bytes were actually
* read in 'received'.
*
* @param data Where to store the received data.
* @param buffer_size The number of bytes to try and receive. Be sure that it's not larger than data.
* @param received Will be filled with the number of bytes actually received, might be less than you requested if in non-blocking mode/error.
* @return The status of the operation:
* 'WouldBlock' if no data has been received, and the socket is in non-blocking mode or operation has timed out
* 'Disconnected' if the socket has disconnected.
* 'ReceiveError' A receive error occurred.
* 'Success' All the bytes you wanted have been read
*/
Status receive_raw(void *data, size_t buffer_size, size_t &received) override;
/*!
* Sets if the socket should be blocking or non-blocking.
*
* @note This must be set *WHILST* connected
* @param should_block True to block, false otherwise.
*/
void set_blocking(bool should_block) override;
/*!
* Sets if the socket should be blocking or non-blocking.
*
* @note This must be set *WHILST* connected
* @param should_block True to block, false otherwise.
* @return State of the operation:
* 'Success' on success.
* 'Error' on failure.
*/
Status set_blocking(bool should_block) override;
/*!
* Sets the socket file descriptor. Internally used.
*
* @note For TcpSocket, this should be a pointer to a int32_t. A copy is made.
* @param descriptor_data The socket descriptor data, set up by the Listener.
*/
void set_descriptor(void *descriptor_data) override;
/*!
* Checks if the socket is blocking
*
* @return True if it is, false otherwise
*/
bool get_blocking() const override;
/*!
* Applies requested socket options to the socket.
* Should be called when a new socket is created.
*/
void reconfigure_socket() override;
/*!
* Sets the socket file descriptor. Internally used.
*
* @note For TcpSocket, this should be a pointer to a int32_t. A copy is made.
* @param descriptor_data The socket descriptor data, set up by the Listener.
*/
void set_descriptor(void *descriptor_data) override;
/*!
* Checks to see if the socket is connected or not
*
* @return True if connected, false otherwise
*/
bool connected() const noexcept override;
/*!
* Applies requested socket options to the socket.
* Should be called when a new socket is created.
*/
void reconfigure_socket() override;
/*!
* Gets the underlying socket descriptor.
*
* @return The socket descriptor.
*/
int32_t get_socket_descriptor() const noexcept override;
/*!
* Checks to see if the socket is connected or not
*
* @return True if connected, false otherwise
*/
bool connected() const noexcept override;
protected:
/*!
* Gets the underlying socket descriptor.
*
* @return The socket descriptor.
*/
int32_t get_socket_descriptor() const noexcept override;
/*!
* Close the connection.
*/
void close_socket() override;
protected:
int32_t socket_descriptor;
};
/*!
* Close the connection.
*/
void close_socket() override;
int32_t socket_descriptor;
bool is_blocking;
};
}
#endif //FRNETLIB_TCPSOCKET_H
#endif //FRNETLIB_TCPSOCKET_H

View File

@ -18,6 +18,8 @@ namespace fr
class WebSocketBase
{
public:
virtual ~WebSocketBase()=default;
/*!
* Checks if the socket is the client component or the server component
*
@ -148,4 +150,4 @@ namespace fr
};
}
#endif //FRNETLIB_WEBSOCKET_H
#endif //FRNETLIB_WEBSOCKET_H

View File

@ -996,7 +996,13 @@ namespace fr
Socket::Status Http::send(Socket *socket) const
{
std::string data = construct(socket->get_remote_address());
return socket->send_raw(&data[0], data.size());
size_t sent = 0;
fr::Socket::Status state;
do
{
state = socket->send_raw(&data[0], data.size(), sent);
} while(state == fr::Socket::WouldBlock);
return state;
}
Socket::Status Http::receive(Socket *socket)

View File

@ -16,7 +16,7 @@
namespace fr
{
SSLListener::SSLListener(std::shared_ptr<SSLContext> ssl_context_, const std::string &pem_path, const std::string &private_key_path)
: ssl_context(std::move(ssl_context_))
: ssl_context(std::move(ssl_context_))
{
//Initialise SSL objects required
listen_fd.fd = -1;

View File

@ -9,19 +9,47 @@
#include <mbedtls/net_sockets.h>
#include <frnetlib/SSLSocket.h>
mbedtls_net_context *net_create()
{
auto *ctx = new mbedtls_net_context;
mbedtls_net_init(ctx);
return ctx;
}
mbedtls_ssl_context *ssl_create()
{
auto *ctx = new mbedtls_ssl_context;
mbedtls_ssl_init(ctx);
return ctx;
}
void ssl_free(mbedtls_ssl_context *ctx)
{
mbedtls_ssl_free(ctx);
delete ctx;
}
void web_free(mbedtls_net_context *ctx)
{
mbedtls_net_free(ctx);
delete ctx;
}
namespace fr
{
SSLSocket::SSLSocket(std::shared_ptr<SSLContext> ssl_context_) noexcept
: ssl_context(std::move(ssl_context_)),
ssl_socket_descriptor(nullptr, web_free),
ssl(nullptr, ssl_free),
should_verify(true),
receive_timeout(0)
receive_timeout(0),
is_blocking(true)
{
//Initialise mbedtls structures
mbedtls_ssl_config_init(&conf);
}
SSLSocket::~SSLSocket() noexcept
SSLSocket::~SSLSocket()
{
//Close connection if active
close_socket();
@ -32,33 +60,34 @@ namespace fr
void SSLSocket::close_socket()
{
if(ssl)
{
mbedtls_ssl_close_notify(ssl.get());
mbedtls_ssl_free(ssl.get());
}
if(ssl_socket_descriptor)
mbedtls_net_free(ssl_socket_descriptor.get());
ssl = nullptr;
ssl_socket_descriptor = nullptr;
}
Socket::Status SSLSocket::send_raw(const char *data, size_t size)
Socket::Status SSLSocket::send_raw(const char *data, size_t size, size_t &sent)
{
int response = 0;
size_t data_sent = 0;
while(data_sent < size)
sent = 0;
ssize_t status = 0;
while(sent < size)
{
response = mbedtls_ssl_write(ssl.get(), (const unsigned char *)data + data_sent, size - data_sent);
if(response != MBEDTLS_ERR_SSL_WANT_READ && response != MBEDTLS_ERR_SSL_WANT_WRITE)
status = mbedtls_ssl_write(ssl.get(), (const unsigned char *)data + sent, size - sent);
if(status < 0)
{
data_sent += response;
}
else if(response < 0)
{
errno = response;
if(status == MBEDTLS_ERR_SSL_WANT_READ || status == MBEDTLS_ERR_SSL_WANT_WRITE)
{
if(is_blocking)
{
continue;
}
return Socket::Status::WouldBlock;
}
errno = status;
return Socket::Status::SSLError;
}
sent += status;
}
return Socket::Status::Success;
@ -69,34 +98,48 @@ namespace fr
ssize_t status = 0;
if(receive_timeout == 0)
{
status = mbedtls_ssl_read(ssl.get(), (unsigned char *)data, data_size);
if(status == 0)
do
{
return Socket::Status::Disconnected;
}
if(status < 0)
{
if(status == MBEDTLS_ERR_SSL_WANT_READ || status == MBEDTLS_ERR_SSL_WANT_WRITE)
status = mbedtls_ssl_read(ssl.get(), (unsigned char *) data, data_size);
if(status == 0)
{
return Socket::Status::WouldBlock;
return Socket::Status::Disconnected;
}
if(status < 0)
{
if(status == MBEDTLS_ERR_SSL_WANT_READ || status == MBEDTLS_ERR_SSL_WANT_WRITE)
{
if(is_blocking)
{
return Socket::Status::WouldBlock;
}
continue;
}
errno = static_cast<int>(status);
return Socket::Status::SSLError;
}
errno = static_cast<int>(status);
return Socket::Status::SSLError;
}
break;
} while(true);
}
else
{
do
{
status = mbedtls_net_recv_timeout(ssl.get(), (unsigned char *)data, data_size, receive_timeout);
if(status <= 0)
if(status == 0)
{
return Socket::Status::Disconnected;
}
if(status < 0)
{
if(status == MBEDTLS_ERR_SSL_TIMEOUT)
{
return Socket::Status::WouldBlock;
}
else if(status == MBEDTLS_ERR_SSL_WANT_READ)
if(status == MBEDTLS_ERR_SSL_WANT_READ)
{
continue; //try again, interrupted before anything could be received
}
@ -117,10 +160,8 @@ namespace fr
Socket::Status SSLSocket::connect(const std::string &address, const std::string &port, std::chrono::seconds timeout)
{
//Initialise mbedtls stuff
ssl = std::make_unique<mbedtls_ssl_context>();
ssl_socket_descriptor = std::make_unique<mbedtls_net_context>();
mbedtls_ssl_init(ssl.get());
mbedtls_net_init(ssl_socket_descriptor.get());
ssl.reset(ssl_create());
ssl_socket_descriptor.reset(net_create());
//Due to mbedtls not supporting connect timeouts, we have to use an fr::TcpSocket to
//Open the descriptor, and then steal it. This is a hack.
@ -171,9 +212,12 @@ namespace fr
}
//Verify server certificate
if(should_verify && ((flags = mbedtls_ssl_get_verify_result(ssl.get())) != 0))
if(should_verify)
{
return Socket::Status::VerificationFailed;
if(mbedtls_ssl_get_verify_result(ssl.get()) != 0)
{
return Socket::Status::VerificationFailed;
}
}
//Update state
@ -184,7 +228,7 @@ namespace fr
void SSLSocket::set_ssl_context(std::unique_ptr<mbedtls_ssl_context> context)
{
ssl = std::move(context);
ssl.reset(context.release());
}
void SSLSocket::set_descriptor(void *descriptor)
@ -220,4 +264,4 @@ namespace fr
setsockopt(get_socket_descriptor(), SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout_dword, sizeof timeout_dword);
#endif
}
}
}

View File

@ -16,10 +16,9 @@
namespace fr
{
Socket::Socket()
: is_blocking(true),
ai_family(AF_UNSPEC),
max_receive_size(0),
socket_read_timeout(0)
: ai_family(AF_UNSPEC),
max_receive_size(0),
socket_read_timeout(0)
{
init_wsa();
}
@ -99,7 +98,7 @@ namespace fr
};
#define ERR_STR wsa_err_to_str(WSAGetLastError())
#else
#define ERR_STR strerror(errno)
#define ERR_STR strerror(errno)
#endif
switch(status)

View File

@ -13,7 +13,7 @@ namespace fr
const int no = 0;
TcpListener::TcpListener()
: socket_descriptor(-1)
: socket_descriptor(-1)
{
}
@ -77,14 +77,15 @@ namespace fr
break;
}
//We're done with this now, cleanup
freeaddrinfo(info);
//Check that we've actually bound
if(c == nullptr)
{
return Socket::Status::BindFailed;
}
//We're done with this now, cleanup
freeaddrinfo(info);
//Listen to socket
if(::listen(socket_descriptor, LISTEN_QUEUE_SIZE) == SOCKET_ERROR)

View File

@ -13,7 +13,8 @@ namespace fr
{
TcpSocket::TcpSocket() noexcept
: socket_descriptor(-1)
: socket_descriptor(-1),
is_blocking(true)
{
}
@ -23,20 +24,27 @@ namespace fr
close_socket();
}
Socket::Status TcpSocket::send_raw(const char *data, size_t size)
Socket::Status TcpSocket::send_raw(const char *data, size_t size, size_t &sent)
{
size_t sent = 0;
while(sent < size)
{
int64_t status = ::send(socket_descriptor, data + sent, size - sent, 0);
if(status > 0)
{
sent += status;
continue;
}
else if(errno != EWOULDBLOCK && errno != EAGAIN) //Don't exit if the socket just couldn't block
if(errno == EWOULDBLOCK)
{
return Socket::Status::SendError;
return Socket::Status::WouldBlock;
}
else if(errno == EINTR)
{
continue; //try again, interrupted before anything could be received
}
return Socket::Status::SendError;
}
return Socket::Status::Success;
}
@ -64,7 +72,7 @@ namespace fr
if(status < 0)
{
if(errno == EWOULDBLOCK || errno == EAGAIN)
if(errno == EWOULDBLOCK)
{
return Socket::Status::WouldBlock;
}
@ -182,10 +190,12 @@ namespace fr
return Socket::Status::Success;
}
void TcpSocket::set_blocking(bool should_block)
Socket::Status TcpSocket::set_blocking(bool should_block)
{
set_unix_socket_blocking(socket_descriptor, is_blocking, should_block);
if(!set_unix_socket_blocking(socket_descriptor, is_blocking, should_block))
return Status::Error;
is_blocking = should_block;
return Socket::Success;
}
int32_t TcpSocket::get_socket_descriptor() const noexcept
@ -226,5 +236,10 @@ namespace fr
return socket_descriptor > -1;
}
bool TcpSocket::get_blocking() const
{
return is_blocking;
}
}

View File

@ -77,7 +77,13 @@ namespace fr
}
buffer.append(payload);
return socket_->send_raw(buffer.c_str(), buffer.size());
size_t sent = 0;
fr::Socket::Status state;
do
{
state = socket_->send_raw(buffer.c_str(), buffer.size(), sent);
} while(state == fr::Socket::WouldBlock);
return state;
}
Socket::Status WebFrame::receive(Socket *socket)