diff --git a/include/frnetlib/SSLSocket.h b/include/frnetlib/SSLSocket.h index 322481e..53c2da0 100644 --- a/include/frnetlib/SSLSocket.h +++ b/include/frnetlib/SSLSocket.h @@ -98,6 +98,16 @@ namespace fr abort(); } + /*! + * Checks to see if we're connected to a socket or not + * + * @return True if it's connected. False otherwise. + */ + inline virtual bool connected() const override final + { + return ssl_socket_descriptor->fd > -1; + } + private: std::shared_ptr ssl_context; diff --git a/include/frnetlib/Socket.h b/include/frnetlib/Socket.h index 92cf6c4..5b5e1d9 100644 --- a/include/frnetlib/Socket.h +++ b/include/frnetlib/Socket.h @@ -131,10 +131,7 @@ namespace fr * * @return True if it's connected. False otherwise. */ - inline bool connected() const - { - return is_connected; - } + virtual bool connected() const =0; /*! * Gets the socket descriptor. @@ -162,6 +159,7 @@ namespace fr */ void set_inet_version(IP version); protected: + /*! * Applies requested socket options to the socket. * Should be called when a new socket is created. @@ -170,7 +168,6 @@ namespace fr std::string remote_address; bool is_blocking; - bool is_connected; std::mutex outbound_mutex; std::mutex inbound_mutex; int ai_family; diff --git a/include/frnetlib/TcpSocket.h b/include/frnetlib/TcpSocket.h index 47ece61..b997587 100644 --- a/include/frnetlib/TcpSocket.h +++ b/include/frnetlib/TcpSocket.h @@ -93,6 +93,16 @@ public: */ int32_t get_socket_descriptor() const override; + /*! + * Checks to see if we're connected to a socket or not + * + * @return True if it's connected. False otherwise. + */ + inline virtual bool connected() const override final + { + return socket_descriptor > -1; + } + protected: int32_t socket_descriptor; }; diff --git a/src/SSLSocket.cpp b/src/SSLSocket.cpp index b24c993..283b3cb 100644 --- a/src/SSLSocket.cpp +++ b/src/SSLSocket.cpp @@ -4,6 +4,8 @@ #include "frnetlib/SSLSocket.h" #include +#include + #ifdef SSL_ENABLED namespace fr @@ -27,13 +29,12 @@ namespace fr void SSLSocket::close_socket() { - if(is_connected) + if(ssl_socket_descriptor->fd > -1) { if(ssl) mbedtls_ssl_close_notify(ssl.get()); if(ssl_socket_descriptor) mbedtls_net_free(ssl_socket_descriptor.get()); - is_connected = false; } } @@ -44,7 +45,7 @@ namespace fr { if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE) { - is_connected = false; + close_socket(); return Socket::Status::Disconnected; } } @@ -64,7 +65,7 @@ namespace fr if(read == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { - is_connected = false; + close_socket(); return Socket::Status::Disconnected; } else if(read <= 0) @@ -136,7 +137,6 @@ namespace fr } //Update state - is_connected = true; remote_address = address + ":" + port; reconfigure_socket(); @@ -150,7 +150,6 @@ namespace fr void SSLSocket::set_net_context(std::unique_ptr context) { - is_connected = true; ssl_socket_descriptor = std::move(context); reconfigure_socket(); } diff --git a/src/Socket.cpp b/src/Socket.cpp index 4f5daf8..380cda9 100644 --- a/src/Socket.cpp +++ b/src/Socket.cpp @@ -16,7 +16,6 @@ namespace fr Socket::Socket() noexcept : is_blocking(true), - is_connected(false), ai_family(AF_UNSPEC) { if(instance_count == 0) @@ -38,7 +37,7 @@ namespace fr Socket::Status Socket::send(Packet &packet) { - if(!is_connected) + if(!connected()) return Socket::Disconnected; std::string &data = packet.get_buffer(); @@ -47,7 +46,7 @@ namespace fr Socket::Status Socket::send(Packet &&packet) { - if(!is_connected) + if(!connected()) return Socket::Disconnected; std::string &data = packet.get_buffer(); @@ -56,7 +55,7 @@ namespace fr Socket::Status Socket::receive(Packet &packet) { - if(!is_connected) + if(!connected()) return Socket::Disconnected; Socket::Status status; @@ -81,7 +80,7 @@ namespace fr Socket::Status Socket::receive_all(void *dest, size_t buffer_size) { - if(!is_connected) + if(!connected()) return Socket::Disconnected; ssize_t bytes_remaining = buffer_size; diff --git a/src/TcpSocket.cpp b/src/TcpSocket.cpp index 3ef975a..49ea2c0 100644 --- a/src/TcpSocket.cpp +++ b/src/TcpSocket.cpp @@ -9,6 +9,7 @@ namespace fr { TcpSocket::TcpSocket() noexcept + : socket_descriptor(-1) { } @@ -40,10 +41,10 @@ namespace fr void TcpSocket::close_socket() { - if(is_connected) + if(socket_descriptor > -1) { ::closesocket(socket_descriptor); - is_connected = false; + socket_descriptor = -1; } } @@ -80,7 +81,6 @@ namespace fr { reconfigure_socket(); socket_descriptor = descriptor; - is_connected = true; } Socket::Status TcpSocket::connect(const std::string &address, const std::string &port) @@ -124,7 +124,6 @@ namespace fr freeaddrinfo(info); //Update state now we've got a valid socket descriptor - is_connected = true; remote_address = address + ":" + port; reconfigure_socket();