Added support for connect timeouts.

Both fr::TcpSocket and fr::SSLSocket can have timeouts specified when connecting. This works by putting the socket into non-blocking mode, making a connect, and then selecting on the socket for the requested timeout. If the select times out then we've failed to connect, if it didn't time out then we've connected.
This commit is contained in:
Fred Nicolson 2018-01-10 17:08:16 +00:00
parent 2215d068af
commit 30354f15bc
10 changed files with 97 additions and 29 deletions

View File

@ -6,7 +6,7 @@ set(FRNETLIB_LINK_LIBRARIES "")
set( CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/cmake_modules)
#User options
option(USE_SSL "Use SSL" OFF)
option(USE_SSL "Use SSL" ON)
set(FRNETLIB_BUILD_SHARED_LIBS false CACHE BOOL "Build shared library.")
set(MAX_HTTP_HEADER_SIZE "0xC800" CACHE STRING "The maximum allowed HTTP header size in bytes")
set(MAX_HTTP_BODY_SIZE "0xA00000" CACHE STRING "The maximum allowed HTTP body size in bytes")

View File

@ -82,20 +82,27 @@ inline double ntohd(double val)
return val;
}
inline void set_unix_socket_blocking(int32_t socket_descriptor, bool is_blocking_already, bool should_block)
inline bool set_unix_socket_blocking(int32_t socket_descriptor, bool is_blocking_already, bool should_block)
{
//Don't update it if we're already in that mode
if(should_block == is_blocking_already)
return;
return true;
//Different API calls needed for both windows and unix
#ifdef WIN32
u_long non_blocking = should_block ? 0 : 1;
ioctlsocket(socket_descriptor, FIONBIO, &non_blocking);
int ret = ioctlsocket(socket_descriptor, FIONBIO, &non_blocking);
if(ret != 0)
return false;
#else
int flags = fcntl(socket_descriptor, F_GETFL, 0);
fcntl(socket_descriptor, F_SETFL, is_blocking_already ? flags ^ O_NONBLOCK : flags ^= O_NONBLOCK);
if(flags < 0)
return false;
flags = fcntl(socket_descriptor, F_SETFL, is_blocking_already ? flags ^ O_NONBLOCK : flags ^= O_NONBLOCK);
if(flags < 0)
return false;
#endif
return true;
}
static UNUSED_VAR void init_wsa()

View File

