Added support for connect timeouts.

Both fr::TcpSocket and fr::SSLSocket can have timeouts specified when connecting. This works by putting the socket into non-blocking mode, making a connect, and then selecting on the socket for the requested timeout. If the select times out then we've failed to connect, if it didn't time out then we've connected.
This commit is contained in:
Fred Nicolson 2018-01-10 17:08:16 +00:00
parent 2215d068af
commit 30354f15bc
10 changed files with 97 additions and 29 deletions

View File

@ -6,7 +6,7 @@ set(FRNETLIB_LINK_LIBRARIES "")
set( CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/cmake_modules) set( CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/cmake_modules)
#User options #User options
option(USE_SSL "Use SSL" OFF) option(USE_SSL "Use SSL" ON)
set(FRNETLIB_BUILD_SHARED_LIBS false CACHE BOOL "Build shared library.") set(FRNETLIB_BUILD_SHARED_LIBS false CACHE BOOL "Build shared library.")
set(MAX_HTTP_HEADER_SIZE "0xC800" CACHE STRING "The maximum allowed HTTP header size in bytes") set(MAX_HTTP_HEADER_SIZE "0xC800" CACHE STRING "The maximum allowed HTTP header size in bytes")
set(MAX_HTTP_BODY_SIZE "0xA00000" CACHE STRING "The maximum allowed HTTP body size in bytes") set(MAX_HTTP_BODY_SIZE "0xA00000" CACHE STRING "The maximum allowed HTTP body size in bytes")

View File

@ -82,20 +82,27 @@ inline double ntohd(double val)
return val; return val;
} }
inline void set_unix_socket_blocking(int32_t socket_descriptor, bool is_blocking_already, bool should_block) inline bool set_unix_socket_blocking(int32_t socket_descriptor, bool is_blocking_already, bool should_block)
{ {
//Don't update it if we're already in that mode //Don't update it if we're already in that mode
if(should_block == is_blocking_already) if(should_block == is_blocking_already)
return; return true;
//Different API calls needed for both windows and unix //Different API calls needed for both windows and unix
#ifdef WIN32 #ifdef WIN32
u_long non_blocking = should_block ? 0 : 1; u_long non_blocking = should_block ? 0 : 1;
ioctlsocket(socket_descriptor, FIONBIO, &non_blocking); int ret = ioctlsocket(socket_descriptor, FIONBIO, &non_blocking);
if(ret != 0)
return false;
#else #else
int flags = fcntl(socket_descriptor, F_GETFL, 0); int flags = fcntl(socket_descriptor, F_GETFL, 0);
fcntl(socket_descriptor, F_SETFL, is_blocking_already ? flags ^ O_NONBLOCK : flags ^= O_NONBLOCK); if(flags < 0)
return false;
flags = fcntl(socket_descriptor, F_SETFL, is_blocking_already ? flags ^ O_NONBLOCK : flags ^= O_NONBLOCK);
if(flags < 0)
return false;
#endif #endif
return true;
} }
static UNUSED_VAR void init_wsa() static UNUSED_VAR void init_wsa()

View File

