diff --git a/CMakeLists.txt b/CMakeLists.txt index 490dabe..fd8efe6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,7 +6,7 @@ set(FRNETLIB_LINK_LIBRARIES "") set( CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/cmake_modules) #User options -option(USE_SSL "Use SSL" OFF) +option(USE_SSL "Use SSL" ON) set(FRNETLIB_BUILD_SHARED_LIBS false CACHE BOOL "Build shared library.") set(MAX_HTTP_HEADER_SIZE "0xC800" CACHE STRING "The maximum allowed HTTP header size in bytes") set(MAX_HTTP_BODY_SIZE "0xA00000" CACHE STRING "The maximum allowed HTTP body size in bytes") diff --git a/include/frnetlib/NetworkEncoding.h b/include/frnetlib/NetworkEncoding.h index 0097895..1bceb28 100644 --- a/include/frnetlib/NetworkEncoding.h +++ b/include/frnetlib/NetworkEncoding.h @@ -82,20 +82,27 @@ inline double ntohd(double val) return val; } -inline void set_unix_socket_blocking(int32_t socket_descriptor, bool is_blocking_already, bool should_block) +inline bool set_unix_socket_blocking(int32_t socket_descriptor, bool is_blocking_already, bool should_block) { //Don't update it if we're already in that mode if(should_block == is_blocking_already) - return; + return true; //Different API calls needed for both windows and unix #ifdef WIN32 u_long non_blocking = should_block ? 0 : 1; - ioctlsocket(socket_descriptor, FIONBIO, &non_blocking); + int ret = ioctlsocket(socket_descriptor, FIONBIO, &non_blocking); + if(ret != 0) + return false; #else int flags = fcntl(socket_descriptor, F_GETFL, 0); - fcntl(socket_descriptor, F_SETFL, is_blocking_already ? flags ^ O_NONBLOCK : flags ^= O_NONBLOCK); + if(flags < 0) + return false; + flags = fcntl(socket_descriptor, F_SETFL, is_blocking_already ? flags ^ O_NONBLOCK : flags ^= O_NONBLOCK); + if(flags < 0) + return false; #endif + return true; } static UNUSED_VAR void init_wsa() diff --git a/include/frnetlib/SSLSocket.h b/include/frnetlib/SSLSocket.h index 0d909f7..7767e46 100644 --- a/include/frnetlib/SSLSocket.h +++ b/include/frnetlib/SSLSocket.h @@ -59,9 +59,10 @@ namespace fr * * @param address The address of the socket to connect to * @param port The port of the socket to connect to + * @param timeout The number of seconds to wait before timing the connection attempt out. Pass -1 for default. * @return A Socket::Status indicating the status of the operation. */ - Socket::Status connect(const std::string &address, const std::string &port) override; + virtual Socket::Status connect(const std::string &address, const std::string &port, std::chrono::seconds timeout) override; /*! * Set the SSL context diff --git a/include/frnetlib/Socket.h b/include/frnetlib/Socket.h index b38db19..f5e12f4 100644 --- a/include/frnetlib/Socket.h +++ b/include/frnetlib/Socket.h @@ -73,9 +73,10 @@ namespace fr * * @param address The address of the socket to connect to * @param port The port of the socket to connect to + * @param timeout The number of seconds to wait before timing the connection attempt out. Pass -1 for default. * @return A Socket::Status indicating the status of the operation. */ - virtual Socket::Status connect(const std::string &address, const std::string &port)=0; + virtual Socket::Status connect(const std::string &address, const std::string &port, std::chrono::seconds timeout)=0; /*! * Gets the socket's printable remote address diff --git a/include/frnetlib/SocketSelector.h b/include/frnetlib/SocketSelector.h index c8c9e75..32c4a47 100644 --- a/include/frnetlib/SocketSelector.h +++ b/include/frnetlib/SocketSelector.h @@ -16,10 +16,12 @@ namespace fr { public: SocketSelector() noexcept; + ~SocketSelector(); /*! * Waits for a socket to become ready. * + * @throws An std::exception if the socket encountered an error * @param timeout The amount of time wait should block for before timing out. * @return True if a socket is ready. False if it timed out. */ @@ -33,13 +35,25 @@ namespace fr * @param socket The socket to add. */ template - void add(const T &socket) + inline void add(const T &socket) + { + add(socket.get_socket_descriptor()); + } + + /*! + * Adds a socket to the selector. Note that SocketSelector + * does not keep a copy of the object, just a handle, it's + * up to you to store your fr::Sockets. + * + * @param socket The socket descriptor to add. + */ + void add(int32_t socket_descriptor) { //Add it to the set - FD_SET(socket.get_socket_descriptor(), &listen_set); + FD_SET(socket_descriptor, &listen_set); - if(socket.get_socket_descriptor() > max_descriptor) - max_descriptor = socket.get_socket_descriptor(); + if(socket_descriptor > max_descriptor) + max_descriptor = socket_descriptor; } /*! diff --git a/include/frnetlib/TcpSocket.h b/include/frnetlib/TcpSocket.h index ff3b42b..1ac038f 100644 --- a/include/frnetlib/TcpSocket.h +++ b/include/frnetlib/TcpSocket.h @@ -30,9 +30,10 @@ public: * * @param address The address of the socket to connect to * @param port The port of the socket to connect to + * @param timeout The number of seconds to wait before timing the connection attempt out. Pass -1 for default. * @return A Socket::Status indicating the status of the operation. */ - virtual Socket::Status connect(const std::string &address, const std::string &port); + virtual Socket::Status connect(const std::string &address, const std::string &port, std::chrono::seconds timeout) override; /*! * Sets the socket file descriptor. diff --git a/src/SSLSocket.cpp b/src/SSLSocket.cpp index 1ea6e22..14a306f 100644 --- a/src/SSLSocket.cpp +++ b/src/SSLSocket.cpp @@ -73,7 +73,7 @@ namespace fr read = mbedtls_ssl_read(ssl.get(), (unsigned char *)data, data_size); } - if(read == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY || read <= 0) + if(read <= 0) { close_socket(); return Socket::Status::Disconnected; @@ -84,7 +84,7 @@ namespace fr } - Socket::Status SSLSocket::connect(const std::string &address, const std::string &port) + Socket::Status SSLSocket::connect(const std::string &address, const std::string &port, std::chrono::seconds timeout) { //Initialise mbedtls stuff ssl = std::make_unique(); @@ -92,14 +92,20 @@ namespace fr mbedtls_ssl_init(ssl.get()); mbedtls_net_init(ssl_socket_descriptor.get()); - //Initialise the connection using mbedtlsl - int error = 0; - if((error = mbedtls_net_connect(ssl_socket_descriptor.get(), address.c_str(), port.c_str(), MBEDTLS_NET_PROTO_TCP)) != 0) + //Do to mbedtls not supporting connect timeouts, we have to use an fr::TcpSocket to + //Open the descriptor, and then steal it. This is a hack. { - return Socket::Status::ConnectionFailed; + fr::TcpSocket socket; + auto ret = socket.connect(address, port, timeout); + if(ret != fr::Socket::Success) + return ret; + ssl_socket_descriptor->fd = socket.get_socket_descriptor(); + remote_address = socket.get_remote_address(); + socket.set_descriptor(-1); } //Initialise SSL data structures + int error = 0; if((error = mbedtls_ssl_config_defaults(&conf, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT)) != 0) { return Socket::Status::Error; @@ -142,7 +148,6 @@ namespace fr } //Update state - remote_address = address + ":" + port; reconfigure_socket(); return Socket::Status::Success; diff --git a/src/SocketSelector.cpp b/src/SocketSelector.cpp index 1f4be08..0e7faa1 100644 --- a/src/SocketSelector.cpp +++ b/src/SocketSelector.cpp @@ -18,6 +18,12 @@ namespace fr max_descriptor = 0; } + SocketSelector::~SocketSelector() + { + FD_ZERO(&listen_set); + FD_ZERO(&listen_read); + } + bool SocketSelector::wait(std::chrono::milliseconds timeout) { //Windows will crash if we pass an empty set. Do a check. @@ -35,8 +41,7 @@ namespace fr wait_time.tv_usec = std::chrono::duration_cast(timeout).count(); listen_read = listen_set; - int select_result = select(max_descriptor + 1, &listen_read, nullptr, nullptr, timeout == std::chrono::milliseconds(0) ? nullptr - : &wait_time); + int select_result = select(max_descriptor + 1, &listen_read, nullptr, nullptr, timeout.count() == 0 ? nullptr : &wait_time); if(select_result == 0) //If it's timed out return false; @@ -45,4 +50,5 @@ namespace fr return true; } + } \ No newline at end of file diff --git a/src/TcpSocket.cpp b/src/TcpSocket.cpp index 9120610..e251a8e 100644 --- a/src/TcpSocket.cpp +++ b/src/TcpSocket.cpp @@ -3,7 +3,9 @@ // #include +#include #include "frnetlib/TcpSocket.h" +#define DEFAULT_SOCKET_TIMEOUT 20 namespace fr { @@ -83,46 +85,77 @@ namespace fr socket_descriptor = descriptor; } - Socket::Status TcpSocket::connect(const std::string &address, const std::string &port) + Socket::Status TcpSocket::connect(const std::string &address, const std::string &port, std::chrono::seconds timeout) { + //Setup required structures + int ret = 0; addrinfo *info; addrinfo hints{}; memset(&hints, 0, sizeof(addrinfo)); + //Setup connection settings hints.ai_family = ai_family; hints.ai_socktype = SOCK_STREAM; //TCP hints.ai_flags = AI_PASSIVE; //Have the IP filled in for us + //Query remote address information if(getaddrinfo(address.c_str(), port.c_str(), &hints, &info) != 0) { return Socket::Status::Error; } + //Try to connect to results returned by getaddrinfo until we succeed/run out of things addrinfo *c; for(c = info; c != nullptr; c = c->ai_next) { + //Get the socket for this entry socket_descriptor = ::socket(c->ai_family, c->ai_socktype, c->ai_protocol); if(socket_descriptor == INVALID_SOCKET) - { continue; - } - if(::connect(socket_descriptor, c->ai_addr, c->ai_addrlen) == SOCKET_ERROR) - { + //Put it into non-blocking mode, to allow for a custom connect timeout + if(!set_unix_socket_blocking(socket_descriptor, true, false)) + continue; + + //Try and connect + ret = ::connect(socket_descriptor, c->ai_addr, c->ai_addrlen); + if(ret < 0 && errno != EINPROGRESS) + continue; + else if(ret == 0) //If it connected immediately then break out of the connect loop + break; + + //Wait for the socket to do something/expire + timeval tv = {}; + tv.tv_sec = timeout.count() == -1 ? DEFAULT_SOCKET_TIMEOUT : timeout.count(); + tv.tv_usec = 0; + fd_set set = {}; + FD_ZERO(&set); + FD_SET(socket_descriptor, &set); + ret = select(socket_descriptor + 1, nullptr, &set, nullptr, &tv); + if(ret <= 0) + continue; + + //Verify that we're connected + socklen_t len = sizeof(ret); + if(getsockopt(socket_descriptor, SOL_SOCKET, SO_ERROR, &ret, &len) == -1) + continue; + if(ret != 0) continue; - } break; } //We're done with this now, cleanup freeaddrinfo(info); - if(c == nullptr) return Socket::Status::Error; + //Turn back to blocking mode + if(!set_unix_socket_blocking(socket_descriptor, false, true)) + return Socket::Status::Error; + //Update state now we've got a valid socket descriptor remote_address = address + ":" + port; reconfigure_socket(); diff --git a/tests/TcpListenerTest.cpp b/tests/TcpListenerTest.cpp index b16153b..64c1b07 100644 --- a/tests/TcpListenerTest.cpp +++ b/tests/TcpListenerTest.cpp @@ -28,7 +28,7 @@ TEST(TcpListenerTest, listener_accept) { fr::TcpSocket socket; socket.set_inet_version(fr::Socket::IP::v4); - auto ret = socket.connect("127.0.0.1", "9095"); + auto ret = socket.connect("127.0.0.1", "9095", std::chrono::seconds(5)); ASSERT_EQ(ret, fr::Socket::Success); };