Thread safety for send/receive on both Tcp & SSL sockets

Send/Recv often needs to be called multiple times to transfer all of the data. send_raw/receive_raw are now mutex protected and so both send & receive can be called simultaneously.
This commit is contained in:
Fred Nicolson 2016-12-23 18:23:25 +00:00
parent 8a54e8994a
commit 8638b70fb8
5 changed files with 75 additions and 13 deletions

View File

@ -107,6 +107,9 @@ namespace fr
std::unique_ptr<mbedtls_ssl_context> ssl;
mbedtls_ssl_config conf;
uint32_t flags;
std::mutex outbound_mutex;
std::mutex inbound_mutex;
};
}

View File

@ -6,6 +6,7 @@
#define FRNETLIB_TCPSOCKET_H
#include <memory>
#include <mutex>
#include "Socket.h"
namespace fr
@ -95,6 +96,8 @@ protected:
std::string unprocessed_buffer;
std::unique_ptr<char[]> recv_buffer;
int32_t socket_descriptor;
std::mutex outbound_mutex;
std::mutex inbound_mutex;
};
}

View File

@ -1,5 +1,8 @@
#include <iostream>
#include <frnetlib/SSLListener.h>
#include <thread>
#include <atomic>
#include <mutex>
#include "frnetlib/Packet.h"
#include "frnetlib/TcpSocket.h"
#include "frnetlib/TcpListener.h"
@ -11,21 +14,68 @@
#include "frnetlib/SSLContext.h"
#include "frnetlib/SSLListener.h"
void server()
{
fr::TcpListener listener;
fr::TcpSocket client;
listener.listen("8081");
listener.accept(client);
uint32_t packet_no = 0;
while(true)
{
fr::Packet packet;
client.receive(packet);
uint32_t num = 0;
packet >> num;
if(num != ++packet_no)
{
std::cout << "Packet mismatch. Expected " << packet_no + 1 << ". Got " << num << std::endl;
return;
}
}
}
void client()
{
fr::TcpSocket server;
server.connect("127.0.0.1", "8081");
uint32_t packet_no = 0;
std::mutex m1;
auto lam = [&]()
{
while(true)
{
m1.lock();
fr::Packet packet;
packet << ++packet_no;
m1.unlock();
server.send(packet);
}
};
std::thread t1(lam);
std::thread t2(lam);
std::thread t3(lam);
std::thread t4(lam);
t1.join();
}
int main()
{
std::shared_ptr<fr::SSLContext> ssl_context(new fr::SSLContext("certs.crt"));
std::thread s1(server);
std::this_thread::sleep_for(std::chrono::milliseconds(100));
std::thread c1(client);
fr::HttpSocket<fr::SSLSocket> socket(ssl_context);
std::string addr;
std::cin >> addr;
socket.connect(addr, "443");
fr::HttpRequest request;
socket.send(request);
fr::HttpResponse response;
socket.receive(response);
std::cout << response.get_body() << std::endl;
s1.join();
c1.join();
return 0;
}

View File

@ -40,6 +40,7 @@ namespace fr
Socket::Status SSLSocket::send_raw(const char *data, size_t size)
{
std::lock_guard<std::mutex> guard(outbound_mutex);
int error = 0;
while((error = mbedtls_ssl_write(ssl.get(), (const unsigned char *)data, size)) <= 0)
{
@ -54,6 +55,8 @@ namespace fr
Socket::Status SSLSocket::receive_raw(void *data, size_t data_size, size_t &received)
{
std::lock_guard<std::mutex> guard(inbound_mutex);
int read = MBEDTLS_ERR_SSL_WANT_READ;
received = 0;
if(unprocessed_buffer.size() < data_size)

View File

@ -21,6 +21,8 @@ namespace fr
Socket::Status TcpSocket::send_raw(const char *data, size_t size)
{
std::lock_guard<std::mutex> guard(outbound_mutex);
size_t sent = 0;
while(sent < size)
{
@ -54,6 +56,7 @@ namespace fr
Socket::Status TcpSocket::receive_raw(void *data, size_t buffer_size, size_t &received)
{
std::lock_guard<std::mutex> guard(inbound_mutex);
received = 0;
if(unprocessed_buffer.size() < buffer_size)
{