diff --git a/include/frnetlib/SSLSocket.h b/include/frnetlib/SSLSocket.h index 1145250..336295c 100644 --- a/include/frnetlib/SSLSocket.h +++ b/include/frnetlib/SSLSocket.h @@ -45,7 +45,7 @@ namespace fr * @param data_size The number of bytes to try and receive. Be sure that it's not larger than data. * @param received Will be filled with the number of bytes actually received, might be less than you requested. * @return The status of the operation: - * 'WouldBlock' if no data has been received, and the socket is in non-blocking mode + * 'WouldBlock' if no data has been received, and the socket is in non-blocking mode or the operation has timed out * 'Disconnected' if the socket has disconnected. * 'Success' All the bytes you wanted have been read */ @@ -86,6 +86,12 @@ namespace fr */ void verify_certificates(bool should_verify); + /*! + * Applies requested socket options to the socket. + * Should be called when a new socket is created. + */ + void reconfigure_socket() override; + /*! * Gets the underlying socket descriptor. * @@ -133,6 +139,7 @@ namespace fr mbedtls_ssl_config conf; uint32_t flags; bool should_verify; + uint32_t receive_timeout; }; } diff --git a/include/frnetlib/Socket.h b/include/frnetlib/Socket.h index b23e6ae..bc9572d 100644 --- a/include/frnetlib/Socket.h +++ b/include/frnetlib/Socket.h @@ -90,7 +90,7 @@ namespace fr * @param data Where to store the received data. * @param data_size The number of bytes to try and receive. Be sure that it's not larger than data. * @param received Will be filled with the number of bytes actually received, might be less than you requested. - * @return The status of the operation, if the socket has disconnected etc. + * @return The status of the operation, if the socket has disconnected etc. This is dependent on the underlying socket type. */ virtual Status receive_raw(void *data, size_t data_size, size_t &received) = 0; @@ -146,7 +146,7 @@ namespace fr * @return Operation status: * 'Disconnected' if the socket disconnected * 'Success' if buffer_size bytes could be read successfully - * 'WouldBlock' if the socket is in blocking mode and no data could be read + * 'WouldBlock' if the socket is in blocking mode and no data could be read, or if the read timed out before any data was received */ Status receive_all(void *dest, size_t buffer_size); @@ -168,25 +168,6 @@ namespace fr */ void set_inet_version(IP version); - /*! - * Sets the maximum receivable size that may be received by the socket. This does - * not apply to receive_raw(), but only things like fr::Packet. - * - * If a client attempts to send a packet larger than sz bytes, then - * the client will be disconnected and an fr::Socket::MaxPacketSizeExceeded - * will be returned. Pass '0' to indicate no limit. - * - * This should be used to prevent potential abuse, as a client could say that - * it's going to send a 200GiB packet, which would cause the Socket to try and - * allocate that much memory to accommodate the data, which is most likely not - * desirable. - * - * By default, there is no limit (0) - * - * @param sz The maximum number of bytes that may be received in an fr::Packet - */ - void set_max_receive_size(uint32_t sz); - /*! * Converts an fr::Socket::Status value to a printable string * @@ -206,6 +187,48 @@ namespace fr */ virtual void disconnect(); + /*! + * Sets the maximum receivable size that may be received by the socket. This does + * not apply to receive_raw(), but only things like fr::Packet. + * + * If a client attempts to send a packet larger than sz bytes, then + * the client will be disconnected and an fr::Socket::MaxPacketSizeExceeded + * will be returned. Pass '0' to indicate no limit. + * + * This should be used to prevent potential abuse, as a client could say that + * it's going to send a 200GiB packet, which would cause the Socket to try and + * allocate that much memory to accommodate the data, which is most likely not + * desirable. + * + * By default, there is no limit (0) + * + * @param sz The maximum number of bytes that may be received in an fr::Packet + */ + inline void set_max_receive_size(uint32_t sz) + { + max_receive_size = sz; + } + + /*! + * Sets a timeout which applies when receiving data. + * + * @note When receiving framed data, such as with receive(), this timeout will apply to the underlying + * individual reads, but not for the message as a whole. + * + * @param timeout The maximum number of milliseconds to wait on a socket read before returning. Pass + * 0 (default) for no timeout. + */ + inline void set_receive_timeout(uint32_t timeout) + { + socket_read_timeout = timeout; + reconfigure_socket(); + } + + inline uint32_t get_receive_timeout() const + { + return socket_read_timeout; + } + /*! * Gets the max packet size. See set_max_packet_size * for more information. If this returns 0, then @@ -248,12 +271,13 @@ namespace fr * Applies requested socket options to the socket. * Should be called when a new socket is created. */ - void reconfigure_socket(); + virtual void reconfigure_socket()=0; std::string remote_address; bool is_blocking; int ai_family; uint32_t max_receive_size; + uint32_t socket_read_timeout; }; } diff --git a/include/frnetlib/TcpSocket.h b/include/frnetlib/TcpSocket.h index b75d920..0464ecb 100644 --- a/include/frnetlib/TcpSocket.h +++ b/include/frnetlib/TcpSocket.h @@ -54,7 +54,7 @@ public: * @param buffer_size The number of bytes to try and receive. Be sure that it's not larger than data. * @param received Will be filled with the number of bytes actually received, might be less than you requested. * @return The status of the operation: - * 'WouldBlock' if no data has been received, and the socket is in non-blocking mode + * 'WouldBlock' if no data has been received, and the socket is in non-blocking mode or operation has timed out * 'Disconnected' if the socket has disconnected. * 'Success' All the bytes you wanted have been read */ @@ -83,6 +83,12 @@ public: */ int32_t get_socket_descriptor() const override; + /*! + * Applies requested socket options to the socket. + * Should be called when a new socket is created. + */ + void reconfigure_socket() override; + /*! * Checks to see if we're connected to a socket or not * diff --git a/include/frnetlib/version.h b/include/frnetlib/version.h index 27f89f2..16de973 100644 --- a/include/frnetlib/version.h +++ b/include/frnetlib/version.h @@ -8,11 +8,11 @@ //Format: Major | Minor | Patch #define FRNETLIB_VERSION_MAJOR 1 -#define FRNETLIB_VERSION_MINOR 0 -#define FRNETLIB_VERSION_PATCH 2 +#define FRNETLIB_VERSION_MINOR 1 +#define FRNETLIB_VERSION_PATCH 0 #define FRNETLIB_VERSION_NUMBER (FRNETLIB_VERSION_MAJOR * 100*100 + FRNETLIB_VERSION_MINOR * 100 + FRNETLIB_VERSION_PATCH) -#define FRNETLIB_VERSION_STRING "1.0.2" -#define FRNETLIB_VERSION_STRING_FULL "frnetlib 1.0.2" +#define FRNETLIB_VERSION_STRING "1.1.0" +#define FRNETLIB_VERSION_STRING_FULL "frnetlib 1.0.0" #endif //FRNETLIB_VERSION_H diff --git a/src/Http.cpp b/src/Http.cpp index 756e2ba..6790ccd 100644 --- a/src/Http.cpp +++ b/src/Http.cpp @@ -999,13 +999,15 @@ namespace fr Socket::Status Http::receive(Socket *socket) { char recv_buffer[RECV_CHUNK_SIZE]; - size_t received = 0; fr::Socket::Status state; + size_t total_received = 0; + size_t received = 0; do { //Receive the request Socket::Status status = socket->receive_raw(recv_buffer, RECV_CHUNK_SIZE, received); - if(status != Socket::Success) + total_received += received; + if(status != Socket::Success && !(status == fr::Socket::WouldBlock && total_received != 0)) return status; //Parse it diff --git a/src/HttpRequest.cpp b/src/HttpRequest.cpp index 35a6e4e..0e0a56a 100644 --- a/src/HttpRequest.cpp +++ b/src/HttpRequest.cpp @@ -37,7 +37,7 @@ namespace fr header_ended = header_end != std::string::npos; //Ensure that the header doesn't exceed max length - if(!header_ended && body.size() > MAX_HTTP_HEADER_SIZE || header_ended && header_end > MAX_HTTP_HEADER_SIZE) + if((!header_ended && body.size() > MAX_HTTP_HEADER_SIZE) || (header_ended && header_end > MAX_HTTP_HEADER_SIZE)) { return fr::Socket::HttpHeaderTooBig; } diff --git a/src/SSLSocket.cpp b/src/SSLSocket.cpp index 259347c..4b97f47 100644 --- a/src/SSLSocket.cpp +++ b/src/SSLSocket.cpp @@ -7,12 +7,15 @@ #include #include +#include + namespace fr { SSLSocket::SSLSocket(std::shared_ptr ssl_context_) noexcept : ssl_context(std::move(ssl_context_)), - should_verify(true) + should_verify(true), + receive_timeout(0) { //Initialise mbedtls structures mbedtls_ssl_config_init(&conf); @@ -63,19 +66,46 @@ namespace fr Socket::Status SSLSocket::receive_raw(void *data, size_t data_size, size_t &received) { - int read = 0; - received = 0; - read = mbedtls_ssl_read(ssl.get(), (unsigned char *)data, data_size); - if(read == MBEDTLS_ERR_SSL_WANT_READ || read == MBEDTLS_ERR_SSL_WANT_WRITE) - return Socket::Status::WouldBlock; - - if(read <= 0) + ssize_t status = 0; + if(receive_timeout == 0) { - close_socket(); - return Socket::Status::Disconnected; + status = mbedtls_ssl_read(ssl.get(), (unsigned char *)data, data_size); + if(status <= 0) + { + if(status == MBEDTLS_ERR_SSL_WANT_READ || status == MBEDTLS_ERR_SSL_WANT_WRITE) + { + return Socket::Status::WouldBlock; + } + + close_socket(); + return Socket::Status::Disconnected; + } + } + else + { + do + { + status = mbedtls_net_recv_timeout(ssl.get(), (unsigned char *)data, data_size, receive_timeout); + if(status <= 0) + { + if(status == MBEDTLS_ERR_SSL_TIMEOUT) + { + return Socket::Status::WouldBlock; + } + else if(status == MBEDTLS_ERR_SSL_WANT_READ) + { + continue; //try again, interrupted before anything could be received + } + + close_socket(); + return Socket::Status::Disconnected; + } + break; + } while(true); } - received += read; + + received = static_cast(status); return Socket::Status::Success; } @@ -163,4 +193,21 @@ namespace fr { should_verify = should_verify_; } + + void SSLSocket::reconfigure_socket() + { + int one = 1; +#ifndef _WIN32 + //Disable Nagle's algorithm + setsockopt(get_socket_descriptor(), SOL_TCP, TCP_NODELAY, (char*)&one, sizeof(one)); +#else + //Disable Nagle's algorithm + setsockopt(get_socket_descriptor(), IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(one)); + setsockopt(get_socket_descriptor(), SOL_SOCKET, SO_EXCLUSIVEADDRUSE, (char*)&one, sizeof(one)); + + //Apply receive timeout + DWORD timeout_dword = static_cast(get_receive_timeout()); + setsockopt(socket_descriptor, SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout_dword, sizeof timeout_dword); +#endif + } } diff --git a/src/Socket.cpp b/src/Socket.cpp index 328f071..79c01c3 100644 --- a/src/Socket.cpp +++ b/src/Socket.cpp @@ -15,7 +15,8 @@ namespace fr Socket::Socket() : is_blocking(true), ai_family(AF_UNSPEC), - max_receive_size(0) + max_receive_size(0), + socket_read_timeout(0) { init_wsa(); } @@ -64,16 +65,6 @@ namespace fr ::shutdown(get_socket_descriptor(), SHUT_RDWR); } - void Socket::reconfigure_socket() - { - //todo: Perhaps allow for these settings to be modified - int one = 1; - setsockopt(get_socket_descriptor(), SOL_TCP, TCP_NODELAY, (char*)&one, sizeof(one)); -#ifdef _WIN32 - setsockopt(get_socket_descriptor(), SOL_SOCKET, SO_EXCLUSIVEADDRUSE, (char*)&one, sizeof(one)); -#endif - } - void Socket::set_inet_version(Socket::IP version) { switch(version) @@ -92,29 +83,25 @@ namespace fr } } - void Socket::set_max_receive_size(uint32_t sz) - { - max_receive_size = sz; - } - const std::string &Socket::status_to_string(fr::Socket::Status status) { static std::vector map = { - "Unknown", - "Success", - "Listen Failed", - "Bind Failed", - "Disconnected", - "Error", - "Would Block", - "Connection Failed", - "Handshake Failed", - "Verification Failed", - "Max packet size exceeded", - "Not enough data", - "Parse error", - "HTTP header too big", - "HTTP body too big"}; + "Unknown", + "Success", + "Listen Failed", + "Bind Failed", + "Disconnected", + "Error", + "Would Block", + "Connection Failed", + "Handshake Failed", + "Verification Failed", + "Max packet size exceeded", + "Not enough data", + "Parse error", + "HTTP header too big", + "HTTP body too big" + }; if(status < 0 || status > map.size()) throw std::logic_error("Socket::status_to_string(): Invalid status value " + std::to_string(status)); diff --git a/src/TcpSocket.cpp b/src/TcpSocket.cpp index c5d5cd5..1ab8a21 100644 --- a/src/TcpSocket.cpp +++ b/src/TcpSocket.cpp @@ -4,6 +4,8 @@ #include #include +#include + #include "frnetlib/TcpSocket.h" #define DEFAULT_SOCKET_TIMEOUT 20 @@ -51,29 +53,29 @@ namespace fr Socket::Status TcpSocket::receive_raw(void *data, size_t buffer_size, size_t &received) { - received = 0; - - //Read RECV_CHUNK_SIZE bytes into the recv buffer - int64_t status = ::recv(socket_descriptor, (char*)data, buffer_size, 0); - - if(status > 0) + ssize_t status = 0; + do { - received += status; - } - else - { - if(errno == EWOULDBLOCK || errno == EAGAIN) + status = ::recv(socket_descriptor, (char*)data, buffer_size, 0); + if(status <= 0) { - return Socket::Status::WouldBlock; + if(errno == EWOULDBLOCK || errno == EAGAIN) + { + return Socket::Status::WouldBlock; + } + else if(errno == EINTR) + { + continue; //try again, interrupted before anything could be received + } + + close_socket(); + return Socket::Status::Disconnected; } + break; + } while(true); - close_socket(); - return Socket::Status::Disconnected; - } - - if(received > buffer_size) - received = buffer_size; + received = static_cast(status); return Socket::Status::Success; } @@ -185,4 +187,28 @@ namespace fr { return socket_descriptor; } + + void TcpSocket::reconfigure_socket() + { + int one = 1; +#ifndef _WIN32 + //Disable Nagle's algorithm + setsockopt(get_socket_descriptor(), SOL_TCP, TCP_NODELAY, (char*)&one, sizeof(one)); + + //Apply receive timeout + struct timeval tv = {}; + tv.tv_sec = get_receive_timeout() / 1000; + tv.tv_usec = (get_receive_timeout() % 1000) * 1000; + setsockopt(socket_descriptor, SOL_SOCKET, SO_RCVTIMEO, (const char*)&tv, sizeof tv); +#else + //Disable Nagle's algorithm + setsockopt(get_socket_descriptor(), IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(one)); + setsockopt(get_socket_descriptor(), SOL_SOCKET, SO_EXCLUSIVEADDRUSE, (char*)&one, sizeof(one)); + + //Apply receive timeout + DWORD timeout_dword = static_cast(get_receive_timeout()); + setsockopt(socket_descriptor, SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout_dword, sizeof timeout_dword); +#endif + } + } \ No newline at end of file