Removed is_connected flag in fr::Socket

The socket descriptor is now checked, to see if it's greater than 0, instead. This means that the flag doesn't have to be manually updated.
This commit is contained in:
Fred Nicolson 2017-05-25 16:21:39 +01:00
parent b2bb3ce60b
commit ae61464aee
6 changed files with 34 additions and 20 deletions

View File

@ -98,6 +98,16 @@ namespace fr
abort(); abort();
} }
/*!
* Checks to see if we're connected to a socket or not
*
* @return True if it's connected. False otherwise.
*/
inline virtual bool connected() const override final
{
return ssl_socket_descriptor->fd > -1;
}
private: private:
std::shared_ptr<SSLContext> ssl_context; std::shared_ptr<SSLContext> ssl_context;

View File

@ -131,10 +131,7 @@ namespace fr
* *
* @return True if it's connected. False otherwise. * @return True if it's connected. False otherwise.
*/ */
inline bool connected() const virtual bool connected() const =0;
{
return is_connected;
}
/*! /*!
* Gets the socket descriptor. * Gets the socket descriptor.
@ -162,6 +159,7 @@ namespace fr
*/ */
void set_inet_version(IP version); void set_inet_version(IP version);
protected: protected:
/*! /*!
* Applies requested socket options to the socket. * Applies requested socket options to the socket.
* Should be called when a new socket is created. * Should be called when a new socket is created.
@ -170,7 +168,6 @@ namespace fr
std::string remote_address; std::string remote_address;
bool is_blocking; bool is_blocking;
bool is_connected;
std::mutex outbound_mutex; std::mutex outbound_mutex;
std::mutex inbound_mutex; std::mutex inbound_mutex;
int ai_family; int ai_family;

View File

@ -93,6 +93,16 @@ public:
*/ */
int32_t get_socket_descriptor() const override; int32_t get_socket_descriptor() const override;
/*!
* Checks to see if we're connected to a socket or not
*
* @return True if it's connected. False otherwise.
*/
inline virtual bool connected() const override final
{
return socket_descriptor > -1;
}
protected: protected:
int32_t socket_descriptor; int32_t socket_descriptor;
}; };

View File

@ -4,6 +4,8 @@
#include "frnetlib/SSLSocket.h" #include "frnetlib/SSLSocket.h"
#include <memory> #include <memory>
#include <mbedtls/net_sockets.h>
#ifdef SSL_ENABLED #ifdef SSL_ENABLED
namespace fr namespace fr
@ -27,13 +29,12 @@ namespace fr
void SSLSocket::close_socket() void SSLSocket::close_socket()
{ {
if(is_connected) if(ssl_socket_descriptor->fd > -1)
{ {
if(ssl) if(ssl)
mbedtls_ssl_close_notify(ssl.get()); mbedtls_ssl_close_notify(ssl.get());
if(ssl_socket_descriptor) if(ssl_socket_descriptor)
mbedtls_net_free(ssl_socket_descriptor.get()); mbedtls_net_free(ssl_socket_descriptor.get());
is_connected = false;
} }
} }
@ -44,7 +45,7 @@ namespace fr
{ {
if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE) if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE)
{ {
is_connected = false; close_socket();
return Socket::Status::Disconnected; return Socket::Status::Disconnected;
} }
} }
@ -64,7 +65,7 @@ namespace fr
if(read == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) if(read == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)
{ {
is_connected = false; close_socket();
return Socket::Status::Disconnected; return Socket::Status::Disconnected;
} }
else if(read <= 0) else if(read <= 0)
@ -136,7 +137,6 @@ namespace fr
} }
//Update state //Update state
is_connected = true;
remote_address = address + ":" + port; remote_address = address + ":" + port;
reconfigure_socket(); reconfigure_socket();
@ -150,7 +150,6 @@ namespace fr
void SSLSocket::set_net_context(std::unique_ptr<mbedtls_net_context> context) void SSLSocket::set_net_context(std::unique_ptr<mbedtls_net_context> context)
{ {
is_connected = true;
ssl_socket_descriptor = std::move(context); ssl_socket_descriptor = std::move(context);
reconfigure_socket(); reconfigure_socket();
} }

View File

@ -16,7 +16,6 @@ namespace fr
Socket::Socket() noexcept Socket::Socket() noexcept
: is_blocking(true), : is_blocking(true),
is_connected(false),
ai_family(AF_UNSPEC) ai_family(AF_UNSPEC)
{ {
if(instance_count == 0) if(instance_count == 0)
@ -38,7 +37,7 @@ namespace fr
Socket::Status Socket::send(Packet &packet) Socket::Status Socket::send(Packet &packet)
{ {
if(!is_connected) if(!connected())
return Socket::Disconnected; return Socket::Disconnected;
std::string &data = packet.get_buffer(); std::string &data = packet.get_buffer();
@ -47,7 +46,7 @@ namespace fr
Socket::Status Socket::send(Packet &&packet) Socket::Status Socket::send(Packet &&packet)
{ {
if(!is_connected) if(!connected())
return Socket::Disconnected; return Socket::Disconnected;
std::string &data = packet.get_buffer(); std::string &data = packet.get_buffer();
@ -56,7 +55,7 @@ namespace fr
Socket::Status Socket::receive(Packet &packet) Socket::Status Socket::receive(Packet &packet)
{ {
if(!is_connected) if(!connected())
return Socket::Disconnected; return Socket::Disconnected;
Socket::Status status; Socket::Status status;
@ -81,7 +80,7 @@ namespace fr
Socket::Status Socket::receive_all(void *dest, size_t buffer_size) Socket::Status Socket::receive_all(void *dest, size_t buffer_size)
{ {
if(!is_connected) if(!connected())
return Socket::Disconnected; return Socket::Disconnected;
ssize_t bytes_remaining = buffer_size; ssize_t bytes_remaining = buffer_size;

View File

@ -9,6 +9,7 @@ namespace fr
{ {
TcpSocket::TcpSocket() noexcept TcpSocket::TcpSocket() noexcept
: socket_descriptor(-1)
{ {
} }
@ -40,10 +41,10 @@ namespace fr
void TcpSocket::close_socket() void TcpSocket::close_socket()
{ {
if(is_connected) if(socket_descriptor > -1)
{ {
::closesocket(socket_descriptor); ::closesocket(socket_descriptor);
is_connected = false; socket_descriptor = -1;
} }
} }
@ -80,7 +81,6 @@ namespace fr
{ {
reconfigure_socket(); reconfigure_socket();
socket_descriptor = descriptor; socket_descriptor = descriptor;
is_connected = true;
} }
Socket::Status TcpSocket::connect(const std::string &address, const std::string &port) Socket::Status TcpSocket::connect(const std::string &address, const std::string &port)
@ -124,7 +124,6 @@ namespace fr
freeaddrinfo(info); freeaddrinfo(info);
//Update state now we've got a valid socket descriptor //Update state now we've got a valid socket descriptor
is_connected = true;
remote_address = address + ":" + port; remote_address = address + ":" + port;
reconfigure_socket(); reconfigure_socket();