diff --git a/.gitignore b/.gitignore index 8e24b65..82aa793 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .idea cmake-build-debug +cmake-build-release diff --git a/include/frnetlib/Packet.h b/include/frnetlib/Packet.h index 22f4ba6..964e9b9 100644 --- a/include/frnetlib/Packet.h +++ b/include/frnetlib/Packet.h @@ -12,11 +12,13 @@ namespace fr { +#define PACKET_HEADER_LENGTH sizeof(uint32_t) class Packet { public: Packet() noexcept - : buffer_offset(0) + : buffer_read_index(PACKET_HEADER_LENGTH), + buffer(PACKET_HEADER_LENGTH, '0') { } @@ -41,17 +43,6 @@ namespace fr *this << part; } - /*! - * Gets the data added to the packet - * - * @return A string containing all of the data added to the packet - */ - inline const std::string &get_buffer() const - { - return buffer; - } - - /* * Adds a vector to a packet */ @@ -108,8 +99,8 @@ namespace fr { assert_data_remaining(sizeof(var)); - memcpy(&var, &buffer[buffer_offset], sizeof(var)); - buffer_offset += sizeof(var); + memcpy(&var, &buffer[buffer_read_index], sizeof(var)); + buffer_read_index += sizeof(var); return *this; } @@ -131,8 +122,8 @@ namespace fr { assert_data_remaining(sizeof(var)); - memcpy(&var, &buffer[buffer_offset], sizeof(var)); - buffer_offset += sizeof(var); + memcpy(&var, &buffer[buffer_read_index], sizeof(var)); + buffer_read_index += sizeof(var); var = ntohs(var); return *this; } @@ -155,8 +146,8 @@ namespace fr { assert_data_remaining(sizeof(var)); - memcpy(&var, &buffer[buffer_offset], sizeof(var)); - buffer_offset += sizeof(var); + memcpy(&var, &buffer[buffer_read_index], sizeof(var)); + buffer_read_index += sizeof(var); var = ntohl(var); return *this; } @@ -179,8 +170,8 @@ namespace fr { assert_data_remaining(sizeof(var)); - memcpy(&var, &buffer[buffer_offset], sizeof(var)); - buffer_offset += sizeof(var); + memcpy(&var, &buffer[buffer_read_index], sizeof(var)); + buffer_read_index += sizeof(var); var = ntohll(var); return *this; } @@ -203,8 +194,8 @@ namespace fr { assert_data_remaining(sizeof(var)); - memcpy(&var, &buffer[buffer_offset], sizeof(var)); - buffer_offset += sizeof(var); + memcpy(&var, &buffer[buffer_read_index], sizeof(var)); + buffer_read_index += sizeof(var); var = ntohs((uint16_t)var); return *this; } @@ -227,8 +218,8 @@ namespace fr { assert_data_remaining(sizeof(var)); - memcpy(&var, &buffer[buffer_offset], sizeof(var)); - buffer_offset += sizeof(var); + memcpy(&var, &buffer[buffer_read_index], sizeof(var)); + buffer_read_index += sizeof(var); var = ntohl((uint32_t)var); return *this; } @@ -251,8 +242,8 @@ namespace fr { assert_data_remaining(sizeof(var)); - memcpy(&var, &buffer[buffer_offset], sizeof(var)); - buffer_offset += sizeof(var); + memcpy(&var, &buffer[buffer_read_index], sizeof(var)); + buffer_read_index += sizeof(var); var = ntohll((uint64_t)var); return *this; } @@ -275,8 +266,8 @@ namespace fr { assert_data_remaining(sizeof(var)); - memcpy(&var, &buffer[buffer_offset], sizeof(var)); - buffer_offset += sizeof(var); + memcpy(&var, &buffer[buffer_read_index], sizeof(var)); + buffer_read_index += sizeof(var); var = ntohf(var); return *this; } @@ -299,8 +290,8 @@ namespace fr { assert_data_remaining(sizeof(var)); - memcpy(&var, &buffer[buffer_offset], sizeof(var)); - buffer_offset += sizeof(var);; + memcpy(&var, &buffer[buffer_read_index], sizeof(var)); + buffer_read_index += sizeof(var);; var = ntohd(var); return *this; } @@ -333,8 +324,8 @@ namespace fr uint32_t length; *this >> length; - var = buffer.substr(buffer_offset, length); - buffer_offset += length; + var = buffer.substr(buffer_read_index, length); + buffer_read_index += length; return *this; } @@ -354,20 +345,50 @@ namespace fr */ inline void clear() { - buffer.clear(); - buffer_offset = 0; + buffer.erase(PACKET_HEADER_LENGTH, buffer.size() - PACKET_HEADER_LENGTH); + buffer_read_index = PACKET_HEADER_LENGTH; } /*! - * Resets the buffer read cursor back to the beginning - * of the packet. + * Resets the read cursor back to 0, or a specified position. + * + * @param pos The buffer index to continue reading from. */ - inline void reset_read_cursor() + inline void reset_read_cursor(size_t pos = 0) { - buffer_offset = 0; + buffer_read_index = PACKET_HEADER_LENGTH + pos; + } + + /*! + * Reserves space in the internal packet buffer, + * for if you know how much data you expect to receive + * or send. + * + * @param bytes The number of bytes to reserve + */ + inline void reserve(size_t bytes) + { + buffer.reserve(PACKET_HEADER_LENGTH + bytes); } private: + friend class Socket; + + /*! + * Gets the data added to the packet + * + * @return A string containing all of the data added to the packet + */ + inline std::string &get_buffer() + { + //Update packet length first + uint32_t length = htonl((uint32_t)buffer.size() - PACKET_HEADER_LENGTH); + memcpy(&buffer[0], &length, sizeof(uint32_t)); + + //Then a reference to the buffer + return buffer; + } + /*! * Checks that there's enough data in the buffer to extract * a given number of bytes to prevent buffer overflows. @@ -377,12 +398,13 @@ namespace fr */ inline void assert_data_remaining(size_t required_space) { - if(buffer_offset + required_space > buffer.size()) + if(buffer_read_index + required_space > buffer.size()) throw std::out_of_range("Not enough bytes remaining in packet to extract requested"); } + std::string buffer; //Packet data buffer - size_t buffer_offset; //Current read position + size_t buffer_read_index; //Current read position }; } diff --git a/include/frnetlib/SSLSocket.h b/include/frnetlib/SSLSocket.h index 27e6f2a..6db642e 100644 --- a/include/frnetlib/SSLSocket.h +++ b/include/frnetlib/SSLSocket.h @@ -98,17 +98,7 @@ namespace fr abort(); } - /*! - * Checks to see if there's data still in the socket's - * recv buffer. - * - * @return True if there is data in the buffer, false otherwise. - */ - virtual bool has_data() const override; - private: - std::string unprocessed_buffer; - std::unique_ptr recv_buffer; std::shared_ptr ssl_context; std::unique_ptr ssl_socket_descriptor; diff --git a/include/frnetlib/Socket.h b/include/frnetlib/Socket.h index c825f3b..37e6c60 100644 --- a/include/frnetlib/Socket.h +++ b/include/frnetlib/Socket.h @@ -5,7 +5,7 @@ #ifndef FRNETLIB_SOCKET_H #define FRNETLIB_SOCKET_H - +#include #include "NetworkEncoding.h" #include "Packet.h" @@ -95,7 +95,8 @@ namespace fr * @param packet The packet to send * @return True on success, false on failure. */ - Status send(const Packet &packet); + Status send(Packet &packet); + Status send(Packet &&packet); /*! * Receive a packet through the socket @@ -146,13 +147,6 @@ namespace fr */ void shutdown(); - /*! - * Checks to see if there's data still in the socket's - * recv buffer. - * - * @return True if there is data in the buffer, false otherwise. - */ - virtual bool has_data() const = 0; protected: /*! @@ -164,6 +158,8 @@ namespace fr std::string remote_address; bool is_blocking; bool is_connected; + std::mutex outbound_mutex; + std::mutex inbound_mutex; #ifdef _WIN32 static WSADATA wsaData; diff --git a/include/frnetlib/TcpListener.h b/include/frnetlib/TcpListener.h index d5c54d5..a84120e 100644 --- a/include/frnetlib/TcpListener.h +++ b/include/frnetlib/TcpListener.h @@ -45,7 +45,6 @@ private: virtual fr::Socket::Status send_raw(const char*, size_t){return Socket::Error;} virtual fr::Socket::Status receive_raw(void*, size_t, size_t&){return Socket::Error;} virtual int32_t get_socket_descriptor() const {return socket_descriptor;} - virtual bool has_data() const override {return false;}; }; } diff --git a/include/frnetlib/TcpSocket.h b/include/frnetlib/TcpSocket.h index 8797bb3..47ece61 100644 --- a/include/frnetlib/TcpSocket.h +++ b/include/frnetlib/TcpSocket.h @@ -19,9 +19,7 @@ public: TcpSocket() noexcept; virtual ~TcpSocket() noexcept; TcpSocket(TcpSocket &&other) noexcept - : unprocessed_buffer(std::move(other.unprocessed_buffer)), - recv_buffer(std::move(other.recv_buffer)), - socket_descriptor(other.socket_descriptor){} + : socket_descriptor(other.socket_descriptor){} void operator=(const TcpSocket &other)=delete; /*! @@ -95,20 +93,8 @@ public: */ int32_t get_socket_descriptor() const override; - /*! - * Checks to see if there's data still in the socket's - * recv buffer. - * - * @return True if there is data in the buffer, false otherwise. - */ - virtual bool has_data() const override; - protected: - std::string unprocessed_buffer; - std::unique_ptr recv_buffer; int32_t socket_descriptor; - std::mutex outbound_mutex; - std::mutex inbound_mutex; }; } diff --git a/main.cpp b/main.cpp index 5cb3f81..3428fbb 100644 --- a/main.cpp +++ b/main.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include "frnetlib/Packet.h" #include "frnetlib/TcpSocket.h" #include "frnetlib/TcpListener.h" @@ -17,65 +18,50 @@ void server() { fr::TcpListener listener; - fr::TcpSocket client; + listener.listen("9092"); - listener.listen("8081"); - listener.accept(client); - - uint32_t packet_no = 0; + fr::TcpSocket socket; + listener.accept(socket); + uint64_t packet_count = 0; + auto last_print_time = std::chrono::system_clock::now(); while(true) { fr::Packet packet; - client.receive(packet); + if(socket.receive(packet) != fr::Socket::Success) + break; - uint32_t num = 0; - packet >> num; + std::string s1; + packet >> s1; - if(num != ++packet_no) + packet_count++; + if(last_print_time + std::chrono::seconds(1) < std::chrono::system_clock::now()) { - std::cout << "Packet mismatch. Expected " << packet_no + 1 << ". Got " << num << std::endl; - return; + std::cout << "Packets per second: " << packet_count << std::endl; + packet_count = 0; + last_print_time = std::chrono::system_clock::now(); } } -} -void client() -{ - fr::TcpSocket server; - server.connect("127.0.0.1", "8081"); - - uint32_t packet_no = 0; - std::mutex m1; - - auto lam = [&]() - { - while(true) - { - m1.lock(); - fr::Packet packet; - packet << ++packet_no; - m1.unlock(); - - server.send(packet); - } - }; - - std::thread t1(lam); - std::thread t2(lam); - std::thread t3(lam); - std::thread t4(lam); - t1.join(); } int main() { - std::thread s1(server); + std::thread server_thread(server); std::this_thread::sleep_for(std::chrono::milliseconds(100)); - std::thread c1(client); - s1.join(); - c1.join(); - return 0; + fr::TcpSocket socket; + socket.connect("127.0.0.1", "9092"); + + std::string a(32384, 'c'); + fr::Packet packet; + while(true) + { + packet << a; + if(socket.send(packet) != fr::Socket::Success) + break; + packet.clear(); + } + server_thread.join(); } \ No newline at end of file diff --git a/src/SSLSocket.cpp b/src/SSLSocket.cpp index 486eaf0..a6c5e6d 100644 --- a/src/SSLSocket.cpp +++ b/src/SSLSocket.cpp @@ -9,8 +9,7 @@ namespace fr { SSLSocket::SSLSocket(std::shared_ptr ssl_context_) noexcept - : recv_buffer(new char[RECV_CHUNK_SIZE]), - ssl_context(ssl_context_) + : ssl_context(ssl_context_) { //Initialise mbedtls structures mbedtls_ssl_config_init(&conf); @@ -55,42 +54,26 @@ namespace fr Socket::Status SSLSocket::receive_raw(void *data, size_t data_size, size_t &received) { - std::lock_guard guard(inbound_mutex); - int read = MBEDTLS_ERR_SSL_WANT_READ; received = 0; - if(unprocessed_buffer.size() < data_size) + + while(read == MBEDTLS_ERR_SSL_WANT_READ || read == MBEDTLS_ERR_SSL_WANT_WRITE) { - while(read == MBEDTLS_ERR_SSL_WANT_READ || read == MBEDTLS_ERR_SSL_WANT_WRITE) - { - read = mbedtls_ssl_read(ssl.get(), (unsigned char *)recv_buffer.get(), RECV_CHUNK_SIZE); - } - - if(read == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) - { - is_connected = false; - return Socket::Status::Disconnected; - } - else if(read <= 0) - { - //No data. But no error occurred. - return Socket::Status::Success; - } - - received += read; - unprocessed_buffer += {recv_buffer.get(), (size_t)read}; - - if(received > data_size) - received = data_size; - } - else - { - received = data_size; + read = mbedtls_ssl_read(ssl.get(), (unsigned char *)data, data_size); } - //Copy data to where it needs to go - memcpy(data, &unprocessed_buffer[0], received); - unprocessed_buffer.erase(0, received); + if(read == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) + { + is_connected = false; + return Socket::Status::Disconnected; + } + else if(read <= 0) + { + //No data. But no error occurred. + return Socket::Status::Success; + } + + received += read; return Socket::Status::Success; } @@ -171,11 +154,6 @@ namespace fr ssl_socket_descriptor = std::move(context); reconfigure_socket(); } - - bool SSLSocket::has_data() const - { - return !unprocessed_buffer.empty(); - } } #endif \ No newline at end of file diff --git a/src/Socket.cpp b/src/Socket.cpp index d92416e..04220aa 100644 --- a/src/Socket.cpp +++ b/src/Socket.cpp @@ -2,6 +2,7 @@ // Created by fred on 06/12/16. // +#include #include "frnetlib/NetworkEncoding.h" #include "frnetlib/Socket.h" @@ -30,20 +31,21 @@ namespace fr #endif // _WIN32 } - Socket::Status Socket::send(const Packet &packet) + Socket::Status Socket::send(Packet &packet) { if(!is_connected) return Socket::Disconnected; - //Get packet data - std::string data = packet.get_buffer(); + std::string &data = packet.get_buffer(); + return send_raw(data.c_str(), data.size()); + } - //Prepend packet length - uint32_t length = htonl((uint32_t)data.size()); - data.insert(0, "1234"); - memcpy(&data[0], &length, sizeof(uint32_t)); + Socket::Status Socket::send(Packet &&packet) + { + if(!is_connected) + return Socket::Disconnected; - //Send it + std::string &data = packet.get_buffer(); return send_raw(data.c_str(), data.size()); } @@ -53,6 +55,7 @@ namespace fr return Socket::Disconnected; Socket::Status status; + std::lock_guard guard(inbound_mutex); //Try to read packet length uint32_t packet_length = 0; @@ -62,13 +65,11 @@ namespace fr packet_length = ntohl(packet_length); //Now we've got the length, read the rest of the data in - std::string data(packet_length, 'c'); - status = receive_all(&data[0], packet_length); + packet.buffer.resize(packet_length + PACKET_HEADER_LENGTH); + status = receive_all(&packet.buffer[PACKET_HEADER_LENGTH], packet_length); if(status != Socket::Status::Success) return status; - //Set the packet to what we've read - packet.set_buffer(std::move(data)); return Socket::Status::Success; } diff --git a/src/SocketSelector.cpp b/src/SocketSelector.cpp index 75d4a0e..24f4846 100644 --- a/src/SocketSelector.cpp +++ b/src/SocketSelector.cpp @@ -3,6 +3,7 @@ // #include +#include #include "frnetlib/SocketSelector.h" namespace fr diff --git a/src/TcpSocket.cpp b/src/TcpSocket.cpp index bf75320..8ca2e2d 100644 --- a/src/TcpSocket.cpp +++ b/src/TcpSocket.cpp @@ -9,7 +9,6 @@ namespace fr { TcpSocket::TcpSocket() noexcept - : recv_buffer(new char[RECV_CHUNK_SIZE]) { } @@ -56,44 +55,33 @@ namespace fr Socket::Status TcpSocket::receive_raw(void *data, size_t buffer_size, size_t &received) { - std::lock_guard guard(inbound_mutex); received = 0; - if(unprocessed_buffer.size() < buffer_size) + + //Read RECV_CHUNK_SIZE bytes into the recv buffer + ssize_t status = ::recv(socket_descriptor, data, buffer_size, 0); + + if(status > 0) { - //Read RECV_CHUNK_SIZE bytes into the recv buffer - ssize_t status = ::recv(socket_descriptor, recv_buffer.get(), RECV_CHUNK_SIZE, 0); - - if(status > 0) - { - unprocessed_buffer += {recv_buffer.get(), (size_t)status}; - received += status; - } - else - { - if(errno == EWOULDBLOCK || errno == EAGAIN) - { - return Socket::Status::WouldBlock; - } - else if(status == -1) - { - return Socket::Status::Error; - } - - is_connected = false; - return Socket::Status::Disconnected; - } - - if(received > buffer_size) - received = buffer_size; + received += status; } else { - received = buffer_size; + if(errno == EWOULDBLOCK || errno == EAGAIN) + { + return Socket::Status::WouldBlock; + } + else if(status == -1) + { + return Socket::Status::Error; + } + + is_connected = false; + return Socket::Status::Disconnected; } - //Copy data to where it needs to go - memcpy(data, &unprocessed_buffer[0], received); - unprocessed_buffer.erase(0, received); + if(received > buffer_size) + received = buffer_size; + return Socket::Status::Success; } @@ -162,9 +150,4 @@ namespace fr { return socket_descriptor; } - - bool TcpSocket::has_data() const - { - return !unprocessed_buffer.empty(); - } } \ No newline at end of file