@ -59,9 +59,10 @@ namespace fr
*
* @param address The address of the socket to connect to
* @param port The port of the socket to connect to
* @param timeout The number of seconds to wait before timing the connection attempt out. Pass -1 for default.
* @return A Socket::Status indicating the status of the operation.
*/
Socket::Status connect(const std::string &address, const std::string &port) override;
virtual Socket::Status connect(const std::string &address, const std::string &port, std::chrono::seconds timeout) override;
/*!
* Set the SSL context

View File

@ -73,9 +73,10 @@ namespace fr
*
* @param address The address of the socket to connect to
* @param port The port of the socket to connect to
* @param timeout The number of seconds to wait before timing the connection attempt out. Pass -1 for default.
* @return A Socket::Status indicating the status of the operation.
*/
virtual Socket::Status connect(const std::string &address, const std::string &port)=0;
virtual Socket::Status connect(const std::string &address, const std::string &port, std::chrono::seconds timeout)=0;
/*!
* Gets the socket's printable remote address

View File

@ -16,10 +16,12 @@ namespace fr
{
public:
SocketSelector() noexcept;
~SocketSelector();
/*!
* Waits for a socket to become ready.
*
* @throws An std::exception if the socket encountered an error
* @param timeout The amount of time wait should block for before timing out.
* @return True if a socket is ready. False if it timed out.
*/
@ -33,13 +35,25 @@ namespace fr
* @param socket The socket to add.
*/
template<typename T>
void add(const T &socket)
inline void add(const T &socket)
{
add(socket.get_socket_descriptor());
}
/*!
* Adds a socket to the selector. Note that SocketSelector
* does not keep a copy of the object, just a handle, it's
* up to you to store your fr::Sockets.
*
* @param socket The socket descriptor to add.
*/
void add(int32_t socket_descriptor)
{
//Add it to the set
FD_SET(socket.get_socket_descriptor(), &listen_set);
FD_SET(socket_descriptor, &listen_set);
if(socket.get_socket_descriptor() > max_descriptor)
max_descriptor = socket.get_socket_descriptor();
if(socket_descriptor > max_descriptor)
max_descriptor = socket_descriptor;
}
/*!

View File

@ -30,9 +30,10 @@ public:
*
* @param address The address of the socket to connect to
* @param port The port of the socket to connect to
* @param timeout The number of seconds to wait before timing the connection attempt out. Pass -1 for default.
* @return A Socket::Status indicating the status of the operation.
*/
virtual Socket::Status connect(const std::string &address, const std::string &port);
virtual Socket::Status connect(const std::string &address, const std::string &port, std::chrono::seconds timeout) override;
/*!
* Sets the socket file descriptor.

View File

@ -73,7 +73,7 @@ namespace fr
read = mbedtls_ssl_read(ssl.get(), (unsigned char *)data, data_size);
}
if(read == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY || read <= 0)
if(read <= 0)
{
close_socket();
return Socket::Status::Disconnected;
@ -84,7 +84,7 @@ namespace fr
}
Socket::Status SSLSocket::connect(const std::string &address, const std::string &port)
Socket::Status SSLSocket::connect(const std::string &address, const std::string &port, std::chrono::seconds timeout)
{
//Initialise mbedtls stuff
ssl = std::make_unique<mbedtls_ssl_context>();
@ -92,14 +92,20 @@ namespace fr
mbedtls_ssl_init(ssl.get());
mbedtls_net_init(ssl_socket_descriptor.get());
//Initialise the connection using mbedtlsl
int error = 0;
if((error = mbedtls_net_connect(ssl_socket_descriptor.get(), address.c_str(), port.c_str(), MBEDTLS_NET_PROTO_TCP)) != 0)
//Do to mbedtls not supporting connect timeouts, we have to use an fr::TcpSocket to
//Open the descriptor, and then steal it. This is a hack.
{
return Socket::Status::ConnectionFailed;
fr::TcpSocket socket;
auto ret = socket.connect(address, port, timeout);
if(ret != fr::Socket::Success)
return ret;
ssl_socket_descriptor->fd = socket.get_socket_descriptor();
remote_address = socket.get_remote_address();
socket.set_descriptor(-1);
}
//Initialise SSL data structures
int error = 0;
if((error = mbedtls_ssl_config_defaults(&conf, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT)) != 0)
{
return Socket::Status::Error;
@ -142,7 +148,6 @@ namespace fr
}
//Update state
remote_address = address + ":" + port;
reconfigure_socket();
return Socket::Status::Success;

View File

@ -18,6 +18,12 @@ namespace fr
max_descriptor = 0;
}
SocketSelector::~SocketSelector()
{
FD_ZERO(&listen_set);
FD_ZERO(&listen_read);
}
bool SocketSelector::wait(std::chrono::milliseconds timeout)
{
//Windows will crash if we pass an empty set. Do a check.
@ -35,8 +41,7 @@ namespace fr
wait_time.tv_usec = std::chrono::duration_cast<std::chrono::microseconds>(timeout).count();
listen_read = listen_set;
int select_result = select(max_descriptor + 1, &listen_read, nullptr, nullptr, timeout == std::chrono::milliseconds(0) ? nullptr
: &wait_time);
int select_result = select(max_descriptor + 1, &listen_read, nullptr, nullptr, timeout.count() == 0 ? nullptr : &wait_time);
if(select_result == 0) //If it's timed out
return false;
@ -45,4 +50,5 @@ namespace fr
return true;
}
}

View File

@ -3,7 +3,9 @@
//
#include <iostream>
#include <frnetlib/SocketSelector.h>
#include "frnetlib/TcpSocket.h"
#define DEFAULT_SOCKET_TIMEOUT 20
namespace fr
{
@ -83,46 +85,77 @@ namespace fr
socket_descriptor = descriptor;
}
Socket::Status TcpSocket::connect(const std::string &address, const std::string &port)
Socket::Status TcpSocket::connect(const std::string &address, const std::string &port, std::chrono::seconds timeout)
{
//Setup required structures
int ret = 0;
addrinfo *info;
addrinfo hints{};
memset(&hints, 0, sizeof(addrinfo));
//Setup connection settings
hints.ai_family = ai_family;
hints.ai_socktype = SOCK_STREAM; //TCP
hints.ai_flags = AI_PASSIVE; //Have the IP filled in for us
//Query remote address information
if(getaddrinfo(address.c_str(), port.c_str(), &hints, &info) != 0)
{
return Socket::Status::Error;
}
//Try to connect to results returned by getaddrinfo until we succeed/run out of things
addrinfo *c;
for(c = info; c != nullptr; c = c->ai_next)
{
//Get the socket for this entry
socket_descriptor = ::socket(c->ai_family, c->ai_socktype, c->ai_protocol);
if(socket_descriptor == INVALID_SOCKET)
{
continue;
}
if(::connect(socket_descriptor, c->ai_addr, c->ai_addrlen) == SOCKET_ERROR)
{
//Put it into non-blocking mode, to allow for a custom connect timeout
if(!set_unix_socket_blocking(socket_descriptor, true, false))
continue;
//Try and connect
ret = ::connect(socket_descriptor, c->ai_addr, c->ai_addrlen);
if(ret < 0 && errno != EINPROGRESS)
continue;
else if(ret == 0) //If it connected immediately then break out of the connect loop
break;
//Wait for the socket to do something/expire
timeval tv = {};
tv.tv_sec = timeout.count() == -1 ? DEFAULT_SOCKET_TIMEOUT : timeout.count();
tv.tv_usec = 0;
fd_set set = {};
FD_ZERO(&set);
FD_SET(socket_descriptor, &set);
ret = select(socket_descriptor + 1, nullptr, &set, nullptr, &tv);
if(ret <= 0)
continue;
//Verify that we're connected
socklen_t len = sizeof(ret);
if(getsockopt(socket_descriptor, SOL_SOCKET, SO_ERROR, &ret, &len) == -1)
continue;
if(ret != 0)
continue;
}
break;
}
//We're done with this now, cleanup
freeaddrinfo(info);
if(c == nullptr)
return Socket::Status::Error;
//Turn back to blocking mode
if(!set_unix_socket_blocking(socket_descriptor, false, true))
return Socket::Status::Error;
//Update state now we've got a valid socket descriptor
remote_address = address + ":" + port;
reconfigure_socket();

View File

@ -28,7 +28,7 @@ TEST(TcpListenerTest, listener_accept)
{
fr::TcpSocket socket;
socket.set_inet_version(fr::Socket::IP::v4);
auto ret = socket.connect("127.0.0.1", "9095");
auto ret = socket.connect("127.0.0.1", "9095", std::chrono::seconds(5));
ASSERT_EQ(ret, fr::Socket::Success);
};