diff --git a/include/frnetlib/SSLSocket.h b/include/frnetlib/SSLSocket.h index 4f6af63..54d317e 100644 --- a/include/frnetlib/SSLSocket.h +++ b/include/frnetlib/SSLSocket.h @@ -107,6 +107,9 @@ namespace fr std::unique_ptr ssl; mbedtls_ssl_config conf; uint32_t flags; + + std::mutex outbound_mutex; + std::mutex inbound_mutex; }; } diff --git a/include/frnetlib/TcpSocket.h b/include/frnetlib/TcpSocket.h index f396ab7..56a4f72 100644 --- a/include/frnetlib/TcpSocket.h +++ b/include/frnetlib/TcpSocket.h @@ -6,6 +6,7 @@ #define FRNETLIB_TCPSOCKET_H #include +#include #include "Socket.h" namespace fr @@ -95,6 +96,8 @@ protected: std::string unprocessed_buffer; std::unique_ptr recv_buffer; int32_t socket_descriptor; + std::mutex outbound_mutex; + std::mutex inbound_mutex; }; } diff --git a/main.cpp b/main.cpp index 269f0f6..5cb3f81 100644 --- a/main.cpp +++ b/main.cpp @@ -1,5 +1,8 @@ #include #include +#include +#include +#include #include "frnetlib/Packet.h" #include "frnetlib/TcpSocket.h" #include "frnetlib/TcpListener.h" @@ -11,21 +14,68 @@ #include "frnetlib/SSLContext.h" #include "frnetlib/SSLListener.h" +void server() +{ + fr::TcpListener listener; + fr::TcpSocket client; + + listener.listen("8081"); + listener.accept(client); + + uint32_t packet_no = 0; + + while(true) + { + fr::Packet packet; + client.receive(packet); + + + uint32_t num = 0; + packet >> num; + + if(num != ++packet_no) + { + std::cout << "Packet mismatch. Expected " << packet_no + 1 << ". Got " << num << std::endl; + return; + } + } +} + +void client() +{ + fr::TcpSocket server; + server.connect("127.0.0.1", "8081"); + + uint32_t packet_no = 0; + std::mutex m1; + + auto lam = [&]() + { + while(true) + { + m1.lock(); + fr::Packet packet; + packet << ++packet_no; + m1.unlock(); + + server.send(packet); + } + }; + + std::thread t1(lam); + std::thread t2(lam); + std::thread t3(lam); + std::thread t4(lam); + t1.join(); +} + int main() { - std::shared_ptr ssl_context(new fr::SSLContext("certs.crt")); + std::thread s1(server); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + std::thread c1(client); - fr::HttpSocket socket(ssl_context); - std::string addr; - std::cin >> addr; - socket.connect(addr, "443"); - - fr::HttpRequest request; - socket.send(request); - - fr::HttpResponse response; - socket.receive(response); - - std::cout << response.get_body() << std::endl; + s1.join(); + c1.join(); return 0; } \ No newline at end of file diff --git a/src/SSLSocket.cpp b/src/SSLSocket.cpp index 375a02e..8d307b8 100644 --- a/src/SSLSocket.cpp +++ b/src/SSLSocket.cpp @@ -40,6 +40,7 @@ namespace fr Socket::Status SSLSocket::send_raw(const char *data, size_t size) { + std::lock_guard guard(outbound_mutex); int error = 0; while((error = mbedtls_ssl_write(ssl.get(), (const unsigned char *)data, size)) <= 0) { @@ -54,6 +55,8 @@ namespace fr Socket::Status SSLSocket::receive_raw(void *data, size_t data_size, size_t &received) { + std::lock_guard guard(inbound_mutex); + int read = MBEDTLS_ERR_SSL_WANT_READ; received = 0; if(unprocessed_buffer.size() < data_size) diff --git a/src/TcpSocket.cpp b/src/TcpSocket.cpp index 8cbf87a..732be2d 100644 --- a/src/TcpSocket.cpp +++ b/src/TcpSocket.cpp @@ -21,6 +21,8 @@ namespace fr Socket::Status TcpSocket::send_raw(const char *data, size_t size) { + std::lock_guard guard(outbound_mutex); + size_t sent = 0; while(sent < size) { @@ -54,6 +56,7 @@ namespace fr Socket::Status TcpSocket::receive_raw(void *data, size_t buffer_size, size_t &received) { + std::lock_guard guard(inbound_mutex); received = 0; if(unprocessed_buffer.size() < buffer_size) {