Packet and socket optimisations

Sending/receiving data on a socket is roughly 5.5x faster, after removing data copies and buffer re-allocations.
This commit is contained in:
Fred Nicolson 2017-01-25 22:57:11 +00:00
parent 6b2947932a
commit fad7d0b81f
11 changed files with 148 additions and 205 deletions

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
.idea .idea
cmake-build-debug cmake-build-debug
cmake-build-release

View File

@ -12,11 +12,13 @@
namespace fr namespace fr
{ {
#define PACKET_HEADER_LENGTH sizeof(uint32_t)
class Packet class Packet
{ {
public: public:
Packet() noexcept Packet() noexcept
: buffer_offset(0) : buffer_read_index(PACKET_HEADER_LENGTH),
buffer(PACKET_HEADER_LENGTH, '0')
{ {
} }
@ -41,17 +43,6 @@ namespace fr
*this << part; *this << part;
} }
/*!
* Gets the data added to the packet
*
* @return A string containing all of the data added to the packet
*/
inline const std::string &get_buffer() const
{
return buffer;
}
/* /*
* Adds a vector to a packet * Adds a vector to a packet
*/ */
@ -108,8 +99,8 @@ namespace fr
{ {
assert_data_remaining(sizeof(var)); assert_data_remaining(sizeof(var));
memcpy(&var, &buffer[buffer_offset], sizeof(var)); memcpy(&var, &buffer[buffer_read_index], sizeof(var));
buffer_offset += sizeof(var); buffer_read_index += sizeof(var);
return *this; return *this;
} }
@ -131,8 +122,8 @@ namespace fr
{ {
assert_data_remaining(sizeof(var)); assert_data_remaining(sizeof(var));
memcpy(&var, &buffer[buffer_offset], sizeof(var)); memcpy(&var, &buffer[buffer_read_index], sizeof(var));
buffer_offset += sizeof(var); buffer_read_index += sizeof(var);
var = ntohs(var); var = ntohs(var);
return *this; return *this;
} }
@ -155,8 +146,8 @@ namespace fr
{ {
assert_data_remaining(sizeof(var)); assert_data_remaining(sizeof(var));
memcpy(&var, &buffer[buffer_offset], sizeof(var)); memcpy(&var, &buffer[buffer_read_index], sizeof(var));
buffer_offset += sizeof(var); buffer_read_index += sizeof(var);
var = ntohl(var); var = ntohl(var);
return *this; return *this;
} }
@ -179,8 +170,8 @@ namespace fr
{ {
assert_data_remaining(sizeof(var)); assert_data_remaining(sizeof(var));
memcpy(&var, &buffer[buffer_offset], sizeof(var)); memcpy(&var, &buffer[buffer_read_index], sizeof(var));
buffer_offset += sizeof(var); buffer_read_index += sizeof(var);
var = ntohll(var); var = ntohll(var);
return *this; return *this;
} }
@ -203,8 +194,8 @@ namespace fr
{ {
assert_data_remaining(sizeof(var)); assert_data_remaining(sizeof(var));
memcpy(&var, &buffer[buffer_offset], sizeof(var)); memcpy(&var, &buffer[buffer_read_index], sizeof(var));
buffer_offset += sizeof(var); buffer_read_index += sizeof(var);
var = ntohs((uint16_t)var); var = ntohs((uint16_t)var);
return *this; return *this;
} }
@ -227,8 +218,8 @@ namespace fr
{ {
assert_data_remaining(sizeof(var)); assert_data_remaining(sizeof(var));
memcpy(&var, &buffer[buffer_offset], sizeof(var)); memcpy(&var, &buffer[buffer_read_index], sizeof(var));
buffer_offset += sizeof(var); buffer_read_index += sizeof(var);
var = ntohl((uint32_t)var); var = ntohl((uint32_t)var);
return *this; return *this;
} }
@ -251,8 +242,8 @@ namespace fr
{ {
assert_data_remaining(sizeof(var)); assert_data_remaining(sizeof(var));
memcpy(&var, &buffer[buffer_offset], sizeof(var)); memcpy(&var, &buffer[buffer_read_index], sizeof(var));
buffer_offset += sizeof(var); buffer_read_index += sizeof(var);
var = ntohll((uint64_t)var); var = ntohll((uint64_t)var);
return *this; return *this;
} }
@ -275,8 +266,8 @@ namespace fr
{ {
assert_data_remaining(sizeof(var)); assert_data_remaining(sizeof(var));
memcpy(&var, &buffer[buffer_offset], sizeof(var)); memcpy(&var, &buffer[buffer_read_index], sizeof(var));
buffer_offset += sizeof(var); buffer_read_index += sizeof(var);
var = ntohf(var); var = ntohf(var);
return *this; return *this;
} }
@ -299,8 +290,8 @@ namespace fr
{ {
assert_data_remaining(sizeof(var)); assert_data_remaining(sizeof(var));
memcpy(&var, &buffer[buffer_offset], sizeof(var)); memcpy(&var, &buffer[buffer_read_index], sizeof(var));
buffer_offset += sizeof(var);; buffer_read_index += sizeof(var);;
var = ntohd(var); var = ntohd(var);
return *this; return *this;
} }
@ -333,8 +324,8 @@ namespace fr
uint32_t length; uint32_t length;
*this >> length; *this >> length;
var = buffer.substr(buffer_offset, length); var = buffer.substr(buffer_read_index, length);
buffer_offset += length; buffer_read_index += length;
return *this; return *this;
} }
@ -354,20 +345,50 @@ namespace fr
*/ */
inline void clear() inline void clear()
{ {
buffer.clear(); buffer.erase(PACKET_HEADER_LENGTH, buffer.size() - PACKET_HEADER_LENGTH);
buffer_offset = 0; buffer_read_index = PACKET_HEADER_LENGTH;
} }
/*! /*!
* Resets the buffer read cursor back to the beginning * Resets the read cursor back to 0, or a specified position.
* of the packet. *
* @param pos The buffer index to continue reading from.
*/ */
inline void reset_read_cursor() inline void reset_read_cursor(size_t pos = 0)
{ {
buffer_offset = 0; buffer_read_index = PACKET_HEADER_LENGTH + pos;
}
/*!
* Reserves space in the internal packet buffer,
* for if you know how much data you expect to receive
* or send.
*
* @param bytes The number of bytes to reserve
*/
inline void reserve(size_t bytes)
{
buffer.reserve(PACKET_HEADER_LENGTH + bytes);
} }
private: private:
friend class Socket;
/*!
* Gets the data added to the packet
*
* @return A string containing all of the data added to the packet
*/
inline std::string &get_buffer()
{
//Update packet length first
uint32_t length = htonl((uint32_t)buffer.size() - PACKET_HEADER_LENGTH);
memcpy(&buffer[0], &length, sizeof(uint32_t));
//Then a reference to the buffer
return buffer;
}
/*! /*!
* Checks that there's enough data in the buffer to extract * Checks that there's enough data in the buffer to extract
* a given number of bytes to prevent buffer overflows. * a given number of bytes to prevent buffer overflows.
@ -377,12 +398,13 @@ namespace fr
*/ */
inline void assert_data_remaining(size_t required_space) inline void assert_data_remaining(size_t required_space)
{ {
if(buffer_offset + required_space > buffer.size()) if(buffer_read_index + required_space > buffer.size())
throw std::out_of_range("Not enough bytes remaining in packet to extract requested"); throw std::out_of_range("Not enough bytes remaining in packet to extract requested");
} }
std::string buffer; //Packet data buffer std::string buffer; //Packet data buffer
size_t buffer_offset; //Current read position size_t buffer_read_index; //Current read position
}; };
} }