@ -59,9 +59,10 @@ namespace fr
* *
* @param address The address of the socket to connect to * @param address The address of the socket to connect to
* @param port The port of the socket to connect to * @param port The port of the socket to connect to
* @param timeout The number of seconds to wait before timing the connection attempt out. Pass -1 for default.
* @return A Socket::Status indicating the status of the operation. * @return A Socket::Status indicating the status of the operation.
*/ */
Socket::Status connect(const std::string &address, const std::string &port) override; virtual Socket::Status connect(const std::string &address, const std::string &port, std::chrono::seconds timeout) override;
/*! /*!
* Set the SSL context * Set the SSL context

View File

@ -73,9 +73,10 @@ namespace fr
* *
* @param address The address of the socket to connect to * @param address The address of the socket to connect to
* @param port The port of the socket to connect to * @param port The port of the socket to connect to
* @param timeout The number of seconds to wait before timing the connection attempt out. Pass -1 for default.
* @return A Socket::Status indicating the status of the operation. * @return A Socket::Status indicating the status of the operation.
*/ */
virtual Socket::Status connect(const std::string &address, const std::string &port)=0; virtual Socket::Status connect(const std::string &address, const std::string &port, std::chrono::seconds timeout)=0;
/*! /*!
* Gets the socket's printable remote address * Gets the socket's printable remote address

View File

@ -16,10 +16,12 @@ namespace fr
{ {
public: public:
SocketSelector() noexcept; SocketSelector() noexcept;
~SocketSelector();
/*! /*!
* Waits for a socket to become ready. * Waits for a socket to become ready.
* *
* @throws An std::exception if the socket encountered an error
* @param timeout The amount of time wait should block for before timing out. * @param timeout The amount of time wait should block for before timing out.
* @return True if a socket is ready. False if it timed out. * @return True if a socket is ready. False if it timed out.
*/ */
@ -33,13 +35,25 @@ namespace fr
* @param socket The socket to add. * @param socket The socket to add.
*/ */
template<typename T> template<typename T>
void add(const T &socket) inline void add(const T &socket)
{
add(socket.get_socket_descriptor());
}
/*!
* Adds a socket to the selector. Note that SocketSelector
* does not keep a copy of the object, just a handle, it's
* up to you to store your fr::Sockets.
*
* @param socket The socket descriptor to add.
*/
void add(int32_t socket_descriptor)
{ {
//Add it to the set //Add it to the set
FD_SET(socket.get_socket_descriptor(), &listen_set); FD_SET(socket_descriptor, &listen_set);
if(socket.get_socket_descriptor() > max_descriptor) if(socket_descriptor > max_descriptor)
max_descriptor = socket.get_socket_descriptor(); max_descriptor = socket_descriptor;
} }
/*! /*!

View File

@ -30,9 +30,10 @@ public:
* *
* @param address The address of the socket to connect to * @param address The address of the socket to connect to
* @param port The port of the socket to connect to * @param port The port of the socket to connect to
* @param timeout The number of seconds to wait before timing the connection attempt out. Pass -1 for default.
* @return A Socket::Status indicating the status of the operation. * @return A Socket::Status indicating the status of the operation.
*/ */
virtual Socket::Status connect(const std::string &address, const std::string &port); virtual Socket::Status connect(const std::string &address, const std::string &port, std::chrono::seconds timeout) override;
/*! /*!
* Sets the socket file descriptor. * Sets the socket file descriptor.

View File

@ -73,7 +73,7 @@ namespace fr
read = mbedtls_ssl_read(ssl.get(), (unsigned char *)data, data_size); read = mbedtls_ssl_read(ssl.get(), (unsigned char *)data, data_size);
} }
if(read == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY || read <= 0) if(read <= 0)
{ {
close_socket(); close_socket();
return Socket::Status::Disconnected; return Socket::Status::Disconnected;
@ -84,7 +84,7 @@ namespace fr
} }
Socket::Status SSLSocket::connect(const std::string &address, const std::string &port) Socket::Status SSLSocket::connect(const std::string &address, const std::string &port, std::chrono::seconds timeout)
{ {
//Initialise mbedtls stuff //Initialise mbedtls stuff
ssl = std::make_unique<mbedtls_ssl_context>(); ssl = std::make_unique<mbedtls_ssl_context>();
@ -92,14 +92,20 @@ namespace fr
mbedtls_ssl_init(ssl.get()); mbedtls_ssl_init(ssl.get());
mbedtls_net_init(ssl_socket_descriptor.get()); mbedtls_net_init(ssl_socket_descriptor.get());
//Initialise the connection using mbedtlsl //Do to mbedtls not supporting connect timeouts, we have to use an fr::TcpSocket to
int error = 0; //Open the descriptor, and then steal it. This is a hack.
if((error = mbedtls_net_connect(ssl_socket_descriptor.get(), address.c_str(), port.c_str(), MBEDTLS_NET_PROTO_TCP)) != 0)
{ {
return Socket::Status::ConnectionFailed; fr::TcpSocket socket;
auto ret = socket.connect(address, port, timeout);
if(ret != fr::Socket::Success)
return ret;
ssl_socket_descriptor->fd = socket.get_socket_descriptor();
remote_address = socket.get_remote_address();
socket.set_descriptor(-1);
} }
//Initialise SSL data structures //Initialise SSL data structures
int error = 0;
if((error = mbedtls_ssl_config_defaults(&conf, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT)) != 0) if((error = mbedtls_ssl_config_defaults(&conf, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT)) != 0)
{ {
return Socket::Status::Error; return Socket::Status::Error;
@ -142,7 +148,6 @@ namespace fr
} }
//Update state //Update state
remote_address = address + ":" + port;
reconfigure_socket(); reconfigure_socket();
return Socket::Status::Success; return Socket::Status::Success;

View File

@ -18,6 +18,12 @@ namespace fr
max_descriptor = 0; max_descriptor = 0;
} }
SocketSelector::~SocketSelector()
{
FD_ZERO(&listen_set);
FD_ZERO(&listen_read);
}
bool SocketSelector::wait(std::chrono::milliseconds timeout) bool SocketSelector::wait(std::chrono::milliseconds timeout)
{ {
//Windows will crash if we pass an empty set. Do a check. //Windows will crash if we pass an empty set. Do a check.
@ -35,8 +41,7 @@ namespace fr
wait_time.tv_usec = std::chrono::duration_cast<std::chrono::microseconds>(timeout).count(); wait_time.tv_usec = std::chrono::duration_cast<std::chrono::microseconds>(timeout).count();
listen_read = listen_set; listen_read = listen_set;
int select_result = select(max_descriptor + 1, &listen_read, nullptr, nullptr, timeout == std::chrono::milliseconds(0) ? nullptr int select_result = select(max_descriptor + 1, &listen_read, nullptr, nullptr, timeout.count() == 0 ? nullptr : &wait_time);
: &wait_time);
if(select_result == 0) //If it's timed out if(select_result == 0) //If it's timed out
return false; return false;
@ -45,4 +50,5 @@ namespace fr
return true; return true;
} }
} }

View File

@ -3,7 +3,9 @@
// //
#include <iostream> #include <iostream>
#include <frnetlib/SocketSelector.h>
#include "frnetlib/TcpSocket.h" #include "frnetlib/TcpSocket.h"
#define DEFAULT_SOCKET_TIMEOUT 20
namespace fr namespace fr
{ {
@ -83,46 +85,77 @@ namespace fr
socket_descriptor = descriptor; socket_descriptor = descriptor;
} }
Socket::Status TcpSocket::connect(const std::string &address, const std::string &port) Socket::Status TcpSocket::connect(const std::string &address, const std::string &port, std::chrono::seconds timeout)
{ {
//Setup required structures
int ret = 0;
addrinfo *info; addrinfo *info;
addrinfo hints{}; addrinfo hints{};
memset(&hints, 0, sizeof(addrinfo)); memset(&hints, 0, sizeof(addrinfo));
//Setup connection settings
hints.ai_family = ai_family; hints.ai_family = ai_family;
hints.ai_socktype = SOCK_STREAM; //TCP hints.ai_socktype = SOCK_STREAM; //TCP
hints.ai_flags = AI_PASSIVE; //Have the IP filled in for us hints.ai_flags = AI_PASSIVE; //Have the IP filled in for us
//Query remote address information
if(getaddrinfo(address.c_str(), port.c_str(), &hints, &info) != 0) if(getaddrinfo(address.c_str(), port.c_str(), &hints, &info) != 0)
{ {
return Socket::Status::Error; return Socket::Status::Error;
} }
//Try to connect to results returned by getaddrinfo until we succeed/run out of things
addrinfo *c; addrinfo *c;
for(c = info; c != nullptr; c = c->ai_next) for(c = info; c != nullptr; c = c->ai_next)
{ {
//Get the socket for this entry
socket_descriptor = ::socket(c->ai_family, c->ai_socktype, c->ai_protocol); socket_descriptor = ::socket(c->ai_family, c->ai_socktype, c->ai_protocol);
if(socket_descriptor == INVALID_SOCKET) if(socket_descriptor == INVALID_SOCKET)
{
continue; continue;
}
if(::connect(socket_descriptor, c->ai_addr, c->ai_addrlen) == SOCKET_ERROR) //Put it into non-blocking mode, to allow for a custom connect timeout
{ if(!set_unix_socket_blocking(socket_descriptor, true, false))
continue;
//Try and connect
ret = ::connect(socket_descriptor, c->ai_addr, c->ai_addrlen);
if(ret < 0 && errno != EINPROGRESS)
continue;
else if(ret == 0) //If it connected immediately then break out of the connect loop
break;
//Wait for the socket to do something/expire
timeval tv = {};
tv.tv_sec = timeout.count() == -1 ? DEFAULT_SOCKET_TIMEOUT : timeout.count();
tv.tv_usec = 0;
fd_set set = {};
FD_ZERO(&set);
FD_SET(socket_descriptor, &set);
ret = select(socket_descriptor + 1, nullptr, &set, nullptr, &tv);
if(ret <= 0)
continue;
//Verify that we're connected
socklen_t len = sizeof(ret);
if(getsockopt(socket_descriptor, SOL_SOCKET, SO_ERROR, &ret, &len) == -1)
continue;
if(ret != 0)
continue; continue;
}
break; break;
} }
//We're done with this now, cleanup //We're done with this now, cleanup
freeaddrinfo(info); freeaddrinfo(info);
if(c == nullptr) if(c == nullptr)
return Socket::Status::Error; return Socket::Status::Error;
//Turn back to blocking mode
if(!set_unix_socket_blocking(socket_descriptor, false, true))
return Socket::Status::Error;
//Update state now we've got a valid socket descriptor //Update state now we've got a valid socket descriptor
remote_address = address + ":" + port; remote_address = address + ":" + port;
reconfigure_socket(); reconfigure_socket();

View File

@ -28,7 +28,7 @@ TEST(TcpListenerTest, listener_accept)
{ {
fr::TcpSocket socket; fr::TcpSocket socket;
socket.set_inet_version(fr::Socket::IP::v4); socket.set_inet_version(fr::Socket::IP::v4);
auto ret = socket.connect("127.0.0.1", "9095"); auto ret = socket.connect("127.0.0.1", "9095", std::chrono::seconds(5));
ASSERT_EQ(ret, fr::Socket::Success); ASSERT_EQ(ret, fr::Socket::Success);
}; };