Added tests. Improved error checking. Bug fixes.

Added fr::TcpListener tests.
This commit is contained in:
Unknown 2017-06-05 20:13:50 +01:00
parent 08ed4a2354
commit 028677b01a
10 changed files with 152 additions and 47 deletions

View File

@ -41,7 +41,7 @@ namespace fr
* *
* @param header_end_pos The position in 'body' of the end of the header * @param header_end_pos The position in 'body' of the end of the header
*/ */
void parse_header(int32_t header_end_pos); bool parse_header(int32_t header_end_pos);
/*! /*!
* Parses the POST data from the body * Parses the POST data from the body

View File

@ -41,7 +41,7 @@ namespace fr
* *
* @param header_end_pos The position in 'body' of the end of the header * @param header_end_pos The position in 'body' of the end of the header
*/ */
void parse_header(int32_t header_end_pos); bool parse_header(int32_t header_end_pos);
//State //State
bool header_ended; bool header_ended;

View File

@ -39,12 +39,19 @@ namespace fr
* Calls the shutdown syscall on the socket. * Calls the shutdown syscall on the socket.
* So you can receive data but not send. * So you can receive data but not send.
* *
* Note: THIS DOES NOT CLOSE THE SOCKET. SHUTDOWN AND CLOSE ARE TWO DIFFERENT THINGS.
*
* This can be called on a blocking socket to force * This can be called on a blocking socket to force
* it to immediately return (you might want to do this if * it to immediately return (you might want to do this if
* you're exiting and need the blocking socket to return). * you're exiting and need the blocking socket to return).
*/ */
virtual void shutdown()=0; virtual void shutdown()=0;
/*!
* Closes the socket
*/
virtual void close_socket()=0;
/*! /*!
* Gets the socket descriptor. * Gets the socket descriptor.
* *

View File

@ -44,6 +44,11 @@ namespace fr
*/ */
virtual Socket::Status accept(Socket &client) override; virtual Socket::Status accept(Socket &client) override;
/*!
* Closes the socket
*/
virtual void close_socket() override;
/*! /*!
* Calls the shutdown syscall on the socket. * Calls the shutdown syscall on the socket.
* So you can receive data but not send. * So you can receive data but not send.

View File

@ -16,7 +16,7 @@ namespace fr
class TcpListener : public Listener class TcpListener : public Listener
{ {
public: public:
TcpListener() = default; TcpListener();
virtual ~TcpListener() override; virtual ~TcpListener() override;
TcpListener(TcpListener &&o) = default; TcpListener(TcpListener &&o) = default;
@ -60,6 +60,11 @@ public:
*/ */
virtual void set_socket_descriptor(int32_t descriptor) override; virtual void set_socket_descriptor(int32_t descriptor) override;
/*!
* Closes the socket
*/
virtual void close_socket() override;
private: private:
int32_t socket_descriptor; int32_t socket_descriptor;

View File

@ -33,7 +33,8 @@ namespace fr
} }
else else
{ {
parse_header(header_end); if(!parse_header(header_end))
return false;
body.clear(); body.clear();
} }
content_length += 2; //The empty line between header and data content_length += 2; //The empty line between header and data
@ -52,29 +53,38 @@ namespace fr
return true; return true;
} }
void HttpRequest::parse_header(int32_t header_end_pos) bool HttpRequest::parse_header(int32_t header_end_pos)
{ {
//Split the header into lines try
size_t line = 0;
std::vector<std::string> header_lines = split_string(body.substr(0, header_end_pos));
if(header_lines.empty())
return;
//Parse request type & uri
parse_header_type(header_lines[line]);
parse_header_uri(header_lines[line]);
line++;
//Read in headers
for(; line < header_lines.size(); line++)
{ {
parse_header_line(header_lines[line]); //Split the header into lines
} size_t line = 0;
std::vector<std::string> header_lines = split_string(body.substr(0, header_end_pos));
if(header_lines.empty())
return false;
//Parse request type & uri
parse_header_type(header_lines[line]);
parse_header_uri(header_lines[line]);
line++;
//Read in headers
for(; line < header_lines.size(); line++)
{
parse_header_line(header_lines[line]);
}
//Store content length value if it exists
auto length_header_iter = header_data.find("content-length");
if(length_header_iter != header_data.end())
content_length = (size_t)std::stoull(length_header_iter->second);
}
catch(const std::exception &e)
{
return false;
}
return true;
//Store content length value if it exists
auto length_header_iter = header_data.find("content-length");
if(length_header_iter != header_data.end())
content_length = (size_t)std::stoull(length_header_iter->second);
} }
std::string HttpRequest::construct(const std::string &host) const std::string HttpRequest::construct(const std::string &host) const

View File