View File

@ -98,17 +98,7 @@ namespace fr
abort(); abort();
} }
/*!
* Checks to see if there's data still in the socket's
* recv buffer.
*
* @return True if there is data in the buffer, false otherwise.
*/
virtual bool has_data() const override;
private: private:
std::string unprocessed_buffer;
std::unique_ptr<char[]> recv_buffer;
std::shared_ptr<SSLContext> ssl_context; std::shared_ptr<SSLContext> ssl_context;
std::unique_ptr<mbedtls_net_context> ssl_socket_descriptor; std::unique_ptr<mbedtls_net_context> ssl_socket_descriptor;

View File

@ -5,7 +5,7 @@
#ifndef FRNETLIB_SOCKET_H #ifndef FRNETLIB_SOCKET_H
#define FRNETLIB_SOCKET_H #define FRNETLIB_SOCKET_H
#include <mutex>
#include "NetworkEncoding.h" #include "NetworkEncoding.h"
#include "Packet.h" #include "Packet.h"
@ -95,7 +95,8 @@ namespace fr
* @param packet The packet to send * @param packet The packet to send
* @return True on success, false on failure. * @return True on success, false on failure.
*/ */
Status send(const Packet &packet); Status send(Packet &packet);
Status send(Packet &&packet);
/*! /*!
* Receive a packet through the socket * Receive a packet through the socket
@ -146,13 +147,6 @@ namespace fr
*/ */
void shutdown(); void shutdown();
/*!
* Checks to see if there's data still in the socket's
* recv buffer.
*
* @return True if there is data in the buffer, false otherwise.
*/
virtual bool has_data() const = 0;
protected: protected:
/*! /*!
@ -164,6 +158,8 @@ namespace fr
std::string remote_address; std::string remote_address;
bool is_blocking; bool is_blocking;
bool is_connected; bool is_connected;
std::mutex outbound_mutex;
std::mutex inbound_mutex;
#ifdef _WIN32 #ifdef _WIN32
static WSADATA wsaData; static WSADATA wsaData;

View File

@ -45,7 +45,6 @@ private:
virtual fr::Socket::Status send_raw(const char*, size_t){return Socket::Error;} virtual fr::Socket::Status send_raw(const char*, size_t){return Socket::Error;}
virtual fr::Socket::Status receive_raw(void*, size_t, size_t&){return Socket::Error;} virtual fr::Socket::Status receive_raw(void*, size_t, size_t&){return Socket::Error;}
virtual int32_t get_socket_descriptor() const {return socket_descriptor;} virtual int32_t get_socket_descriptor() const {return socket_descriptor;}
virtual bool has_data() const override {return false;};
}; };
} }

View File

@ -19,9 +19,7 @@ public:
TcpSocket() noexcept; TcpSocket() noexcept;
virtual ~TcpSocket() noexcept; virtual ~TcpSocket() noexcept;
TcpSocket(TcpSocket &&other) noexcept TcpSocket(TcpSocket &&other) noexcept
: unprocessed_buffer(std::move(other.unprocessed_buffer)), : socket_descriptor(other.socket_descriptor){}
recv_buffer(std::move(other.recv_buffer)),
socket_descriptor(other.socket_descriptor){}
void operator=(const TcpSocket &other)=delete; void operator=(const TcpSocket &other)=delete;
/*! /*!
@ -95,20 +93,8 @@ public:
*/ */
int32_t get_socket_descriptor() const override; int32_t get_socket_descriptor() const override;
/*!
* Checks to see if there's data still in the socket's
* recv buffer.
*
* @return True if there is data in the buffer, false otherwise.
*/
virtual bool has_data() const override;
protected: protected:
std::string unprocessed_buffer;
std::unique_ptr<char[]> recv_buffer;
int32_t socket_descriptor; int32_t socket_descriptor;
std::mutex outbound_mutex;
std::mutex inbound_mutex;
}; };
} }

