Added ability to set/get recv timeouts. Fixed HTTP receive bug.

Recv timeouts can now be specified for sockets, which is the maximum amount of time to wait before returning during a receive. They will return WouldBlock if no data was received during the wait time.

Receiving a HTTP request in non-blocking mode will no longer fail.
This commit is contained in:
Fred Nicolson 2018-08-16 11:24:52 +01:00
parent decb0b10f9
commit 8a4ee937b1
9 changed files with 190 additions and 91 deletions

View File

@ -45,7 +45,7 @@ namespace fr
* @param data_size The number of bytes to try and receive. Be sure that it's not larger than data. * @param data_size The number of bytes to try and receive. Be sure that it's not larger than data.
* @param received Will be filled with the number of bytes actually received, might be less than you requested. * @param received Will be filled with the number of bytes actually received, might be less than you requested.
* @return The status of the operation: * @return The status of the operation:
* 'WouldBlock' if no data has been received, and the socket is in non-blocking mode * 'WouldBlock' if no data has been received, and the socket is in non-blocking mode or the operation has timed out
* 'Disconnected' if the socket has disconnected. * 'Disconnected' if the socket has disconnected.
* 'Success' All the bytes you wanted have been read * 'Success' All the bytes you wanted have been read
*/ */
@ -86,6 +86,12 @@ namespace fr
*/ */
void verify_certificates(bool should_verify); void verify_certificates(bool should_verify);
/*!
* Applies requested socket options to the socket.
* Should be called when a new socket is created.
*/
void reconfigure_socket() override;
/*! /*!
* Gets the underlying socket descriptor. * Gets the underlying socket descriptor.
* *
@ -133,6 +139,7 @@ namespace fr
mbedtls_ssl_config conf; mbedtls_ssl_config conf;
uint32_t flags; uint32_t flags;
bool should_verify; bool should_verify;
uint32_t receive_timeout;
}; };
} }

View File

@ -90,7 +90,7 @@ namespace fr
* @param data Where to store the received data. * @param data Where to store the received data.
* @param data_size The number of bytes to try and receive. Be sure that it's not larger than data. * @param data_size The number of bytes to try and receive. Be sure that it's not larger than data.
* @param received Will be filled with the number of bytes actually received, might be less than you requested. * @param received Will be filled with the number of bytes actually received, might be less than you requested.
* @return The status of the operation, if the socket has disconnected etc. * @return The status of the operation, if the socket has disconnected etc. This is dependent on the underlying socket type.
*/ */
virtual Status receive_raw(void *data, size_t data_size, size_t &received) = 0; virtual Status receive_raw(void *data, size_t data_size, size_t &received) = 0;
@ -146,7 +146,7 @@ namespace fr
* @return Operation status: * @return Operation status:
* 'Disconnected' if the socket disconnected * 'Disconnected' if the socket disconnected
* 'Success' if buffer_size bytes could be read successfully * 'Success' if buffer_size bytes could be read successfully
* 'WouldBlock' if the socket is in blocking mode and no data could be read * 'WouldBlock' if the socket is in blocking mode and no data could be read, or if the read timed out before any data was received
*/ */
Status receive_all(void *dest, size_t buffer_size); Status receive_all(void *dest, size_t buffer_size);
@ -168,25 +168,6 @@ namespace fr
*/ */
void set_inet_version(IP version); void set_inet_version(IP version);
/*!
* Sets the maximum receivable size that may be received by the socket. This does
* not apply to receive_raw(), but only things like fr::Packet.
*
* If a client attempts to send a packet larger than sz bytes, then
* the client will be disconnected and an fr::Socket::MaxPacketSizeExceeded
* will be returned. Pass '0' to indicate no limit.
*
* This should be used to prevent potential abuse, as a client could say that
* it's going to send a 200GiB packet, which would cause the Socket to try and
* allocate that much memory to accommodate the data, which is most likely not
* desirable.
*
* By default, there is no limit (0)
*
* @param sz The maximum number of bytes that may be received in an fr::Packet
*/
void set_max_receive_size(uint32_t sz);
/*! /*!
* Converts an fr::Socket::Status value to a printable string * Converts an fr::Socket::Status value to a printable string
* *
@ -206,6 +187,48 @@ namespace fr
*/ */
virtual void disconnect(); virtual void disconnect();
/*!
* Sets the maximum receivable size that may be received by the socket. This does
* not apply to receive_raw(), but only things like fr::Packet.
*
* If a client attempts to send a packet larger than sz bytes, then
* the client will be disconnected and an fr::Socket::MaxPacketSizeExceeded
* will be returned. Pass '0' to indicate no limit.
*
* This should be used to prevent potential abuse, as a client could say that
* it's going to send a 200GiB packet, which would cause the Socket to try and
* allocate that much memory to accommodate the data, which is most likely not
* desirable.
*
* By default, there is no limit (0)
*
* @param sz The maximum number of bytes that may be received in an fr::Packet
*/
inline void set_max_receive_size(uint32_t sz)
{
max_receive_size = sz;
}
/*!
* Sets a timeout which applies when receiving data.
*
* @note When receiving framed data, such as with receive(), this timeout will apply to the underlying
* individual reads, but not for the message as a whole.
*
* @param timeout The maximum number of milliseconds to wait on a socket read before returning. Pass
* 0 (default) for no timeout.
*/
inline void set_receive_timeout(uint32_t timeout)
{
socket_read_timeout = timeout;
reconfigure_socket();
}
inline uint32_t get_receive_timeout() const
{
return socket_read_timeout;
}
/*! /*!
* Gets the max packet size. See set_max_packet_size * Gets the max packet size. See set_max_packet_size
* for more information. If this returns 0, then * for more information. If this returns 0, then
@ -248,12 +271,13 @@ namespace fr
* Applies requested socket options to the socket. * Applies requested socket options to the socket.
* Should be called when a new socket is created. * Should be called when a new socket is created.
*/ */
void reconfigure_socket(); virtual void reconfigure_socket()=0;
std::string remote_address; std::string remote_address;
bool is_blocking; bool is_blocking;
int ai_family; int ai_family;
uint32_t max_receive_size; uint32_t max_receive_size;
uint32_t socket_read_timeout;
}; };
} }

