Bug fixes. Documentation improvements.

Fixed SSLListener failing to accept SSLSockets properly (not setting the descriptor properly).

TcpSocket::receive_raw and SSLSocket::receive_raw now behave the same, rather than SSLSocket acting more like Socket::receive_all.

Documented specific return values from Socket::receive_all().

Socket::receive_all now returns WouldBlock if the socket is in blocking mode and the first read returns no data, so it doesn't behave like a blocking socket.

Disabled copying/moving of sockets. Copying shouldn't have been enabled, but might add move constructors in the future.

Added Socket::disconnect, which internally just calls close_socket, to allow for protocol-specific disconnect sequences in the future (WebSockets).
This commit is contained in:
Unknown 2018-02-28 23:44:31 +00:00
parent 0840c07e24
commit 62d8b7ba63
9 changed files with 164 additions and 131 deletions

View File

@ -20,10 +20,11 @@ namespace fr
{
public:
explicit SSLSocket(std::shared_ptr<SSLContext> ssl_context) noexcept;
virtual ~SSLSocket() noexcept;
SSLSocket(SSLSocket &&) noexcept = default;
SSLSocket(SSLSocket &&) = delete;
SSLSocket(const SSLSocket &) = delete;
void operator=(SSLSocket &&)=delete;
void operator=(const SSLSocket &)=delete;
/*!
* Effectively just fr::TcpSocket::send_raw() with encryption
@ -43,7 +44,10 @@ namespace fr
* @param data Where to store the received data.
* @param data_size The number of bytes to try and receive. Be sure that it's not larger than data.
* @param received Will be filled with the number of bytes actually received, might be less than you requested.
* @return The status of the operation, if the socket has disconnected etc.
* @return The status of the operation:
* 'WouldBlock' if no data has been received, and the socket is in non-blocking mode
* 'Disconnected' if the socket has disconnected.
* 'Success' All the bytes you wanted have been read
*/
Socket::Status receive_raw(void *data, size_t data_size, size_t &received) override;
@ -60,14 +64,15 @@ namespace fr
* @param timeout The number of seconds to wait before timing the connection attempt out. Pass -1 for default.
* @return A Socket::Status indicating the status of the operation.
*/
virtual Socket::Status connect(const std::string &address, const std::string &port, std::chrono::seconds timeout) override;
Socket::Status connect(const std::string &address, const std::string &port, std::chrono::seconds timeout) override;
/*!
* Sets the socket file descriptor.
* Sets the socket file descriptor. Internally used.
*
* @param descriptor The socket descriptor.
* @note For SSLSocket, this should be a pointer to a heap allocated mbedtls_net_context. Pointer ownership will be taken over by the SSLSocket.
* @param descriptor_data The socket descriptor data, set up by the Listener.
*/
virtual void set_descriptor(int descriptor) override;
void set_descriptor(void *descriptor_data) override;
/*!
* Set the SSL context
@ -76,36 +81,6 @@ namespace fr
*/
void set_ssl_context(std::unique_ptr<mbedtls_ssl_context> context);
/*!
* Gets the underlying socket descriptor.
*
* @return The socket's descriptor.
*/
int32_t get_socket_descriptor() const override
{
return ssl_socket_descriptor.fd;
}
/*!
* Sets if the socket should block or not.
*
* @param should_block True to block, false otherwise.
*/
void set_blocking(bool should_block) override
{
abort();
}
/*!
* Checks to see if we're connected to a socket or not
*
* @return True if it's connected. False otherwise.
*/
inline bool connected() const final
{
return ssl_socket_descriptor.fd > -1;
}
/*!
* Sets if the socket should verify the endpoints
* certificates or not. Verification is enforced
@ -116,10 +91,43 @@ namespace fr
*/
void verify_certificates(bool should_verify);
/*!
* Gets the underlying socket descriptor.
*
* @return The socket's descriptor.
*/
inline int32_t get_socket_descriptor() const override
{
if(!ssl_socket_descriptor)
return -1;
return ssl_socket_descriptor->fd;
}
/*!
* Sets if the socket should block or not.
*
* @note This must be set *WHILST* connected
* @param should_block True to block, false otherwise.
*/
inline void set_blocking(bool should_block) override
{
mbedtls_net_set_block(ssl_socket_descriptor.get());
}
/*!
* Checks to see if we're connected to a socket or not
*
* @return True if it's connected. False otherwise.
*/
inline bool connected() const final
{
return ssl_socket_descriptor && ssl_socket_descriptor->fd > -1;
}
private:
std::shared_ptr<SSLContext> ssl_context;
mbedtls_net_context ssl_socket_descriptor;
std::unique_ptr<mbedtls_net_context> ssl_socket_descriptor;
std::unique_ptr<mbedtls_ssl_context> ssl;
mbedtls_ssl_config conf;
uint32_t flags;

View File

@ -19,7 +19,7 @@ namespace fr
* @param socket The socket to send through
* @return Status indicating if the send succeeded or not.
*/
virtual Socket::Status send(Socket *socket) = 0; //TODO: RETURN PROPER VALUE FROM HTTP PARSE
virtual Socket::Status send(Socket *socket) = 0;
/*!
* Overrideable receive, to allow

View File

@ -45,26 +45,18 @@ namespace fr
Socket() noexcept;
virtual ~Socket() noexcept = default;
Socket(Socket &&o) noexcept
{
remote_address = std::move(o.remote_address);
is_blocking = o.is_blocking;
ai_family = o.ai_family;
max_receive_size = o.max_receive_size;
}
/*!
* Close the connection.
*/
virtual void close_socket()=0;
Socket(Socket &&) =delete;
Socket(const Socket &) =delete;
void operator=(const Socket &) =delete;
void operator=(Socket &&) =delete;
/*!
* Connects the socket to an address.
*
* @param address The address of the socket to connect to
* @param port The port of the socket to connect to
* @param timeout The number of seconds to wait before timing the connection attempt out. Pass -1 for default.
* @return A Socket::Status indicating the status of the operation.
* @param timeout The number of seconds to wait before timing the connection attempt out. Pass {} for default.
* @return A Socket::Status indicating the status of the operation. (Success on success, an error type on failure).
*/
virtual Socket::Status connect(const std::string &address, const std::string &port, std::chrono::seconds timeout)=0;
@ -72,6 +64,7 @@ namespace fr
/*!
* Sets the socket to blocking or non-blocking.
*
* @note This must be set *WHILST* connected
* @param should_block True for blocking (default argument), false otherwise.
*/
virtual void set_blocking(bool should_block) = 0;
@ -116,31 +109,11 @@ namespace fr
virtual bool connected() const =0;
/*!
* Sets the socket file descriptor.
* Sets the socket file descriptor. Internally used.
*
* @param descriptor The socket descriptor.
* @param descriptor_data The socket descriptor data, set up by the Listener.
*/
virtual void set_descriptor(int descriptor)=0;
/*!
* Gets the socket's printable remote address
*
* @return The string address
*/
inline const std::string &get_remote_address()
{
return remote_address;
}
/*!
* Sets the connections remote address.
*
* @param addr The remote address to use
*/
void set_remote_address(const std::string &addr)
{
remote_address = addr;
}
virtual void set_descriptor(void *descriptor_data)=0;
/*!
* Send a Sendable object through the socket
@ -148,16 +121,19 @@ namespace fr
* @param obj The object to send
* @return The status of the send
*/
Status send(Sendable &obj);
Status send(Sendable &&obj);
virtual Status send(Sendable &obj);
virtual Status send(Sendable &&obj);
/*!
* Receive a Sendable object through the socket
*
* @param obj The object to receive
* @return The status of the receive
* 'Disconnected' if the socket disconnected
* 'Success' if the object could be read successfully
* 'WouldBlock' if the socket is in blocking mode and no data could be read
*/
Status receive(Sendable &obj);
virtual Status receive(Sendable &obj);
/*!
* Reads size bytes into dest from the socket.
@ -168,7 +144,10 @@ namespace fr
*
* @param dest Where to read the data into
* @param buffer_size The number of bytes to read
* @return Operation status.
* @return Operation status:
* 'Disconnected' if the socket disconnected
* 'Success' if buffer_size bytes could be read successfully
* 'WouldBlock' if the socket is in blocking mode and no data could be read
*/
Status receive_all(void *dest, size_t buffer_size);
@ -207,6 +186,25 @@ namespace fr
*/
void set_max_receive_size(uint32_t sz);
/*!
* Converts an fr::Socket::Status value to a printable string
*
* Throws an std::logic_error if status is out of range.
*
* @param status Status value to convert
* @return A string form version
*/
static const std::string &status_to_string(fr::Socket::Status status);
/*!
* Ends, and closes the connection.
* There is a distinction between 'disconnect' and 'close_socket',
* in that 'disconnect' should end the connection properly (such as sending
* disconnect packets depending on the protocol), before calling 'close_socket' itself.
* 'close_socket' should just close the client connection and be done with it.
*/
virtual void disconnect();
/*!
* Gets the max packet size. See set_max_packet_size
* for more information.
@ -220,16 +218,31 @@ namespace fr
}
/*!
* Converts an fr::Socket::Status value to a printable string
* Gets the socket's printable remote address
*
* Throws an std::logic_error if status is out of range.
*
* @param status Status value to convert
* @return A string form version
* @return The string address
*/
static const std::string &status_to_string(fr::Socket::Status status);
inline const std::string &get_remote_address()
{
return remote_address;
}
/*!
* Sets the connections remote address.
*
* @param addr The remote address to use
*/
inline void set_remote_address(const std::string &addr)
{
remote_address = addr;
}
protected:
/*!
* Close the connection.
*/
virtual void close_socket()=0;
/*!
* Applies requested socket options to the socket.
* Should be called when a new socket is created.

View File

@ -16,9 +16,10 @@ class TcpSocket : public Socket
public:
TcpSocket() noexcept;
virtual ~TcpSocket() noexcept;
TcpSocket(TcpSocket &&other) noexcept
: socket_descriptor(other.socket_descriptor){}
void operator=(const TcpSocket &other)=delete;
TcpSocket(TcpSocket &&) = delete;
TcpSocket(const TcpSocket &) = delete;
void operator=(TcpSocket &&)=delete;
void operator=(const TcpSocket &)=delete;
/*!
* Close the connection.
@ -30,17 +31,10 @@ public:
*
* @param address The address of the socket to connect to
* @param port The port of the socket to connect to
* @param timeout The number of seconds to wait before timing the connection attempt out. Pass -1 for default.
* @return A Socket::Status indicating the status of the operation.
* @param timeout The number of seconds to wait before timing the connection attempt out. Pass {} for default.
* @return A Socket::Status indicating the status of the operation. (Success on success, an error type on failure).
*/
virtual Socket::Status connect(const std::string &address, const std::string &port, std::chrono::seconds timeout) override;
/*!
* Sets the socket file descriptor.
*
* @param descriptor The socket descriptor.
*/
virtual void set_descriptor(int descriptor) override;
Socket::Status connect(const std::string &address, const std::string &port, std::chrono::seconds timeout) override;
/*!
* Attempts to send raw data down the socket, without
@ -51,7 +45,7 @@ public:
* @param size The number of bytes, from data to send. Be careful not to overflow.
* @return The status of the operation.
*/
virtual Status send_raw(const char *data, size_t size) override;
Status send_raw(const char *data, size_t size) override;
/*!
@ -64,17 +58,29 @@ public:
* @param data Where to store the received data.
* @param buffer_size The number of bytes to try and receive. Be sure that it's not larger than data.
* @param received Will be filled with the number of bytes actually received, might be less than you requested.
* @return The status of the operation, if the socket has disconnected etc.
* @return The status of the operation:
* 'WouldBlock' if no data has been received, and the socket is in non-blocking mode
* 'Disconnected' if the socket has disconnected.
* 'Success' All the bytes you wanted have been read
*/
virtual Status receive_raw(void *data, size_t buffer_size, size_t &received) override;
Status receive_raw(void *data, size_t buffer_size, size_t &received) override;
/*!
* Sets if the socket should be blocking or non-blocking.
*
* @note This must be set *WHILST* connected
* @param should_block True to block, false otherwise.
*/
void set_blocking(bool should_block) override;
/*!
* Sets the socket file descriptor. Internally used.
*
* @note For TcpSocket, this should be a pointer to a int32_t. A copy is made.
* @param descriptor_data The socket descriptor data, set up by the Listener.
*/
void set_descriptor(void *descriptor_data) override;
/*!
* Gets the unerlying socket descriptor
*

View File

@ -91,11 +91,11 @@ namespace fr
//Initialise mbedtls
int error = 0;
std::unique_ptr<mbedtls_ssl_context> ssl(new mbedtls_ssl_context);
mbedtls_net_context client_fd;
auto client_fd = std::make_unique<mbedtls_net_context>();
mbedtls_ssl_init(ssl.get());
mbedtls_net_init(&client_fd);
auto free_contexts = [&](){mbedtls_ssl_free(ssl.get()); mbedtls_net_free(&client_fd);};
mbedtls_net_init(client_fd.get());
auto free_contexts = [&](){mbedtls_ssl_free(ssl.get()); mbedtls_net_free(client_fd.get());};
if((error = mbedtls_ssl_setup(ssl.get(), &conf ) ) != 0)
{
std::cout << "Failed to apply SSL setings: " << error << std::endl;
@ -106,14 +106,14 @@ namespace fr
//Accept a connection
char client_ip[INET6_ADDRSTRLEN] = {0};
size_t ip_len = 0;
if((error = mbedtls_net_accept(&listen_fd, &client_fd, client_ip, sizeof(client_ip), &ip_len)) != 0)
if((error = mbedtls_net_accept(&listen_fd, client_fd.get(), client_ip, sizeof(client_ip), &ip_len)) != 0)
{
std::cout << "Accept error: " << error << std::endl;
free_contexts();
return Socket::Error;
}
mbedtls_ssl_set_bio(ssl.get(), &client_fd, mbedtls_net_send, mbedtls_net_recv, nullptr);
mbedtls_ssl_set_bio(ssl.get(), client_fd.get(), mbedtls_net_send, mbedtls_net_recv, nullptr);
//SSL Handshake
while((error = mbedtls_ssl_handshake(ssl.get())) != 0)
@ -132,14 +132,14 @@ namespace fr
char client_printable_addr[INET6_ADDRSTRLEN];
struct sockaddr_storage socket_address{};
socklen_t socket_length;
error = getpeername(client_fd.fd, (struct sockaddr*)&socket_address, &socket_length);
error = getpeername(client_fd->fd, (struct sockaddr*)&socket_address, &socket_length);
if(error == 0)
error = getnameinfo((sockaddr*)&socket_address, socket_length, client_printable_addr, sizeof(client_printable_addr), nullptr,0,NI_NUMERICHOST);
if(error != 0)
strcpy(client_printable_addr, "unknown");
client.set_ssl_context(std::move(ssl));
client.set_descriptor(client_fd.fd);
client.set_descriptor(client_fd.release());
client.set_remote_address(client_printable_addr);
return Socket::Success;
}

View File

@ -16,7 +16,6 @@ namespace fr
{
//Initialise mbedtls structures
mbedtls_ssl_config_init(&conf);
ssl_socket_descriptor.fd = -1;
}
SSLSocket::~SSLSocket() noexcept
@ -35,9 +34,10 @@ namespace fr
mbedtls_ssl_close_notify(ssl.get());
mbedtls_ssl_free(ssl.get());
}
if(ssl_socket_descriptor.fd > -1)
mbedtls_net_free(&ssl_socket_descriptor);
ssl_socket_descriptor.fd = -1;
if(ssl_socket_descriptor)
mbedtls_net_free(ssl_socket_descriptor.get());
ssl_socket_descriptor = nullptr;
}
Socket::Status SSLSocket::send_raw(const char *data, size_t size)
@ -63,13 +63,11 @@ namespace fr
Socket::Status SSLSocket::receive_raw(void *data, size_t data_size, size_t &received)
{
int read = MBEDTLS_ERR_SSL_WANT_READ;
int read = 0;
received = 0;
while(read == MBEDTLS_ERR_SSL_WANT_READ || read == MBEDTLS_ERR_SSL_WANT_WRITE)
{
read = mbedtls_ssl_read(ssl.get(), (unsigned char *)data, data_size);
}
read = mbedtls_ssl_read(ssl.get(), (unsigned char *)data, data_size);
if(read == MBEDTLS_ERR_SSL_WANT_READ || read == MBEDTLS_ERR_SSL_WANT_WRITE)
return Socket::Status::WouldBlock;
if(read <= 0)
{
@ -87,7 +85,7 @@ namespace fr
//Initialise mbedtls stuff
ssl = std::make_unique<mbedtls_ssl_context>();
mbedtls_ssl_init(ssl.get());
mbedtls_net_init(&ssl_socket_descriptor);
mbedtls_net_init(ssl_socket_descriptor.get());
//Do to mbedtls not supporting connect timeouts, we have to use an fr::TcpSocket to
//Open the descriptor, and then steal it. This is a hack.
@ -96,9 +94,10 @@ namespace fr
auto ret = socket.connect(address, port, timeout);
if(ret != fr::Socket::Success)
return ret;
ssl_socket_descriptor.fd = socket.get_socket_descriptor();
ssl_socket_descriptor = std::make_unique<mbedtls_net_context>();
ssl_socket_descriptor->fd = socket.get_socket_descriptor();
remote_address = socket.get_remote_address();
socket.set_descriptor(-1);
socket.set_descriptor(nullptr);
}
//Initialise SSL data structures
@ -155,9 +154,9 @@ namespace fr
ssl = std::move(context);
}
void SSLSocket::set_descriptor(int descriptor)
void SSLSocket::set_descriptor(void *descriptor)
{
ssl_socket_descriptor.fd = descriptor;
ssl_socket_descriptor.reset(static_cast<mbedtls_net_context*>(descriptor));
reconfigure_socket();
}

View File

@ -56,10 +56,12 @@ namespace fr
size_t received = 0;
auto *arr = (char*)dest;
Status status = receive_raw(&arr[bytes_read], (size_t)bytes_remaining, received);
if(status != fr::Socket::Success)
if(status == fr::Socket::Disconnected)
return status;
bytes_remaining -= received;
bytes_read += received;
if(status == fr::Socket::WouldBlock && bytes_read == 0)
return status;
}
return Socket::Status::Success;
@ -126,4 +128,9 @@ namespace fr
throw std::logic_error("Socket::status_to_string(): Invalid status value " + std::to_string(status));
return map[status];
}
void Socket::disconnect()
{
close_socket();
}
}

View File

@ -86,7 +86,7 @@ namespace fr
Socket::Status TcpListener::accept(Socket &client_)
{
//Cast to TcpSocket. Will throw bad cast on failure.
auto &client = dynamic_cast<TcpSocket&>(client_);
TcpSocket &client = dynamic_cast<TcpSocket&>(client_);
//Prepare to wait for the client
sockaddr_storage client_addr{};
@ -105,7 +105,7 @@ namespace fr
strcpy(client_printable_addr, "unknown");
//Set client data
client.set_descriptor(client_descriptor);
client.set_descriptor(&client_descriptor);
client.set_remote_address(client_printable_addr);
return Socket::Success;

View File

@ -78,10 +78,10 @@ namespace fr
}
void TcpSocket::set_descriptor(int descriptor)
void TcpSocket::set_descriptor(void *descriptor)
{
socket_descriptor = *static_cast<int32_t*>(descriptor);
reconfigure_socket();
socket_descriptor = descriptor;
}
Socket::Status TcpSocket::connect(const std::string &address, const std::string &port, std::chrono::seconds timeout)