View File

@ -3,6 +3,7 @@
#include <thread> #include <thread>
#include <atomic> #include <atomic>
#include <mutex> #include <mutex>
#include <chrono>
#include "frnetlib/Packet.h" #include "frnetlib/Packet.h"
#include "frnetlib/TcpSocket.h" #include "frnetlib/TcpSocket.h"
#include "frnetlib/TcpListener.h" #include "frnetlib/TcpListener.h"
@ -17,65 +18,50 @@
void server() void server()
{ {
fr::TcpListener listener; fr::TcpListener listener;
fr::TcpSocket client; listener.listen("9092");
listener.listen("8081"); fr::TcpSocket socket;
listener.accept(client); listener.accept(socket);
uint32_t packet_no = 0;
uint64_t packet_count = 0;
auto last_print_time = std::chrono::system_clock::now();
while(true) while(true)
{ {
fr::Packet packet; fr::Packet packet;
client.receive(packet); if(socket.receive(packet) != fr::Socket::Success)
break;
uint32_t num = 0; std::string s1;
packet >> num; packet >> s1;
if(num != ++packet_no) packet_count++;
if(last_print_time + std::chrono::seconds(1) < std::chrono::system_clock::now())
{ {
std::cout << "Packet mismatch. Expected " << packet_no + 1 << ". Got " << num << std::endl; std::cout << "Packets per second: " << packet_count << std::endl;
return; packet_count = 0;
last_print_time = std::chrono::system_clock::now();
} }
} }
}
void client()
{
fr::TcpSocket server;
server.connect("127.0.0.1", "8081");
uint32_t packet_no = 0;
std::mutex m1;
auto lam = [&]()
{
while(true)
{
m1.lock();
fr::Packet packet;
packet << ++packet_no;
m1.unlock();
server.send(packet);
}
};
std::thread t1(lam);
std::thread t2(lam);
std::thread t3(lam);
std::thread t4(lam);
t1.join();
} }
int main() int main()
{ {
std::thread s1(server); std::thread server_thread(server);
std::this_thread::sleep_for(std::chrono::milliseconds(100)); std::this_thread::sleep_for(std::chrono::milliseconds(100));
std::thread c1(client);
s1.join(); fr::TcpSocket socket;
c1.join(); socket.connect("127.0.0.1", "9092");
return 0;
std::string a(32384, 'c');
fr::Packet packet;
while(true)
{
packet << a;
if(socket.send(packet) != fr::Socket::Success)
break;
packet.clear();
}
server_thread.join();
} }