View File

@ -54,7 +54,7 @@ public:
* @param buffer_size The number of bytes to try and receive. Be sure that it's not larger than data. * @param buffer_size The number of bytes to try and receive. Be sure that it's not larger than data.
* @param received Will be filled with the number of bytes actually received, might be less than you requested. * @param received Will be filled with the number of bytes actually received, might be less than you requested.
* @return The status of the operation: * @return The status of the operation:
* 'WouldBlock' if no data has been received, and the socket is in non-blocking mode * 'WouldBlock' if no data has been received, and the socket is in non-blocking mode or operation has timed out
* 'Disconnected' if the socket has disconnected. * 'Disconnected' if the socket has disconnected.
* 'Success' All the bytes you wanted have been read * 'Success' All the bytes you wanted have been read
*/ */
@ -83,6 +83,12 @@ public:
*/ */
int32_t get_socket_descriptor() const override; int32_t get_socket_descriptor() const override;
/*!
* Applies requested socket options to the socket.
* Should be called when a new socket is created.
*/
void reconfigure_socket() override;
/*! /*!
* Checks to see if we're connected to a socket or not * Checks to see if we're connected to a socket or not
* *

View File

@ -8,11 +8,11 @@
//Format: Major | Minor | Patch //Format: Major | Minor | Patch
#define FRNETLIB_VERSION_MAJOR 1 #define FRNETLIB_VERSION_MAJOR 1
#define FRNETLIB_VERSION_MINOR 0 #define FRNETLIB_VERSION_MINOR 1
#define FRNETLIB_VERSION_PATCH 2 #define FRNETLIB_VERSION_PATCH 0
#define FRNETLIB_VERSION_NUMBER (FRNETLIB_VERSION_MAJOR * 100*100 + FRNETLIB_VERSION_MINOR * 100 + FRNETLIB_VERSION_PATCH) #define FRNETLIB_VERSION_NUMBER (FRNETLIB_VERSION_MAJOR * 100*100 + FRNETLIB_VERSION_MINOR * 100 + FRNETLIB_VERSION_PATCH)
#define FRNETLIB_VERSION_STRING "1.0.2" #define FRNETLIB_VERSION_STRING "1.1.0"
#define FRNETLIB_VERSION_STRING_FULL "frnetlib 1.0.2" #define FRNETLIB_VERSION_STRING_FULL "frnetlib 1.0.0"
#endif //FRNETLIB_VERSION_H #endif //FRNETLIB_VERSION_H

View File

@ -999,13 +999,15 @@ namespace fr
Socket::Status Http::receive(Socket *socket) Socket::Status Http::receive(Socket *socket)
{ {
char recv_buffer[RECV_CHUNK_SIZE]; char recv_buffer[RECV_CHUNK_SIZE];
size_t received = 0;
fr::Socket::Status state; fr::Socket::Status state;
size_t total_received = 0;
size_t received = 0;
do do
{ {
//Receive the request //Receive the request
Socket::Status status = socket->receive_raw(recv_buffer, RECV_CHUNK_SIZE, received); Socket::Status status = socket->receive_raw(recv_buffer, RECV_CHUNK_SIZE, received);
if(status != Socket::Success) total_received += received;
if(status != Socket::Success && !(status == fr::Socket::WouldBlock && total_received != 0))
return status; return status;
//Parse it //Parse it

View File

@ -37,7 +37,7 @@ namespace fr
header_ended = header_end != std::string::npos; header_ended = header_end != std::string::npos;
//Ensure that the header doesn't exceed max length //Ensure that the header doesn't exceed max length
if(!header_ended && body.size() > MAX_HTTP_HEADER_SIZE || header_ended && header_end > MAX_HTTP_HEADER_SIZE) if((!header_ended && body.size() > MAX_HTTP_HEADER_SIZE) || (header_ended && header_end > MAX_HTTP_HEADER_SIZE))
{ {
return fr::Socket::HttpHeaderTooBig; return fr::Socket::HttpHeaderTooBig;
} }

View File

@ -7,12 +7,15 @@
#include <utility> #include <utility>
#include <mbedtls/net_sockets.h> #include <mbedtls/net_sockets.h>
#include <frnetlib/SSLSocket.h>
namespace fr namespace fr
{ {
SSLSocket::SSLSocket(std::shared_ptr<SSLContext> ssl_context_) noexcept SSLSocket::SSLSocket(std::shared_ptr<SSLContext> ssl_context_) noexcept
: ssl_context(std::move(ssl_context_)), : ssl_context(std::move(ssl_context_)),
should_verify(true) should_verify(true),
receive_timeout(0)
{ {
//Initialise mbedtls structures //Initialise mbedtls structures
mbedtls_ssl_config_init(&conf); mbedtls_ssl_config_init(&conf);
@ -63,19 +66,46 @@ namespace fr
Socket::Status SSLSocket::receive_raw(void *data, size_t data_size, size_t &received) Socket::Status SSLSocket::receive_raw(void *data, size_t data_size, size_t &received)
{ {
int read = 0; ssize_t status = 0;
received = 0; if(receive_timeout == 0)
read = mbedtls_ssl_read(ssl.get(), (unsigned char *)data, data_size);
if(read == MBEDTLS_ERR_SSL_WANT_READ || read == MBEDTLS_ERR_SSL_WANT_WRITE)
return Socket::Status::WouldBlock;
if(read <= 0)
{ {
status = mbedtls_ssl_read(ssl.get(), (unsigned char *)data, data_size);
if(status <= 0)
{
if(status == MBEDTLS_ERR_SSL_WANT_READ || status == MBEDTLS_ERR_SSL_WANT_WRITE)
{
return Socket::Status::WouldBlock;
}
close_socket(); close_socket();
return Socket::Status::Disconnected; return Socket::Status::Disconnected;
} }
}
else
{
do
{
status = mbedtls_net_recv_timeout(ssl.get(), (unsigned char *)data, data_size, receive_timeout);
if(status <= 0)
{
if(status == MBEDTLS_ERR_SSL_TIMEOUT)
{
return Socket::Status::WouldBlock;
}
else if(status == MBEDTLS_ERR_SSL_WANT_READ)
{
continue; //try again, interrupted before anything could be received
}
received += read; close_socket();
return Socket::Status::Disconnected;
}
break;
} while(true);
}
received = static_cast<size_t>(status);
return Socket::Status::Success; return Socket::Status::Success;
} }
@ -163,4 +193,21 @@ namespace fr
{ {
should_verify = should_verify_; should_verify = should_verify_;
} }
void SSLSocket::reconfigure_socket()
{
int one = 1;
#ifndef _WIN32
//Disable Nagle's algorithm
setsockopt(get_socket_descriptor(), SOL_TCP, TCP_NODELAY, (char*)&one, sizeof(one));
#else
//Disable Nagle's algorithm
setsockopt(get_socket_descriptor(), IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(one));
setsockopt(get_socket_descriptor(), SOL_SOCKET, SO_EXCLUSIVEADDRUSE, (char*)&one, sizeof(one));
//Apply receive timeout
DWORD timeout_dword = static_cast<DWORD>(get_receive_timeout());
setsockopt(socket_descriptor, SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout_dword, sizeof timeout_dword);
#endif
}
} }

View File

@ -15,7 +15,8 @@ namespace fr
Socket::Socket() Socket::Socket()
: is_blocking(true), : is_blocking(true),
ai_family(AF_UNSPEC), ai_family(AF_UNSPEC),
max_receive_size(0) max_receive_size(0),
socket_read_timeout(0)
{ {
init_wsa(); init_wsa();
} }
@ -64,16 +65,6 @@ namespace fr
::shutdown(get_socket_descriptor(), SHUT_RDWR); ::shutdown(get_socket_descriptor(), SHUT_RDWR);
} }
void Socket::reconfigure_socket()
{
//todo: Perhaps allow for these settings to be modified
int one = 1;
setsockopt(get_socket_descriptor(), SOL_TCP, TCP_NODELAY, (char*)&one, sizeof(one));
#ifdef _WIN32
setsockopt(get_socket_descriptor(), SOL_SOCKET, SO_EXCLUSIVEADDRUSE, (char*)&one, sizeof(one));
#endif
}
void Socket::set_inet_version(Socket::IP version) void Socket::set_inet_version(Socket::IP version)
{ {
switch(version) switch(version)
@ -92,11 +83,6 @@ namespace fr
} }
} }
void Socket::set_max_receive_size(uint32_t sz)
{
max_receive_size = sz;
}
const std::string &Socket::status_to_string(fr::Socket::Status status) const std::string &Socket::status_to_string(fr::Socket::Status status)
{ {
static std::vector<std::string> map = { static std::vector<std::string> map = {
@ -114,7 +100,8 @@ namespace fr
"Not enough data", "Not enough data",
"Parse error", "Parse error",
"HTTP header too big", "HTTP header too big",
"HTTP body too big"}; "HTTP body too big"
};
if(status < 0 || status > map.size()) if(status < 0 || status > map.size())
throw std::logic_error("Socket::status_to_string(): Invalid status value " + std::to_string(status)); throw std::logic_error("Socket::status_to_string(): Invalid status value " + std::to_string(status));

View File

@ -4,6 +4,8 @@
#include <iostream> #include <iostream>
#include <frnetlib/SocketSelector.h> #include <frnetlib/SocketSelector.h>
#include <frnetlib/TcpSocket.h>
#include "frnetlib/TcpSocket.h" #include "frnetlib/TcpSocket.h"
#define DEFAULT_SOCKET_TIMEOUT 20 #define DEFAULT_SOCKET_TIMEOUT 20
@ -51,29 +53,29 @@ namespace fr
Socket::Status TcpSocket::receive_raw(void *data, size_t buffer_size, size_t &received) Socket::Status TcpSocket::receive_raw(void *data, size_t buffer_size, size_t &received)
{ {
received = 0; ssize_t status = 0;
do
//Read RECV_CHUNK_SIZE bytes into the recv buffer
int64_t status = ::recv(socket_descriptor, (char*)data, buffer_size, 0);
if(status > 0)
{ {
received += status; status = ::recv(socket_descriptor, (char*)data, buffer_size, 0);
} if(status <= 0)
else
{ {
if(errno == EWOULDBLOCK || errno == EAGAIN) if(errno == EWOULDBLOCK || errno == EAGAIN)
{ {
return Socket::Status::WouldBlock; return Socket::Status::WouldBlock;
} }
else if(errno == EINTR)
{
continue; //try again, interrupted before anything could be received
}
close_socket(); close_socket();
return Socket::Status::Disconnected; return Socket::Status::Disconnected;
} }
break;
} while(true);
if(received > buffer_size)
received = buffer_size;
received = static_cast<size_t>(status);
return Socket::Status::Success; return Socket::Status::Success;
} }
@ -185,4 +187,28 @@ namespace fr
{ {
return socket_descriptor; return socket_descriptor;
} }
void TcpSocket::reconfigure_socket()
{
int one = 1;
#ifndef _WIN32
//Disable Nagle's algorithm
setsockopt(get_socket_descriptor(), SOL_TCP, TCP_NODELAY, (char*)&one, sizeof(one));
//Apply receive timeout
struct timeval tv = {};
tv.tv_sec = get_receive_timeout() / 1000;
tv.tv_usec = (get_receive_timeout() % 1000) * 1000;
setsockopt(socket_descriptor, SOL_SOCKET, SO_RCVTIMEO, (const char*)&tv, sizeof tv);
#else
//Disable Nagle's algorithm
setsockopt(get_socket_descriptor(), IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(one));
setsockopt(get_socket_descriptor(), SOL_SOCKET, SO_EXCLUSIVEADDRUSE, (char*)&one, sizeof(one));
//Apply receive timeout
DWORD timeout_dword = static_cast<DWORD>(get_receive_timeout());
setsockopt(socket_descriptor, SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout_dword, sizeof timeout_dword);
#endif
}
} }