diff --git a/include/frnetlib/Http.h b/include/frnetlib/Http.h index 51d9968..9909017 100644 --- a/include/frnetlib/Http.h +++ b/include/frnetlib/Http.h @@ -17,13 +17,14 @@ namespace fr public: enum RequestType { - Unknown = 0, - Get = 1, - Post = 2, - Put = 3, - Delete = 4, - Patch = 5, - RequestTypeCount = 6, //Keep me at the end and updated + Get = 0, + Post = 1, + Put = 2, + Delete = 3, + Patch = 4, + RequestTypeCount = 5, //Keep me at the end of valid HTTP request types, and updated + Unknown = 6, + Partial = 7, }; enum RequestStatus { @@ -247,14 +248,6 @@ namespace fr */ const static std::string &get_mimetype(const std::string &filename); - protected: - /*! - * Splits a string by new line. Ignores escaped \n's - * - * @return The split string - */ - std::vector split_string(const std::string &str); - /*! * Converts a 'RequestType' enum value to * a printable string. @@ -262,18 +255,23 @@ namespace fr * @param type The RequestType to convert * @return The printable version of the enum value */ - std::string request_type_to_string(RequestType type) const; + static std::string request_type_to_string(RequestType type); /*! - * Converts hexadecimal to an integer. + * Converts a string value into a 'RequestType' enum value. * - * @param hex The hex value to convert - * @return The decimal equivilent of the hexadecimal value. + * @param str The string to convert + * @return The converted RequestType. Unknown on failure. Or Partial if str is part of a request type. */ - static inline int dectohex(const std::string &hex) - { - return (int)strtol(&hex[0], nullptr, 16); - } + static RequestType string_to_request_type(const std::string &str) ; + + protected: + /*! + * Splits a string by new line. Ignores escaped \n's + * + * @return The split string + */ + static std::vector split_string(const std::string &str); /*! * Converts a parameter list to a vector pair. @@ -283,7 +281,7 @@ namespace fr * @param str The string to parse * @return The vector containing the results pairs */ - std::vector> parse_argument_list(const std::string &str); + static std::vector> parse_argument_list(const std::string &str); /*! * Parses a header line in a HTTP request/response diff --git a/include/frnetlib/HttpRequest.h b/include/frnetlib/HttpRequest.h index 36774a9..5343601 100644 --- a/include/frnetlib/HttpRequest.h +++ b/include/frnetlib/HttpRequest.h @@ -57,8 +57,9 @@ namespace fr * Parses the header type (GET/POST) from the given string. * * @param str The first header line + * @return The parsed request type */ - void parse_header_type(const std::string &str); + Http::RequestType parse_header_type(const std::string &str); /*! * Parses the header URI diff --git a/src/Http.cpp b/src/Http.cpp index fef1b90..1fdac96 100644 --- a/src/Http.cpp +++ b/src/Http.cpp @@ -5,23 +5,17 @@ #include #include #include +#include #include "frnetlib/Http.h" namespace fr { - const static std::string request_type_strings[Http::RequestType::RequestTypeCount] = {"UNKNOWN", - "GET", - "POST", - "PUT", - "DELETE", - "PATCH"}; - Http::Http() : request_type(Unknown), uri("/"), status(Ok) { - static_assert(Http::RequestType::RequestTypeCount == 6, "Please update request_type_strings"); + } Http::RequestType Http::get_type() const @@ -118,13 +112,42 @@ namespace fr uri = '/' + str; } - std::string Http::request_type_to_string(RequestType type) const + std::string Http::request_type_to_string(RequestType type) { + static_assert(RequestType::RequestTypeCount == 5, "Update request_type_to_string"); + const static std::string request_type_strings[RequestType::RequestTypeCount] = {"GET", + "POST", + "PUT", + "DELETE", + "PATCH"}; + if(type >= RequestType::RequestTypeCount) - return request_type_strings[0]; + return "UNKNOWN"; return request_type_strings[type]; } + Http::RequestType Http::string_to_request_type(const std::string &str) + { + //Find the request type + static_assert(Http::RequestTypeCount == 5, "Update parse_header_type()"); + + RequestType type = Http::Unknown; + for(size_t a = 0; a < Http::RequestTypeCount; ++a) + { + std::string type_string = request_type_to_string(static_cast(a)); + int cmp_ret = str.compare(0, type_string.size(), type_string); + if(cmp_ret == 0) + return static_cast(a); + if(str.size() < type_string.size() && cmp_ret < 0) + type = Http::Partial; + if(type != Http::Partial && str.size() < type_string.size() && cmp_ret > 0) + type = Http::Unknown; + } + + return type; + } + + void Http::set_type(Http::RequestType type) { request_type = type; @@ -137,18 +160,31 @@ namespace fr std::string Http::url_encode(const std::string &str) { - std::stringstream encoded; - encoded << std::hex; - for(const auto &c : str) + static const char hex_lookup[]= "0123456789ABCDEF"; + std::string out; + for(char c : str) { - if(isalnum(c) || c == '-' || c == '_' || c == '.' || c == '~') - encoded << c; - else if(c == ' ') - encoded << '+'; + if ((c >= '0' && c <= '9') || + (c >= 'a' && c <= 'z') || + (c >= 'A' && c <= 'Z') || + c == '-' || c == '_' || c == '.' || c == '!' || c == '~' || + c == '*' || c == '\'' || c == '(' || c == ')') + { + out += c; + } + else if (c == ' ') + { + out += '+'; + } else - encoded << "%" << std::uppercase << (int)c << std::nouppercase; + { + out.push_back('%'); + out.push_back(hex_lookup[(c&0xF0)>>4]); + out.push_back(hex_lookup[(c&0x0F)]); + } } - return encoded.str(); + + return out; } std::string Http::url_decode(const std::string &str) @@ -158,8 +194,14 @@ namespace fr { if(str[a] == '%' && a < str.size() - 1) { - result += (char)dectohex(str.substr(a + 1, 2)); - a += 2; + int ch1 = str[++a] - 48; + int ch2 = str[++a] - 48; + if(ch1 > 9) ch1 -= 7; + if(ch2 > 9) ch2 -= 7; + uint8_t ret = 0; + ret |= ch1 << 4; + ret |= ch2; + result.push_back(ret); } else if(str[a] == '+') { diff --git a/src/HttpRequest.cpp b/src/HttpRequest.cpp index a3b21e8..35a6e4e 100644 --- a/src/HttpRequest.cpp +++ b/src/HttpRequest.cpp @@ -22,6 +22,10 @@ namespace fr //Ensure that the whole header has been parsed first if(!header_ended) { + //Verify that it's a valid HTTP header so far + if(!body.empty() && Http::string_to_request_type(body) == fr::Http::Unknown) + return fr::Socket::ParseError; + //Check to see if this request data contains the end of the header uint16_t header_end_size = 4; auto header_end = body.find("\r\n\r\n"); @@ -78,7 +82,9 @@ namespace fr return true; //Parse request type & uri - parse_header_type(header_lines[line]); + request_type = parse_header_type(header_lines[line]); + if(request_type > Http::RequestTypeCount) + return false; parse_header_uri(header_lines[line]); line++; @@ -177,29 +183,15 @@ namespace fr } } - void HttpRequest::parse_header_type(const std::string &str) + Http::RequestType HttpRequest::parse_header_type(const std::string &str) { //Find the request type auto type_end = str.find(' '); if(type_end != std::string::npos) { - //Check what it is - if(str.compare(0, type_end, "GET") == 0) - request_type = fr::Http::Get; - else if(str.compare(0, type_end, "POST") == 0) - request_type = fr::Http::Post; - else if(str.compare(0, type_end, "PUT") == 0) - request_type = fr::Http::Put; - else if(str.compare(0, type_end, "DELETE") == 0) - request_type = fr::Http::Delete; - else if(str.compare(0, type_end, "PATCH") == 0) - request_type = fr::Http::Patch; - else - request_type = fr::Http::Unknown; - - return; + return string_to_request_type(str.substr(0, type_end)); } - throw std::invalid_argument("No known request type found in: " + str); + return Http::Unknown; } void HttpRequest::parse_header_uri(const std::string &str) diff --git a/src/HttpResponse.cpp b/src/HttpResponse.cpp index 3b01cfb..70d3af2 100644 --- a/src/HttpResponse.cpp +++ b/src/HttpResponse.cpp @@ -14,6 +14,10 @@ namespace fr //Ensure that the whole header has been parsed first if(!header_ended) { + //Verify that it's a valid HTTP response if there's enough data + if(body.size() >= 4 && body.compare(0, 4, "HTTP") != 0) + return fr::Socket::ParseError; + //Check to see if this request data contains the end of the header uint16_t header_end_size = 4; auto header_end = body.find("\r\n\r\n"); diff --git a/src/SSLListener.cpp b/src/SSLListener.cpp index d71e830..7fc6814 100644 --- a/src/SSLListener.cpp +++ b/src/SSLListener.cpp @@ -91,7 +91,7 @@ namespace fr //Initialise mbedtls int error = 0; std::unique_ptr ssl(new mbedtls_ssl_context); - auto client_fd = std::make_unique(); + std::unique_ptr client_fd(new mbedtls_net_context); mbedtls_ssl_init(ssl.get()); mbedtls_net_init(client_fd.get()); @@ -108,7 +108,6 @@ namespace fr size_t ip_len = 0; if((error = mbedtls_net_accept(&listen_fd, client_fd.get(), client_ip, sizeof(client_ip), &ip_len)) != 0) { - std::cout << "Accept error: " << error << std::endl; free_contexts(); return Socket::Error; } @@ -120,7 +119,6 @@ namespace fr { if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE) { - std::cout << "Handshake error: " << error << std::endl; free_contexts(); return Socket::Status::HandshakeFailed; } diff --git a/src/SSLSocket.cpp b/src/SSLSocket.cpp index 9599a6c..79ab9dd 100644 --- a/src/SSLSocket.cpp +++ b/src/SSLSocket.cpp @@ -84,6 +84,7 @@ namespace fr { //Initialise mbedtls stuff ssl = std::make_unique(); + ssl_socket_descriptor = std::make_unique(); mbedtls_ssl_init(ssl.get()); mbedtls_net_init(ssl_socket_descriptor.get()); @@ -94,7 +95,6 @@ namespace fr auto ret = socket.connect(address, port, timeout); if(ret != fr::Socket::Success) return ret; - ssl_socket_descriptor = std::make_unique(); ssl_socket_descriptor->fd = socket.get_socket_descriptor(); remote_address = socket.get_remote_address(); socket.set_descriptor(nullptr); @@ -121,7 +121,7 @@ namespace fr return Socket::Status::Error; } - mbedtls_ssl_set_bio(ssl.get(), &ssl_socket_descriptor, mbedtls_net_send, mbedtls_net_recv, nullptr); + mbedtls_ssl_set_bio(ssl.get(), ssl_socket_descriptor.get(), mbedtls_net_send, mbedtls_net_recv, nullptr); //Do SSL handshake while((error = mbedtls_ssl_handshake(ssl.get())) != 0) @@ -157,7 +157,8 @@ namespace fr void SSLSocket::set_descriptor(void *descriptor) { ssl_socket_descriptor.reset(static_cast(descriptor)); - reconfigure_socket(); + if(descriptor) + reconfigure_socket(); } void SSLSocket::verify_certificates(bool should_verify_) diff --git a/src/TcpListener.cpp b/src/TcpListener.cpp index a20fc8a..249e0aa 100644 --- a/src/TcpListener.cpp +++ b/src/TcpListener.cpp @@ -90,7 +90,7 @@ namespace fr //Prepare to wait for the client sockaddr_storage client_addr{}; - int client_descriptor; + int32_t client_descriptor; char client_printable_addr[INET6_ADDRSTRLEN]; //Accept one diff --git a/src/TcpSocket.cpp b/src/TcpSocket.cpp index 98f9517..5f0d290 100644 --- a/src/TcpSocket.cpp +++ b/src/TcpSocket.cpp @@ -80,6 +80,11 @@ namespace fr void TcpSocket::set_descriptor(void *descriptor) { + if(!descriptor) + { + socket_descriptor = -1; + return; + } socket_descriptor = *static_cast(descriptor); reconfigure_socket(); } diff --git a/tests/HttpRequestTest.cpp b/tests/HttpRequestTest.cpp index bcb0dc3..45f7b0b 100644 --- a/tests/HttpRequestTest.cpp +++ b/tests/HttpRequestTest.cpp @@ -82,29 +82,34 @@ TEST(HttpRequestTest, request_type_parse) const std::string delete_request = "DELETE / HTTP/1.1\r\n\r\n"; const std::string patch_request = "PATCH / HTTP/1.1\r\n\r\n"; const std::string invalid_request = "INVALID / HTTP/1.1\r\n\r\n"; + const std::string invalid_request2 = "PU / HTTP/1.1\r\n\r\n"; fr::HttpRequest request; - request.parse(get_request.c_str(), get_request.size()); + ASSERT_EQ(request.parse(get_request.c_str(), get_request.size()), fr::Socket::Success); ASSERT_EQ(request.get_type(), fr::Http::Get); request = {}; - request.parse(post_request.c_str(), post_request.size()); + ASSERT_EQ(request.parse(post_request.c_str(), post_request.size()), fr::Socket::Success); ASSERT_EQ(request.get_type(), fr::Http::Post); request = {}; - request.parse(put_request.c_str(), put_request.size()); + ASSERT_EQ(request.parse(put_request.c_str(), put_request.size()), fr::Socket::Success); ASSERT_EQ(request.get_type(), fr::Http::Put); request = {}; - request.parse(delete_request.c_str(), delete_request.size()); + ASSERT_EQ(request.parse(delete_request.c_str(), delete_request.size()), fr::Socket::Success); ASSERT_EQ(request.get_type(), fr::Http::Delete); request = {}; - request.parse(patch_request.c_str(), patch_request.size()); + ASSERT_EQ(request.parse(patch_request.c_str(), patch_request.size()), fr::Socket::Success); ASSERT_EQ(request.get_type(), fr::Http::Patch); request = {}; - request.parse(invalid_request.c_str(), invalid_request.size()); + ASSERT_EQ(request.parse(invalid_request.c_str(), invalid_request.size()), fr::Socket::ParseError); + ASSERT_EQ(request.get_type(), fr::Http::Unknown); + request = {}; + + ASSERT_EQ(request.parse(invalid_request2.c_str(), invalid_request2.size()), fr::Socket::ParseError); ASSERT_EQ(request.get_type(), fr::Http::Unknown); request = {}; } diff --git a/tests/HttpResponseTest.cpp b/tests/HttpResponseTest.cpp index a262d39..9e67213 100644 --- a/tests/HttpResponseTest.cpp +++ b/tests/HttpResponseTest.cpp @@ -92,6 +92,7 @@ TEST(HttpResponseTest, header_length_test) //Try data with no header end first std::string buff(MAX_HTTP_HEADER_SIZE + 1, '\0'); fr::HttpResponse response; + buff.insert(0, "HTTP"); ASSERT_EQ(response.parse(buff.c_str(), buff.size()), fr::Socket::HttpHeaderTooBig); response = {}; diff --git a/tests/HttpTest.cpp b/tests/HttpTest.cpp new file mode 100644 index 0000000..b3be53e --- /dev/null +++ b/tests/HttpTest.cpp @@ -0,0 +1,60 @@ +// +// Created by fred.nicolson on 01/03/18. +// + +#include +#include + +TEST(HttpTest, test_request_type_to_string) +{ + for(size_t a = 0; a < fr::Http::RequestTypeCount; ++a) + { + ASSERT_EQ(fr::Http::string_to_request_type(fr::Http::request_type_to_string((fr::Http::RequestType)a)), a); + } + + ASSERT_EQ(fr::Http::request_type_to_string(fr::Http::Partial), "UNKNOWN"); + ASSERT_EQ(fr::Http::request_type_to_string(fr::Http::RequestTypeCount), "UNKNOWN"); + ASSERT_EQ(fr::Http::request_type_to_string(fr::Http::Unknown), "UNKNOWN"); +} + +TEST(HttpTest, test_string_to_request_type) +{ + std::vector> strings = { + {fr::Http::Get, "GET"}, + {fr::Http::Put, "PUT"}, + {fr::Http::Delete, "DELETE"}, + {fr::Http::Patch, "PATCH"}, + {fr::Http::Patch, "PATCHid-=wa"}, + {fr::Http::Partial, "PA"}, + {fr::Http::Partial, "PU"}, + {fr::Http::Partial, "DELET"}, + {fr::Http::Unknown, "DELETa"}, + {fr::Http::Unknown, "U"}, + {fr::Http::Unknown, "dwaouidhwi"}, + {fr::Http::Unknown, "get"}, + }; + + for(auto &str : strings) + { + ASSERT_EQ(fr::Http::string_to_request_type(str.second), str.first); + } +} + +TEST(HttpTest, test_url_encode) +{ + std::string source = "1\"!£FEW$\"931-90%%+-&*0(du%a90dj09=_da.A~"; + ASSERT_EQ(fr::Http::url_encode(source), "1%22!%C2%A3FEW%24%22931-90%25%25%2B-%26*0(du%25a90dj09%3D_da.A~"); +} + +TEST(HttpTest, test_url_decode) +{ + std::string source = "1%22!%C2%A3FEW%24%22931-90%25%25%2B-%26*0(du%25a90dj09%3D_da.A~"; + ASSERT_EQ(fr::Http::url_decode(source), "1\"!£FEW$\"931-90%%+-&*0(du%a90dj09=_da.A~"); +} + +TEST(HttpTest, test_get_mimetype) +{ + ASSERT_EQ(fr::Http::get_mimetype(".html"), "text/html"); + ASSERT_EQ(fr::Http::get_mimetype("my_file.html"), "text/html"); + ASSERT_EQ(fr::Http::get_mimetype("file.some_random_type"), "application/octet-stream"); +} \ No newline at end of file