Added support for connect timeouts.
Both fr::TcpSocket and fr::SSLSocket can have timeouts specified when connecting. This works by putting the socket into non-blocking mode, making a connect, and then selecting on the socket for the requested timeout. If the select times out then we've failed to connect, if it didn't time out then we've connected.
This commit is contained in:
parent
2215d068af
commit
30354f15bc
@ -6,7 +6,7 @@ set(FRNETLIB_LINK_LIBRARIES "")
|
||||
set( CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/cmake_modules)
|
||||
|
||||
#User options
|
||||
option(USE_SSL "Use SSL" OFF)
|
||||
option(USE_SSL "Use SSL" ON)
|
||||
set(FRNETLIB_BUILD_SHARED_LIBS false CACHE BOOL "Build shared library.")
|
||||
set(MAX_HTTP_HEADER_SIZE "0xC800" CACHE STRING "The maximum allowed HTTP header size in bytes")
|
||||
set(MAX_HTTP_BODY_SIZE "0xA00000" CACHE STRING "The maximum allowed HTTP body size in bytes")
|
||||
|
||||
@ -82,20 +82,27 @@ inline double ntohd(double val)
|
||||
return val;
|
||||
}
|
||||
|
||||
inline void set_unix_socket_blocking(int32_t socket_descriptor, bool is_blocking_already, bool should_block)
|
||||
inline bool set_unix_socket_blocking(int32_t socket_descriptor, bool is_blocking_already, bool should_block)
|
||||
{
|
||||
//Don't update it if we're already in that mode
|
||||
if(should_block == is_blocking_already)
|
||||
return;
|
||||
return true;
|
||||
|
||||
//Different API calls needed for both windows and unix
|
||||
#ifdef WIN32
|
||||
u_long non_blocking = should_block ? 0 : 1;
|
||||
ioctlsocket(socket_descriptor, FIONBIO, &non_blocking);
|
||||
int ret = ioctlsocket(socket_descriptor, FIONBIO, &non_blocking);
|
||||
if(ret != 0)
|
||||
return false;
|
||||
#else
|
||||
int flags = fcntl(socket_descriptor, F_GETFL, 0);
|
||||
fcntl(socket_descriptor, F_SETFL, is_blocking_already ? flags ^ O_NONBLOCK : flags ^= O_NONBLOCK);
|
||||
if(flags < 0)
|
||||
return false;
|
||||
flags = fcntl(socket_descriptor, F_SETFL, is_blocking_already ? flags ^ O_NONBLOCK : flags ^= O_NONBLOCK);
|
||||
if(flags < 0)
|
||||
return false;
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
static UNUSED_VAR void init_wsa()
|
||||
|
||||
@ -59,9 +59,10 @@ namespace fr
|
||||
*
|
||||
* @param address The address of the socket to connect to
|
||||
* @param port The port of the socket to connect to
|
||||
* @param timeout The number of seconds to wait before timing the connection attempt out. Pass -1 for default.
|
||||
* @return A Socket::Status indicating the status of the operation.
|
||||
*/
|
||||
Socket::Status connect(const std::string &address, const std::string &port) override;
|
||||
virtual Socket::Status connect(const std::string &address, const std::string &port, std::chrono::seconds timeout) override;
|
||||
|
||||
/*!
|
||||
* Set the SSL context
|
||||
|
||||
@ -73,9 +73,10 @@ namespace fr
|
||||
*
|
||||
* @param address The address of the socket to connect to
|
||||
* @param port The port of the socket to connect to
|
||||
* @param timeout The number of seconds to wait before timing the connection attempt out. Pass -1 for default.
|
||||
* @return A Socket::Status indicating the status of the operation.
|
||||
*/
|
||||
virtual Socket::Status connect(const std::string &address, const std::string &port)=0;
|
||||
virtual Socket::Status connect(const std::string &address, const std::string &port, std::chrono::seconds timeout)=0;
|
||||
|
||||
/*!
|
||||
* Gets the socket's printable remote address
|
||||
|
||||
@ -16,10 +16,12 @@ namespace fr
|
||||
{
|
||||
public:
|
||||
SocketSelector() noexcept;
|
||||
~SocketSelector();
|
||||
|
||||
/*!
|
||||
* Waits for a socket to become ready.
|
||||
*
|
||||
* @throws An std::exception if the socket encountered an error
|
||||
* @param timeout The amount of time wait should block for before timing out.
|
||||
* @return True if a socket is ready. False if it timed out.
|
||||
*/
|
||||
@ -33,13 +35,25 @@ namespace fr
|
||||
* @param socket The socket to add.
|
||||
*/
|
||||
template<typename T>
|
||||
void add(const T &socket)
|
||||
inline void add(const T &socket)
|
||||
{
|
||||
add(socket.get_socket_descriptor());
|
||||
}
|
||||
|
||||
/*!
|
||||
* Adds a socket to the selector. Note that SocketSelector
|
||||
* does not keep a copy of the object, just a handle, it's
|
||||
* up to you to store your fr::Sockets.
|
||||
*
|
||||
* @param socket The socket descriptor to add.
|
||||
*/
|
||||
void add(int32_t socket_descriptor)
|
||||
{
|
||||
//Add it to the set
|
||||
FD_SET(socket.get_socket_descriptor(), &listen_set);
|
||||
FD_SET(socket_descriptor, &listen_set);
|
||||
|
||||
if(socket.get_socket_descriptor() > max_descriptor)
|
||||
max_descriptor = socket.get_socket_descriptor();
|
||||
if(socket_descriptor > max_descriptor)
|
||||
max_descriptor = socket_descriptor;
|
||||
}
|
||||
|
||||
/*!
|
||||
|
||||
@ -30,9 +30,10 @@ public:
|
||||
*
|
||||
* @param address The address of the socket to connect to
|
||||
* @param port The port of the socket to connect to
|
||||
* @param timeout The number of seconds to wait before timing the connection attempt out. Pass -1 for default.
|
||||
* @return A Socket::Status indicating the status of the operation.
|
||||
*/
|
||||
virtual Socket::Status connect(const std::string &address, const std::string &port);
|
||||
virtual Socket::Status connect(const std::string &address, const std::string &port, std::chrono::seconds timeout) override;
|
||||
|
||||
/*!
|
||||
* Sets the socket file descriptor.
|
||||
|
||||
@ -73,7 +73,7 @@ namespace fr
|
||||
read = mbedtls_ssl_read(ssl.get(), (unsigned char *)data, data_size);
|
||||
}
|
||||
|
||||
if(read == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY || read <= 0)
|
||||
if(read <= 0)
|
||||
{
|
||||
close_socket();
|
||||
return Socket::Status::Disconnected;
|
||||
@ -84,7 +84,7 @@ namespace fr
|
||||
|
||||
}
|
||||
|
||||
Socket::Status SSLSocket::connect(const std::string &address, const std::string &port)
|
||||
Socket::Status SSLSocket::connect(const std::string &address, const std::string &port, std::chrono::seconds timeout)
|
||||
{
|
||||
//Initialise mbedtls stuff
|
||||
ssl = std::make_unique<mbedtls_ssl_context>();
|
||||
@ -92,14 +92,20 @@ namespace fr
|
||||
mbedtls_ssl_init(ssl.get());
|
||||
mbedtls_net_init(ssl_socket_descriptor.get());
|
||||
|
||||
//Initialise the connection using mbedtlsl
|
||||
int error = 0;
|
||||
if((error = mbedtls_net_connect(ssl_socket_descriptor.get(), address.c_str(), port.c_str(), MBEDTLS_NET_PROTO_TCP)) != 0)
|
||||
//Do to mbedtls not supporting connect timeouts, we have to use an fr::TcpSocket to
|
||||
//Open the descriptor, and then steal it. This is a hack.
|
||||
{
|
||||
return Socket::Status::ConnectionFailed;
|
||||
fr::TcpSocket socket;
|
||||
auto ret = socket.connect(address, port, timeout);
|
||||
if(ret != fr::Socket::Success)
|
||||
return ret;
|
||||
ssl_socket_descriptor->fd = socket.get_socket_descriptor();
|
||||
remote_address = socket.get_remote_address();
|
||||
socket.set_descriptor(-1);
|
||||
}
|
||||
|
||||
//Initialise SSL data structures
|
||||
int error = 0;
|
||||
if((error = mbedtls_ssl_config_defaults(&conf, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT)) != 0)
|
||||
{
|
||||
return Socket::Status::Error;
|
||||
@ -142,7 +148,6 @@ namespace fr
|
||||
}
|
||||
|
||||
//Update state
|
||||
remote_address = address + ":" + port;
|
||||
reconfigure_socket();
|
||||
|
||||
return Socket::Status::Success;
|
||||
|
||||
@ -18,6 +18,12 @@ namespace fr
|
||||
max_descriptor = 0;
|
||||
}
|
||||
|
||||
SocketSelector::~SocketSelector()
|
||||
{
|
||||
FD_ZERO(&listen_set);
|
||||
FD_ZERO(&listen_read);
|
||||
}
|
||||
|
||||
bool SocketSelector::wait(std::chrono::milliseconds timeout)
|
||||
{
|
||||
//Windows will crash if we pass an empty set. Do a check.
|
||||
@ -35,8 +41,7 @@ namespace fr
|
||||
wait_time.tv_usec = std::chrono::duration_cast<std::chrono::microseconds>(timeout).count();
|
||||
|
||||
listen_read = listen_set;
|
||||
int select_result = select(max_descriptor + 1, &listen_read, nullptr, nullptr, timeout == std::chrono::milliseconds(0) ? nullptr
|
||||
: &wait_time);
|
||||
int select_result = select(max_descriptor + 1, &listen_read, nullptr, nullptr, timeout.count() == 0 ? nullptr : &wait_time);
|
||||
|
||||
if(select_result == 0) //If it's timed out
|
||||
return false;
|
||||
@ -45,4 +50,5 @@ namespace fr
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
||||
@ -3,7 +3,9 @@
|
||||
//
|
||||
|
||||
#include <iostream>
|
||||
#include <frnetlib/SocketSelector.h>
|
||||
#include "frnetlib/TcpSocket.h"
|
||||
#define DEFAULT_SOCKET_TIMEOUT 20
|
||||
|
||||
namespace fr
|
||||
{
|
||||
@ -83,46 +85,77 @@ namespace fr
|
||||
socket_descriptor = descriptor;
|
||||
}
|
||||
|
||||
Socket::Status TcpSocket::connect(const std::string &address, const std::string &port)
|
||||
Socket::Status TcpSocket::connect(const std::string &address, const std::string &port, std::chrono::seconds timeout)
|
||||
{
|
||||
//Setup required structures
|
||||
int ret = 0;
|
||||
addrinfo *info;
|
||||
addrinfo hints{};
|
||||
|
||||
memset(&hints, 0, sizeof(addrinfo));
|
||||
|
||||
//Setup connection settings
|
||||
hints.ai_family = ai_family;
|
||||
hints.ai_socktype = SOCK_STREAM; //TCP
|
||||
hints.ai_flags = AI_PASSIVE; //Have the IP filled in for us
|
||||
|
||||
//Query remote address information
|
||||
if(getaddrinfo(address.c_str(), port.c_str(), &hints, &info) != 0)
|
||||
{
|
||||
return Socket::Status::Error;
|
||||
}
|
||||
|
||||
//Try to connect to results returned by getaddrinfo until we succeed/run out of things
|
||||
addrinfo *c;
|
||||
for(c = info; c != nullptr; c = c->ai_next)
|
||||
{
|
||||
//Get the socket for this entry
|
||||
socket_descriptor = ::socket(c->ai_family, c->ai_socktype, c->ai_protocol);
|
||||
if(socket_descriptor == INVALID_SOCKET)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
if(::connect(socket_descriptor, c->ai_addr, c->ai_addrlen) == SOCKET_ERROR)
|
||||
{
|
||||
//Put it into non-blocking mode, to allow for a custom connect timeout
|
||||
if(!set_unix_socket_blocking(socket_descriptor, true, false))
|
||||
continue;
|
||||
|
||||
//Try and connect
|
||||
ret = ::connect(socket_descriptor, c->ai_addr, c->ai_addrlen);
|
||||
if(ret < 0 && errno != EINPROGRESS)
|
||||
continue;
|
||||
else if(ret == 0) //If it connected immediately then break out of the connect loop
|
||||
break;
|
||||
|
||||
//Wait for the socket to do something/expire
|
||||
timeval tv = {};
|
||||
tv.tv_sec = timeout.count() == -1 ? DEFAULT_SOCKET_TIMEOUT : timeout.count();
|
||||
tv.tv_usec = 0;
|
||||
fd_set set = {};
|
||||
FD_ZERO(&set);
|
||||
FD_SET(socket_descriptor, &set);
|
||||
ret = select(socket_descriptor + 1, nullptr, &set, nullptr, &tv);
|
||||
if(ret <= 0)
|
||||
continue;
|
||||
|
||||
//Verify that we're connected
|
||||
socklen_t len = sizeof(ret);
|
||||
if(getsockopt(socket_descriptor, SOL_SOCKET, SO_ERROR, &ret, &len) == -1)
|
||||
continue;
|
||||
if(ret != 0)
|
||||
continue;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
//We're done with this now, cleanup
|
||||
freeaddrinfo(info);
|
||||
|
||||
if(c == nullptr)
|
||||
return Socket::Status::Error;
|
||||
|
||||
//Turn back to blocking mode
|
||||
if(!set_unix_socket_blocking(socket_descriptor, false, true))
|
||||
return Socket::Status::Error;
|
||||
|
||||
//Update state now we've got a valid socket descriptor
|
||||
remote_address = address + ":" + port;
|
||||
reconfigure_socket();
|
||||
|
||||
@ -28,7 +28,7 @@ TEST(TcpListenerTest, listener_accept)
|
||||
{
|
||||
fr::TcpSocket socket;
|
||||
socket.set_inet_version(fr::Socket::IP::v4);
|
||||
auto ret = socket.connect("127.0.0.1", "9095");
|
||||
auto ret = socket.connect("127.0.0.1", "9095", std::chrono::seconds(5));
|
||||
ASSERT_EQ(ret, fr::Socket::Success);
|
||||
};
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user