From 028677b01ae150dee123a9cea9012dfa55f6d88a Mon Sep 17 00:00:00 2001 From: Unknown Date: Mon, 5 Jun 2017 20:13:50 +0100 Subject: [PATCH] Added tests. Improved error checking. Bug fixes. Added fr::TcpListener tests. --- include/frnetlib/HttpRequest.h | 2 +- include/frnetlib/HttpResponse.h | 2 +- include/frnetlib/Listener.h | 7 +++++ include/frnetlib/SSLListener.h | 5 ++++ include/frnetlib/TcpListener.h | 7 ++++- src/HttpRequest.cpp | 52 ++++++++++++++++++++------------- src/HttpResponse.cpp | 40 +++++++++++++++---------- src/SSLListener.cpp | 15 ++++++++-- src/TcpListener.cpp | 22 ++++++++++---- tests/TcpListenerTest.cpp | 47 +++++++++++++++++++++++++++++ 10 files changed, 152 insertions(+), 47 deletions(-) create mode 100644 tests/TcpListenerTest.cpp diff --git a/include/frnetlib/HttpRequest.h b/include/frnetlib/HttpRequest.h index ad8f9fd..0b37f33 100644 --- a/include/frnetlib/HttpRequest.h +++ b/include/frnetlib/HttpRequest.h @@ -41,7 +41,7 @@ namespace fr * * @param header_end_pos The position in 'body' of the end of the header */ - void parse_header(int32_t header_end_pos); + bool parse_header(int32_t header_end_pos); /*! * Parses the POST data from the body diff --git a/include/frnetlib/HttpResponse.h b/include/frnetlib/HttpResponse.h index d2d8383..b587b9d 100644 --- a/include/frnetlib/HttpResponse.h +++ b/include/frnetlib/HttpResponse.h @@ -41,7 +41,7 @@ namespace fr * * @param header_end_pos The position in 'body' of the end of the header */ - void parse_header(int32_t header_end_pos); + bool parse_header(int32_t header_end_pos); //State bool header_ended; diff --git a/include/frnetlib/Listener.h b/include/frnetlib/Listener.h index e17bd53..a47a020 100644 --- a/include/frnetlib/Listener.h +++ b/include/frnetlib/Listener.h @@ -39,12 +39,19 @@ namespace fr * Calls the shutdown syscall on the socket. * So you can receive data but not send. * + * Note: THIS DOES NOT CLOSE THE SOCKET. SHUTDOWN AND CLOSE ARE TWO DIFFERENT THINGS. + * * This can be called on a blocking socket to force * it to immediately return (you might want to do this if * you're exiting and need the blocking socket to return). */ virtual void shutdown()=0; + /*! + * Closes the socket + */ + virtual void close_socket()=0; + /*! * Gets the socket descriptor. * diff --git a/include/frnetlib/SSLListener.h b/include/frnetlib/SSLListener.h index a05ac12..ffef04f 100644 --- a/include/frnetlib/SSLListener.h +++ b/include/frnetlib/SSLListener.h @@ -44,6 +44,11 @@ namespace fr */ virtual Socket::Status accept(Socket &client) override; + /*! + * Closes the socket + */ + virtual void close_socket() override; + /*! * Calls the shutdown syscall on the socket. * So you can receive data but not send. diff --git a/include/frnetlib/TcpListener.h b/include/frnetlib/TcpListener.h index ece56ac..dc48393 100644 --- a/include/frnetlib/TcpListener.h +++ b/include/frnetlib/TcpListener.h @@ -16,7 +16,7 @@ namespace fr class TcpListener : public Listener { public: - TcpListener() = default; + TcpListener(); virtual ~TcpListener() override; TcpListener(TcpListener &&o) = default; @@ -60,6 +60,11 @@ public: */ virtual void set_socket_descriptor(int32_t descriptor) override; + /*! + * Closes the socket + */ + virtual void close_socket() override; + private: int32_t socket_descriptor; diff --git a/src/HttpRequest.cpp b/src/HttpRequest.cpp index 41e9508..6187f40 100644 --- a/src/HttpRequest.cpp +++ b/src/HttpRequest.cpp @@ -33,7 +33,8 @@ namespace fr } else { - parse_header(header_end); + if(!parse_header(header_end)) + return false; body.clear(); } content_length += 2; //The empty line between header and data @@ -52,29 +53,38 @@ namespace fr return true; } - void HttpRequest::parse_header(int32_t header_end_pos) + bool HttpRequest::parse_header(int32_t header_end_pos) { - //Split the header into lines - size_t line = 0; - std::vector header_lines = split_string(body.substr(0, header_end_pos)); - if(header_lines.empty()) - return; - - //Parse request type & uri - parse_header_type(header_lines[line]); - parse_header_uri(header_lines[line]); - line++; - - //Read in headers - for(; line < header_lines.size(); line++) + try { - parse_header_line(header_lines[line]); - } + //Split the header into lines + size_t line = 0; + std::vector header_lines = split_string(body.substr(0, header_end_pos)); + if(header_lines.empty()) + return false; + + //Parse request type & uri + parse_header_type(header_lines[line]); + parse_header_uri(header_lines[line]); + line++; + + //Read in headers + for(; line < header_lines.size(); line++) + { + parse_header_line(header_lines[line]); + } + + //Store content length value if it exists + auto length_header_iter = header_data.find("content-length"); + if(length_header_iter != header_data.end()) + content_length = (size_t)std::stoull(length_header_iter->second); + } + catch(const std::exception &e) + { + return false; + } + return true; - //Store content length value if it exists - auto length_header_iter = header_data.find("content-length"); - if(length_header_iter != header_data.end()) - content_length = (size_t)std::stoull(length_header_iter->second); } std::string HttpRequest::construct(const std::string &host) const diff --git a/src/HttpResponse.cpp b/src/HttpResponse.cpp index 1bb9d3c..ace85d3 100644 --- a/src/HttpResponse.cpp +++ b/src/HttpResponse.cpp @@ -62,24 +62,32 @@ namespace fr return response; } - void HttpResponse::parse_header(int32_t header_end_pos) + bool HttpResponse::parse_header(int32_t header_end_pos) { - //Split the header into lines - size_t line = 0; - std::vector header_lines = split_string(body.substr(0, header_end_pos)); - if(header_lines.empty()) - return; - line++; - - //Read in headers - for(; line < header_lines.size(); line++) + try { - parse_header_line(header_lines[line]); - } + //Split the header into lines + size_t line = 0; + std::vector header_lines = split_string(body.substr(0, header_end_pos)); + if(header_lines.empty()) + return false; + line++; - //Store content length value if it exists - auto length_header_iter = header_data.find("content-length"); - if(length_header_iter != header_data.end()) - content_length = std::stoull(length_header_iter->second); + //Read in headers + for(; line < header_lines.size(); line++) + { + parse_header_line(header_lines[line]); + } + + //Store content length value if it exists + auto length_header_iter = header_data.find("content-length"); + if(length_header_iter != header_data.end()) + content_length = std::stoull(length_header_iter->second); + } + catch(const std::exception &e) + { + return false; + } + return true; } } \ No newline at end of file diff --git a/src/SSLListener.cpp b/src/SSLListener.cpp index cb6db1e..2a66f94 100644 --- a/src/SSLListener.cpp +++ b/src/SSLListener.cpp @@ -15,7 +15,7 @@ namespace fr : ssl_context(ssl_context_) { //Initialise SSL objects required - mbedtls_net_init(&listen_fd); + listen_fd.fd = -1; mbedtls_ssl_config_init(&conf); mbedtls_x509_crt_init(&srvcert); mbedtls_pk_init(&pkey); @@ -64,7 +64,7 @@ namespace fr SSLListener::~SSLListener() { - mbedtls_net_free(&listen_fd); + close_socket(); mbedtls_x509_crt_free(&srvcert); mbedtls_pk_free(&pkey); mbedtls_ssl_config_free(&conf); @@ -73,6 +73,8 @@ namespace fr Socket::Status fr::SSLListener::listen(const std::string &port) { //This is a hack. mbedtls doesn't support specifying the address family. + close_socket(); + mbedtls_net_init(&listen_fd); fr::TcpListener tcp_listen; tcp_listen.set_inet_version(ai_family); if(tcp_listen.listen(port) != fr::Socket::Success) @@ -142,5 +144,14 @@ namespace fr listen_fd.fd = descriptor; } + void SSLListener::close_socket() + { + if(listen_fd.fd != -1) + { + mbedtls_net_free(&listen_fd); + listen_fd.fd = -1; + } + } + } #endif \ No newline at end of file diff --git a/src/TcpListener.cpp b/src/TcpListener.cpp index 0e5dc78..1aafdc4 100644 --- a/src/TcpListener.cpp +++ b/src/TcpListener.cpp @@ -9,13 +9,16 @@ namespace fr const int yes = 1; const int no = 0; + + TcpListener::TcpListener() + : socket_descriptor(-1) + { + + } + TcpListener::~TcpListener() { - if(socket_descriptor > -1) - { - closesocket(socket_descriptor); - socket_descriptor = -1; - } + close_socket(); } Socket::Status TcpListener::listen(const std::string &port) @@ -122,4 +125,13 @@ namespace fr { socket_descriptor = descriptor; } + + void TcpListener::close_socket() + { + if(socket_descriptor > -1) + { + closesocket(socket_descriptor); + socket_descriptor = -1; + } + } } \ No newline at end of file diff --git a/tests/TcpListenerTest.cpp b/tests/TcpListenerTest.cpp new file mode 100644 index 0000000..b16153b --- /dev/null +++ b/tests/TcpListenerTest.cpp @@ -0,0 +1,47 @@ +// +// Created by fred on 05/06/17. +// + +#include +#include +#include + +TEST(TcpListenerTest, listner_listen) +{ + fr::TcpListener listener; + ASSERT_EQ(listener.get_socket_descriptor(), -1); + fr::Socket::Status ret = listener.listen("9090"); + ASSERT_EQ(ret, fr::Socket::Success); + listener.close_socket(); + ASSERT_EQ(listener.get_socket_descriptor(), -1); +} + + +TEST(TcpListenerTest, listener_accept) +{ + fr::TcpListener listener; + listener.set_inet_version(fr::Socket::IP::v4); + if(listener.listen("9095") != fr::Socket::Success) + FAIL(); + + auto client_thread = []() + { + fr::TcpSocket socket; + socket.set_inet_version(fr::Socket::IP::v4); + auto ret = socket.connect("127.0.0.1", "9095"); + ASSERT_EQ(ret, fr::Socket::Success); + }; + + std::thread t1(client_thread); + fr::TcpSocket socket; + auto ret = listener.accept(socket); + ASSERT_EQ(ret, fr::Socket::Success); + t1.join(); +} + +TEST(TcpListenerTest, set_descriptor) +{ + fr::TcpListener listener; + listener.set_socket_descriptor(-20); + ASSERT_EQ(listener.get_socket_descriptor(), -20); +}