Thread safety for send/receive on both Tcp & SSL sockets
Send/Recv often needs to be called multiple times to transfer all of the data. send_raw/receive_raw are now mutex protected and so both send & receive can be called simultaneously.
This commit is contained in:
parent
8a54e8994a
commit
8638b70fb8
@ -107,6 +107,9 @@ namespace fr
|
||||
std::unique_ptr<mbedtls_ssl_context> ssl;
|
||||
mbedtls_ssl_config conf;
|
||||
uint32_t flags;
|
||||
|
||||
std::mutex outbound_mutex;
|
||||
std::mutex inbound_mutex;
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
#define FRNETLIB_TCPSOCKET_H
|
||||
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include "Socket.h"
|
||||
|
||||
namespace fr
|
||||
@ -95,6 +96,8 @@ protected:
|
||||
std::string unprocessed_buffer;
|
||||
std::unique_ptr<char[]> recv_buffer;
|
||||
int32_t socket_descriptor;
|
||||
std::mutex outbound_mutex;
|
||||
std::mutex inbound_mutex;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
76
main.cpp
76
main.cpp
@ -1,5 +1,8 @@
|
||||
#include <iostream>
|
||||
#include <frnetlib/SSLListener.h>
|
||||
#include <thread>
|
||||
#include <atomic>
|
||||
#include <mutex>
|
||||
#include "frnetlib/Packet.h"
|
||||
#include "frnetlib/TcpSocket.h"
|
||||
#include "frnetlib/TcpListener.h"
|
||||
@ -11,21 +14,68 @@
|
||||
#include "frnetlib/SSLContext.h"
|
||||
#include "frnetlib/SSLListener.h"
|
||||
|
||||
void server()
|
||||
{
|
||||
fr::TcpListener listener;
|
||||
fr::TcpSocket client;
|
||||
|
||||
listener.listen("8081");
|
||||
listener.accept(client);
|
||||
|
||||
uint32_t packet_no = 0;
|
||||
|
||||
while(true)
|
||||
{
|
||||
fr::Packet packet;
|
||||
client.receive(packet);
|
||||
|
||||
|
||||
uint32_t num = 0;
|
||||
packet >> num;
|
||||
|
||||
if(num != ++packet_no)
|
||||
{
|
||||
std::cout << "Packet mismatch. Expected " << packet_no + 1 << ". Got " << num << std::endl;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void client()
|
||||
{
|
||||
fr::TcpSocket server;
|
||||
server.connect("127.0.0.1", "8081");
|
||||
|
||||
uint32_t packet_no = 0;
|
||||
std::mutex m1;
|
||||
|
||||
auto lam = [&]()
|
||||
{
|
||||
while(true)
|
||||
{
|
||||
m1.lock();
|
||||
fr::Packet packet;
|
||||
packet << ++packet_no;
|
||||
m1.unlock();
|
||||
|
||||
server.send(packet);
|
||||
}
|
||||
};
|
||||
|
||||
std::thread t1(lam);
|
||||
std::thread t2(lam);
|
||||
std::thread t3(lam);
|
||||
std::thread t4(lam);
|
||||
t1.join();
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
std::shared_ptr<fr::SSLContext> ssl_context(new fr::SSLContext("certs.crt"));
|
||||
std::thread s1(server);
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
std::thread c1(client);
|
||||
|
||||
fr::HttpSocket<fr::SSLSocket> socket(ssl_context);
|
||||
std::string addr;
|
||||
std::cin >> addr;
|
||||
socket.connect(addr, "443");
|
||||
|
||||
fr::HttpRequest request;
|
||||
socket.send(request);
|
||||
|
||||
fr::HttpResponse response;
|
||||
socket.receive(response);
|
||||
|
||||
std::cout << response.get_body() << std::endl;
|
||||
s1.join();
|
||||
c1.join();
|
||||
return 0;
|
||||
}
|
||||
@ -40,6 +40,7 @@ namespace fr
|
||||
|
||||
Socket::Status SSLSocket::send_raw(const char *data, size_t size)
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(outbound_mutex);
|
||||
int error = 0;
|
||||
while((error = mbedtls_ssl_write(ssl.get(), (const unsigned char *)data, size)) <= 0)
|
||||
{
|
||||
@ -54,6 +55,8 @@ namespace fr
|
||||
|
||||
Socket::Status SSLSocket::receive_raw(void *data, size_t data_size, size_t &received)
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(inbound_mutex);
|
||||
|
||||
int read = MBEDTLS_ERR_SSL_WANT_READ;
|
||||
received = 0;
|
||||
if(unprocessed_buffer.size() < data_size)
|
||||
|
||||
@ -21,6 +21,8 @@ namespace fr
|
||||
|
||||
Socket::Status TcpSocket::send_raw(const char *data, size_t size)
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(outbound_mutex);
|
||||
|
||||
size_t sent = 0;
|
||||
while(sent < size)
|
||||
{
|
||||
@ -54,6 +56,7 @@ namespace fr
|
||||
|
||||
Socket::Status TcpSocket::receive_raw(void *data, size_t buffer_size, size_t &received)
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(inbound_mutex);
|
||||
received = 0;
|
||||
if(unprocessed_buffer.size() < buffer_size)
|
||||
{
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user