Thread safety for send/receive on both Tcp & SSL sockets
Send/Recv often needs to be called multiple times to transfer all of the data. send_raw/receive_raw are now mutex protected and so both send & receive can be called simultaneously.
This commit is contained in:
parent
8a54e8994a
commit
8638b70fb8
@ -107,6 +107,9 @@ namespace fr
|
|||||||
std::unique_ptr<mbedtls_ssl_context> ssl;
|
std::unique_ptr<mbedtls_ssl_context> ssl;
|
||||||
mbedtls_ssl_config conf;
|
mbedtls_ssl_config conf;
|
||||||
uint32_t flags;
|
uint32_t flags;
|
||||||
|
|
||||||
|
std::mutex outbound_mutex;
|
||||||
|
std::mutex inbound_mutex;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,7 @@
|
|||||||
#define FRNETLIB_TCPSOCKET_H
|
#define FRNETLIB_TCPSOCKET_H
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <mutex>
|
||||||
#include "Socket.h"
|
#include "Socket.h"
|
||||||
|
|
||||||
namespace fr
|
namespace fr
|
||||||
@ -95,6 +96,8 @@ protected:
|
|||||||
std::string unprocessed_buffer;
|
std::string unprocessed_buffer;
|
||||||
std::unique_ptr<char[]> recv_buffer;
|
std::unique_ptr<char[]> recv_buffer;
|
||||||
int32_t socket_descriptor;
|
int32_t socket_descriptor;
|
||||||
|
std::mutex outbound_mutex;
|
||||||
|
std::mutex inbound_mutex;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
76
main.cpp
76
main.cpp
@ -1,5 +1,8 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <frnetlib/SSLListener.h>
|
#include <frnetlib/SSLListener.h>
|
||||||
|
#include <thread>
|
||||||
|
#include <atomic>
|
||||||
|
#include <mutex>
|
||||||
#include "frnetlib/Packet.h"
|
#include "frnetlib/Packet.h"
|
||||||
#include "frnetlib/TcpSocket.h"
|
#include "frnetlib/TcpSocket.h"
|
||||||
#include "frnetlib/TcpListener.h"
|
#include "frnetlib/TcpListener.h"
|
||||||
@ -11,21 +14,68 @@
|
|||||||
#include "frnetlib/SSLContext.h"
|
#include "frnetlib/SSLContext.h"
|
||||||
#include "frnetlib/SSLListener.h"
|
#include "frnetlib/SSLListener.h"
|
||||||
|
|
||||||
|
void server()
|
||||||
|
{
|
||||||
|
fr::TcpListener listener;
|
||||||
|
fr::TcpSocket client;
|
||||||
|
|
||||||
|
listener.listen("8081");
|
||||||
|
listener.accept(client);
|
||||||
|
|
||||||
|
uint32_t packet_no = 0;
|
||||||
|
|
||||||
|
while(true)
|
||||||
|
{
|
||||||
|
fr::Packet packet;
|
||||||
|
client.receive(packet);
|
||||||
|
|
||||||
|
|
||||||
|
uint32_t num = 0;
|
||||||
|
packet >> num;
|
||||||
|
|
||||||
|
if(num != ++packet_no)
|
||||||
|
{
|
||||||
|
std::cout << "Packet mismatch. Expected " << packet_no + 1 << ". Got " << num << std::endl;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void client()
|
||||||
|
{
|
||||||
|
fr::TcpSocket server;
|
||||||
|
server.connect("127.0.0.1", "8081");
|
||||||
|
|
||||||
|
uint32_t packet_no = 0;
|
||||||
|
std::mutex m1;
|
||||||
|
|
||||||
|
auto lam = [&]()
|
||||||
|
{
|
||||||
|
while(true)
|
||||||
|
{
|
||||||
|
m1.lock();
|
||||||
|
fr::Packet packet;
|
||||||
|
packet << ++packet_no;
|
||||||
|
m1.unlock();
|
||||||
|
|
||||||
|
server.send(packet);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::thread t1(lam);
|
||||||
|
std::thread t2(lam);
|
||||||
|
std::thread t3(lam);
|
||||||
|
std::thread t4(lam);
|
||||||
|
t1.join();
|
||||||
|
}
|
||||||
|
|
||||||
int main()
|
int main()
|
||||||
{
|
{
|
||||||
std::shared_ptr<fr::SSLContext> ssl_context(new fr::SSLContext("certs.crt"));
|
std::thread s1(server);
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||||
|
std::thread c1(client);
|
||||||
|
|
||||||
fr::HttpSocket<fr::SSLSocket> socket(ssl_context);
|
s1.join();
|
||||||
std::string addr;
|
c1.join();
|
||||||
std::cin >> addr;
|
|
||||||
socket.connect(addr, "443");
|
|
||||||
|
|
||||||
fr::HttpRequest request;
|
|
||||||
socket.send(request);
|
|
||||||
|
|
||||||
fr::HttpResponse response;
|
|
||||||
socket.receive(response);
|
|
||||||
|
|
||||||
std::cout << response.get_body() << std::endl;
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
@ -40,6 +40,7 @@ namespace fr
|
|||||||
|
|
||||||
Socket::Status SSLSocket::send_raw(const char *data, size_t size)
|
Socket::Status SSLSocket::send_raw(const char *data, size_t size)
|
||||||
{
|
{
|
||||||
|
std::lock_guard<std::mutex> guard(outbound_mutex);
|
||||||
int error = 0;
|
int error = 0;
|
||||||
while((error = mbedtls_ssl_write(ssl.get(), (const unsigned char *)data, size)) <= 0)
|
while((error = mbedtls_ssl_write(ssl.get(), (const unsigned char *)data, size)) <= 0)
|
||||||
{
|
{
|
||||||
@ -54,6 +55,8 @@ namespace fr
|
|||||||
|
|
||||||
Socket::Status SSLSocket::receive_raw(void *data, size_t data_size, size_t &received)
|
Socket::Status SSLSocket::receive_raw(void *data, size_t data_size, size_t &received)
|
||||||
{
|
{
|
||||||
|
std::lock_guard<std::mutex> guard(inbound_mutex);
|
||||||
|
|
||||||
int read = MBEDTLS_ERR_SSL_WANT_READ;
|
int read = MBEDTLS_ERR_SSL_WANT_READ;
|
||||||
received = 0;
|
received = 0;
|
||||||
if(unprocessed_buffer.size() < data_size)
|
if(unprocessed_buffer.size() < data_size)
|
||||||
|
|||||||
@ -21,6 +21,8 @@ namespace fr
|
|||||||
|
|
||||||
Socket::Status TcpSocket::send_raw(const char *data, size_t size)
|
Socket::Status TcpSocket::send_raw(const char *data, size_t size)
|
||||||
{
|
{
|
||||||
|
std::lock_guard<std::mutex> guard(outbound_mutex);
|
||||||
|
|
||||||
size_t sent = 0;
|
size_t sent = 0;
|
||||||
while(sent < size)
|
while(sent < size)
|
||||||
{
|
{
|
||||||
@ -54,6 +56,7 @@ namespace fr
|
|||||||
|
|
||||||
Socket::Status TcpSocket::receive_raw(void *data, size_t buffer_size, size_t &received)
|
Socket::Status TcpSocket::receive_raw(void *data, size_t buffer_size, size_t &received)
|
||||||
{
|
{
|
||||||
|
std::lock_guard<std::mutex> guard(inbound_mutex);
|
||||||
received = 0;
|
received = 0;
|
||||||
if(unprocessed_buffer.size() < buffer_size)
|
if(unprocessed_buffer.size() < buffer_size)
|
||||||
{
|
{
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user