View File

@ -9,8 +9,7 @@
namespace fr namespace fr
{ {
SSLSocket::SSLSocket(std::shared_ptr<SSLContext> ssl_context_) noexcept SSLSocket::SSLSocket(std::shared_ptr<SSLContext> ssl_context_) noexcept
: recv_buffer(new char[RECV_CHUNK_SIZE]), : ssl_context(ssl_context_)
ssl_context(ssl_context_)
{ {
//Initialise mbedtls structures //Initialise mbedtls structures
mbedtls_ssl_config_init(&conf); mbedtls_ssl_config_init(&conf);
@ -55,15 +54,12 @@ namespace fr
Socket::Status SSLSocket::receive_raw(void *data, size_t data_size, size_t &received) Socket::Status SSLSocket::receive_raw(void *data, size_t data_size, size_t &received)
{ {
std::lock_guard<std::mutex> guard(inbound_mutex);
int read = MBEDTLS_ERR_SSL_WANT_READ; int read = MBEDTLS_ERR_SSL_WANT_READ;
received = 0; received = 0;
if(unprocessed_buffer.size() < data_size)
{
while(read == MBEDTLS_ERR_SSL_WANT_READ || read == MBEDTLS_ERR_SSL_WANT_WRITE) while(read == MBEDTLS_ERR_SSL_WANT_READ || read == MBEDTLS_ERR_SSL_WANT_WRITE)
{ {
read = mbedtls_ssl_read(ssl.get(), (unsigned char *)recv_buffer.get(), RECV_CHUNK_SIZE); read = mbedtls_ssl_read(ssl.get(), (unsigned char *)data, data_size);
} }
if(read == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) if(read == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)
@ -78,19 +74,6 @@ namespace fr
} }
received += read; received += read;
unprocessed_buffer += {recv_buffer.get(), (size_t)read};
if(received > data_size)
received = data_size;
}
else
{
received = data_size;
}
//Copy data to where it needs to go
memcpy(data, &unprocessed_buffer[0], received);
unprocessed_buffer.erase(0, received);
return Socket::Status::Success; return Socket::Status::Success;
} }
@ -171,11 +154,6 @@ namespace fr
ssl_socket_descriptor = std::move(context); ssl_socket_descriptor = std::move(context);
reconfigure_socket(); reconfigure_socket();
} }
bool SSLSocket::has_data() const
{
return !unprocessed_buffer.empty();
}
} }
#endif #endif

View File

