diff --git a/include/frnetlib/Sendable.h b/include/frnetlib/Sendable.h index 61bb2f0..eff0bfb 100644 --- a/include/frnetlib/Sendable.h +++ b/include/frnetlib/Sendable.h @@ -19,7 +19,7 @@ namespace fr * sockets. * * @param socket The socket to send through - * @return Status indicating if the send succeeded or not. + * @return Status indicating if the send succeeded or not. This is dependent on the underlying type. */ virtual Socket::Status send(Socket *socket) const = 0; @@ -29,7 +29,7 @@ namespace fr * sockets. * * @param socket The socket to send through - * @return Status indicating if the send succeeded or not. + * @return Status indicating if the send succeeded or not. This is dependent on the underlying type. */ virtual Socket::Status receive(Socket *socket) = 0; }; diff --git a/include/frnetlib/Socket.h b/include/frnetlib/Socket.h index 6299d7f..1dc44bf 100644 --- a/include/frnetlib/Socket.h +++ b/include/frnetlib/Socket.h @@ -39,7 +39,8 @@ namespace fr ReceiveError = 17, AcceptError = 18, SSLError = 19, - NoRouteToHost = 20 + NoRouteToHost = 20, + Timeout = 21, //Remember to update status_to_string if more are added }; @@ -134,7 +135,7 @@ namespace fr * Send a Sendable object through the socket * * @param obj The object to send - * @return The status of the send + * @return The status of the send. This is dependant type being sent. */ virtual Status send(const Sendable &obj); @@ -229,11 +230,38 @@ namespace fr reconfigure_socket(); } + /*! + * Gets the socket receive timeout. + * + * @return Socket timeout in milliseconds. 0 if none. + */ inline uint32_t get_receive_timeout() const { return socket_read_timeout; } + /*! + * Sets a timeout which applies when sending data. + * + * @param timeout The maximum number of milliseconds to wait on a socket write before returning. Pass + * 0 (default) for no timeout. + */ + inline void set_send_timeout(uint32_t timeout) + { + socket_write_timeout = timeout; + reconfigure_socket(); + } + + /*! + * Gets the socket send timeout. + * + * @return Socket send timeout in milliseconds. 0 if none. + */ + inline uint32_t get_send_timeout() const + { + return socket_write_timeout; + } + /*! * Gets the max packet size. See set_max_packet_size * for more information. If this returns 0, then @@ -282,6 +310,7 @@ namespace fr int ai_family; uint32_t max_receive_size; uint32_t socket_read_timeout; + uint32_t socket_write_timeout; }; } diff --git a/src/SSLSocket.cpp b/src/SSLSocket.cpp index 1d3e781..e94c2b0 100644 --- a/src/SSLSocket.cpp +++ b/src/SSLSocket.cpp @@ -78,7 +78,7 @@ namespace fr { if(is_blocking) { - continue; + return Socket::Status::Timeout; } return Socket::Status::WouldBlock; } @@ -111,7 +111,7 @@ namespace fr { if(is_blocking) { - return Socket::Status::WouldBlock; + return Socket::Status::Timeout; } continue; } diff --git a/src/Socket.cpp b/src/Socket.cpp index 25439fc..2f58503 100644 --- a/src/Socket.cpp +++ b/src/Socket.cpp @@ -17,8 +17,9 @@ namespace fr { Socket::Socket() : ai_family(AF_UNSPEC), - max_receive_size(0), - socket_read_timeout(0) + max_receive_size(0), + socket_read_timeout(0), + socket_write_timeout(0) { init_wsa(); } @@ -157,7 +158,8 @@ namespace fr } case NoRouteToHost: return "No Route To Host"; - break; + case Timeout: + return "Timeout"; default: return "Unknown"; } diff --git a/src/TcpSocket.cpp b/src/TcpSocket.cpp index 2d3327b..e3b4743 100644 --- a/src/TcpSocket.cpp +++ b/src/TcpSocket.cpp @@ -29,7 +29,7 @@ namespace fr while(sent < size) { int64_t status = ::send(socket_descriptor, data + sent, size - sent, 0); - if(status > 0) + if(status >= 0) { sent += status; continue; @@ -37,6 +37,10 @@ namespace fr if(errno == EWOULDBLOCK) { + if(is_blocking) + { + return Socket::Status::Timeout; + } return Socket::Status::WouldBlock; } else if(errno == EINTR) @@ -74,6 +78,10 @@ namespace fr { if(errno == EWOULDBLOCK) { + if(is_blocking) + { + return Socket::Status::Timeout; + } return Socket::Status::WouldBlock; } else if(errno == EINTR) @@ -220,6 +228,11 @@ namespace fr tv.tv_sec = get_receive_timeout() / 1000; tv.tv_usec = (get_receive_timeout() % 1000) * 1000; setsockopt(socket_descriptor, SOL_SOCKET, SO_RCVTIMEO, (const char*)&tv, sizeof tv); + + //Apply send timeout + tv.tv_sec = get_send_timeout() / 1000; + tv.tv_usec = (get_send_timeout() % 1000) * 1000; + setsockopt(socket_descriptor, SOL_SOCKET, SO_SNDTIMEO, (const char*)&tv, sizeof(tv)); #else //Disable Nagle's algorithm setsockopt(get_socket_descriptor(), IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(one)); @@ -228,6 +241,10 @@ namespace fr //Apply receive timeout DWORD timeout_dword = static_cast(get_receive_timeout()); setsockopt(socket_descriptor, SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout_dword, sizeof timeout_dword); + + //Apply send timeout + timeout_dword = static_cast(get_send_timeout()); + setsockopt(socket_descriptor, SOL_SOCKET, SO_SNDTIMEO, (const char*)&timeout_dword, sizeof timeout_dword); #endif }