@ -62,24 +62,32 @@ namespace fr
return response; return response;
} }
void HttpResponse::parse_header(int32_t header_end_pos) bool HttpResponse::parse_header(int32_t header_end_pos)
{ {
//Split the header into lines try
size_t line = 0;
std::vector<std::string> header_lines = split_string(body.substr(0, header_end_pos));
if(header_lines.empty())
return;
line++;
//Read in headers
for(; line < header_lines.size(); line++)
{ {
parse_header_line(header_lines[line]); //Split the header into lines
} size_t line = 0;
std::vector<std::string> header_lines = split_string(body.substr(0, header_end_pos));
if(header_lines.empty())
return false;
line++;
//Store content length value if it exists //Read in headers
auto length_header_iter = header_data.find("content-length"); for(; line < header_lines.size(); line++)
if(length_header_iter != header_data.end()) {
content_length = std::stoull(length_header_iter->second); parse_header_line(header_lines[line]);
}
//Store content length value if it exists
auto length_header_iter = header_data.find("content-length");
if(length_header_iter != header_data.end())
content_length = std::stoull(length_header_iter->second);
}
catch(const std::exception &e)
{
return false;
}
return true;
} }
} }

View File

@ -15,7 +15,7 @@ namespace fr
: ssl_context(ssl_context_) : ssl_context(ssl_context_)
{ {
//Initialise SSL objects required //Initialise SSL objects required
mbedtls_net_init(&listen_fd); listen_fd.fd = -1;
mbedtls_ssl_config_init(&conf); mbedtls_ssl_config_init(&conf);
mbedtls_x509_crt_init(&srvcert); mbedtls_x509_crt_init(&srvcert);
mbedtls_pk_init(&pkey); mbedtls_pk_init(&pkey);
@ -64,7 +64,7 @@ namespace fr
SSLListener::~SSLListener() SSLListener::~SSLListener()
{ {
mbedtls_net_free(&listen_fd); close_socket();
mbedtls_x509_crt_free(&srvcert); mbedtls_x509_crt_free(&srvcert);
mbedtls_pk_free(&pkey); mbedtls_pk_free(&pkey);
mbedtls_ssl_config_free(&conf); mbedtls_ssl_config_free(&conf);
@ -73,6 +73,8 @@ namespace fr
Socket::Status fr::SSLListener::listen(const std::string &port) Socket::Status fr::SSLListener::listen(const std::string &port)
{ {
//This is a hack. mbedtls doesn't support specifying the address family. //This is a hack. mbedtls doesn't support specifying the address family.
close_socket();
mbedtls_net_init(&listen_fd);
fr::TcpListener tcp_listen; fr::TcpListener tcp_listen;
tcp_listen.set_inet_version(ai_family); tcp_listen.set_inet_version(ai_family);
if(tcp_listen.listen(port) != fr::Socket::Success) if(tcp_listen.listen(port) != fr::Socket::Success)
@ -142,5 +144,14 @@ namespace fr
listen_fd.fd = descriptor; listen_fd.fd = descriptor;
} }
void SSLListener::close_socket()
{
if(listen_fd.fd != -1)
{
mbedtls_net_free(&listen_fd);
listen_fd.fd = -1;
}
}
} }
#endif #endif

View File

@ -9,13 +9,16 @@ namespace fr
const int yes = 1; const int yes = 1;
const int no = 0; const int no = 0;
TcpListener::TcpListener()
: socket_descriptor(-1)
{
}
TcpListener::~TcpListener() TcpListener::~TcpListener()
{ {
if(socket_descriptor > -1) close_socket();
{
closesocket(socket_descriptor);
socket_descriptor = -1;
}
} }
Socket::Status TcpListener::listen(const std::string &port) Socket::Status TcpListener::listen(const std::string &port)
@ -122,4 +125,13 @@ namespace fr
{ {
socket_descriptor = descriptor; socket_descriptor = descriptor;
} }
void TcpListener::close_socket()
{
if(socket_descriptor > -1)
{
closesocket(socket_descriptor);
socket_descriptor = -1;
}
}
} }

47
tests/TcpListenerTest.cpp Normal file
View File

@ -0,0 +1,47 @@
//
// Created by fred on 05/06/17.
//
#include <gtest/gtest.h>
#include <frnetlib/TcpListener.h>
#include <thread>
TEST(TcpListenerTest, listner_listen)
{
fr::TcpListener listener;
ASSERT_EQ(listener.get_socket_descriptor(), -1);
fr::Socket::Status ret = listener.listen("9090");
ASSERT_EQ(ret, fr::Socket::Success);
listener.close_socket();
ASSERT_EQ(listener.get_socket_descriptor(), -1);
}
TEST(TcpListenerTest, listener_accept)
{
fr::TcpListener listener;
listener.set_inet_version(fr::Socket::IP::v4);
if(listener.listen("9095") != fr::Socket::Success)
FAIL();
auto client_thread = []()
{
fr::TcpSocket socket;
socket.set_inet_version(fr::Socket::IP::v4);
auto ret = socket.connect("127.0.0.1", "9095");
ASSERT_EQ(ret, fr::Socket::Success);
};
std::thread t1(client_thread);
fr::TcpSocket socket;
auto ret = listener.accept(socket);
ASSERT_EQ(ret, fr::Socket::Success);
t1.join();
}
TEST(TcpListenerTest, set_descriptor)
{
fr::TcpListener listener;
listener.set_socket_descriptor(-20);
ASSERT_EQ(listener.get_socket_descriptor(), -20);
}