@ -2,6 +2,7 @@
// Created by fred on 06/12/16. // Created by fred on 06/12/16.
// //
#include <mutex>
#include "frnetlib/NetworkEncoding.h" #include "frnetlib/NetworkEncoding.h"
#include "frnetlib/Socket.h" #include "frnetlib/Socket.h"
@ -30,20 +31,21 @@ namespace fr
#endif // _WIN32 #endif // _WIN32
} }
Socket::Status Socket::send(const Packet &packet) Socket::Status Socket::send(Packet &packet)
{ {
if(!is_connected) if(!is_connected)
return Socket::Disconnected; return Socket::Disconnected;
//Get packet data std::string &data = packet.get_buffer();
std::string data = packet.get_buffer(); return send_raw(data.c_str(), data.size());
}
//Prepend packet length Socket::Status Socket::send(Packet &&packet)
uint32_t length = htonl((uint32_t)data.size()); {
data.insert(0, "1234"); if(!is_connected)
memcpy(&data[0], &length, sizeof(uint32_t)); return Socket::Disconnected;
//Send it std::string &data = packet.get_buffer();
return send_raw(data.c_str(), data.size()); return send_raw(data.c_str(), data.size());
} }
@ -53,6 +55,7 @@ namespace fr
return Socket::Disconnected; return Socket::Disconnected;
Socket::Status status; Socket::Status status;
std::lock_guard<std::mutex> guard(inbound_mutex);
//Try to read packet length //Try to read packet length
uint32_t packet_length = 0; uint32_t packet_length = 0;
@ -62,13 +65,11 @@ namespace fr
packet_length = ntohl(packet_length); packet_length = ntohl(packet_length);
//Now we've got the length, read the rest of the data in //Now we've got the length, read the rest of the data in
std::string data(packet_length, 'c'); packet.buffer.resize(packet_length + PACKET_HEADER_LENGTH);
status = receive_all(&data[0], packet_length); status = receive_all(&packet.buffer[PACKET_HEADER_LENGTH], packet_length);
if(status != Socket::Status::Success) if(status != Socket::Status::Success)
return status; return status;
//Set the packet to what we've read
packet.set_buffer(std::move(data));
return Socket::Status::Success; return Socket::Status::Success;
} }

View File

@ -3,6 +3,7 @@
// //
#include <thread> #include <thread>
#include <mutex>
#include "frnetlib/SocketSelector.h" #include "frnetlib/SocketSelector.h"
namespace fr namespace fr

View File

@ -9,7 +9,6 @@ namespace fr
{ {
TcpSocket::TcpSocket() noexcept TcpSocket::TcpSocket() noexcept
: recv_buffer(new char[RECV_CHUNK_SIZE])
{ {
} }
@ -56,16 +55,13 @@ namespace fr
Socket::Status TcpSocket::receive_raw(void *data, size_t buffer_size, size_t &received) Socket::Status TcpSocket::receive_raw(void *data, size_t buffer_size, size_t &received)
{ {
std::lock_guard<std::mutex> guard(inbound_mutex);
received = 0; received = 0;
if(unprocessed_buffer.size() < buffer_size)
{
//Read RECV_CHUNK_SIZE bytes into the recv buffer //Read RECV_CHUNK_SIZE bytes into the recv buffer
ssize_t status = ::recv(socket_descriptor, recv_buffer.get(), RECV_CHUNK_SIZE, 0); ssize_t status = ::recv(socket_descriptor, data, buffer_size, 0);
if(status > 0) if(status > 0)
{ {
unprocessed_buffer += {recv_buffer.get(), (size_t)status};
received += status; received += status;
} }
else else
@ -85,15 +81,7 @@ namespace fr
if(received > buffer_size) if(received > buffer_size)
received = buffer_size; received = buffer_size;
}
else
{
received = buffer_size;
}
//Copy data to where it needs to go
memcpy(data, &unprocessed_buffer[0], received);
unprocessed_buffer.erase(0, received);
return Socket::Status::Success; return Socket::Status::Success;
} }
@ -162,9 +150,4 @@ namespace fr
{ {
return socket_descriptor; return socket_descriptor;
} }
bool TcpSocket::has_data() const
{
return !unprocessed_buffer.empty();
}
} }