Bug fixes. HTTP parsing improvements. More tests.

Fixed TcpSocket::set_descriptor(nullptr) causing an invalid read from address 0x0.

Improved HTTP response/request parsing so that they report a parse failure if the first few bytes of the HTTP request don't match the expected format rather than continuing to look for an end of header.

Fixed broken fr::Http::url_encode() implementation.

Optimised fr::Http::url_decode() implementation.

Added fr::Http unit tests.
This commit is contained in:
Fred Nicolson 2018-03-01 15:51:57 +00:00
parent 62d8b7ba63
commit 103e0faaae
12 changed files with 184 additions and 77 deletions

View File

@ -17,13 +17,14 @@ namespace fr
public:
enum RequestType
{
Unknown = 0,
Get = 1,
Post = 2,
Put = 3,
Delete = 4,
Patch = 5,
RequestTypeCount = 6, //Keep me at the end and updated
Get = 0,
Post = 1,
Put = 2,
Delete = 3,
Patch = 4,
RequestTypeCount = 5, //Keep me at the end of valid HTTP request types, and updated
Unknown = 6,
Partial = 7,
};
enum RequestStatus
{
@ -247,14 +248,6 @@ namespace fr
*/
const static std::string &get_mimetype(const std::string &filename);
protected:
/*!
* Splits a string by new line. Ignores escaped \n's
*
* @return The split string
*/
std::vector<std::string> split_string(const std::string &str);
/*!
* Converts a 'RequestType' enum value to
* a printable string.
@ -262,18 +255,23 @@ namespace fr
* @param type The RequestType to convert
* @return The printable version of the enum value
*/
std::string request_type_to_string(RequestType type) const;
static std::string request_type_to_string(RequestType type);
/*!
* Converts hexadecimal to an integer.
* Converts a string value into a 'RequestType' enum value.
*
* @param hex The hex value to convert
* @return The decimal equivilent of the hexadecimal value.
* @param str The string to convert
* @return The converted RequestType. Unknown on failure. Or Partial if str is part of a request type.
*/
static inline int dectohex(const std::string &hex)
{
return (int)strtol(&hex[0], nullptr, 16);
}
static RequestType string_to_request_type(const std::string &str) ;
protected:
/*!
* Splits a string by new line. Ignores escaped \n's
*
* @return The split string
*/
static std::vector<std::string> split_string(const std::string &str);
/*!
* Converts a parameter list to a vector pair.
@ -283,7 +281,7 @@ namespace fr
* @param str The string to parse
* @return The vector containing the results pairs
*/
std::vector<std::pair<std::string, std::string>> parse_argument_list(const std::string &str);
static std::vector<std::pair<std::string, std::string>> parse_argument_list(const std::string &str);
/*!
* Parses a header line in a HTTP request/response

View File

@ -57,8 +57,9 @@ namespace fr
* Parses the header type (GET/POST) from the given string.
*
* @param str The first header line
* @return The parsed request type
*/
void parse_header_type(const std::string &str);
Http::RequestType parse_header_type(const std::string &str);
/*!
* Parses the header URI

View File

@ -5,23 +5,17 @@
#include <iostream>
#include <sstream>
#include <algorithm>
#include <iomanip>
#include "frnetlib/Http.h"
namespace fr
{
const static std::string request_type_strings[Http::RequestType::RequestTypeCount] = {"UNKNOWN",
"GET",
"POST",
"PUT",
"DELETE",
"PATCH"};
Http::Http()
: request_type(Unknown),
uri("/"),
status(Ok)
{
static_assert(Http::RequestType::RequestTypeCount == 6, "Please update request_type_strings");
}
Http::RequestType Http::get_type() const
@ -118,13 +112,42 @@ namespace fr
uri = '/' + str;
}
std::string Http::request_type_to_string(RequestType type) const
std::string Http::request_type_to_string(RequestType type)
{
static_assert(RequestType::RequestTypeCount == 5, "Update request_type_to_string");
const static std::string request_type_strings[RequestType::RequestTypeCount] = {"GET",
"POST",
"PUT",
"DELETE",
"PATCH"};
if(type >= RequestType::RequestTypeCount)
return request_type_strings[0];
return "UNKNOWN";
return request_type_strings[type];
}
Http::RequestType Http::string_to_request_type(const std::string &str)
{
//Find the request type
static_assert(Http::RequestTypeCount == 5, "Update parse_header_type()");
RequestType type = Http::Unknown;
for(size_t a = 0; a < Http::RequestTypeCount; ++a)
{
std::string type_string = request_type_to_string(static_cast<RequestType>(a));
int cmp_ret = str.compare(0, type_string.size(), type_string);
if(cmp_ret == 0)
return static_cast<RequestType>(a);
if(str.size() < type_string.size() && cmp_ret < 0)
type = Http::Partial;
if(type != Http::Partial && str.size() < type_string.size() && cmp_ret > 0)
type = Http::Unknown;
}
return type;
}
void Http::set_type(Http::RequestType type)
{
request_type = type;
@ -137,18 +160,31 @@ namespace fr
std::string Http::url_encode(const std::string &str)
{
std::stringstream encoded;
encoded << std::hex;
for(const auto &c : str)
static const char hex_lookup[]= "0123456789ABCDEF";
std::string out;
for(char c : str)
{
if(isalnum(c) || c == '-' || c == '_' || c == '.' || c == '~')
encoded << c;
else if(c == ' ')
encoded << '+';
if ((c >= '0' && c <= '9') ||
(c >= 'a' && c <= 'z') ||
(c >= 'A' && c <= 'Z') ||
c == '-' || c == '_' || c == '.' || c == '!' || c == '~' ||
c == '*' || c == '\'' || c == '(' || c == ')')
{
out += c;
}
else if (c == ' ')
{
out += '+';
}
else
encoded << "%" << std::uppercase << (int)c << std::nouppercase;
{
out.push_back('%');
out.push_back(hex_lookup[(c&0xF0)>>4]);
out.push_back(hex_lookup[(c&0x0F)]);
}
}
return encoded.str();
return out;
}
std::string Http::url_decode(const std::string &str)
@ -158,8 +194,14 @@ namespace fr
{
if(str[a] == '%' && a < str.size() - 1)
{
result += (char)dectohex(str.substr(a + 1, 2));
a += 2;
int ch1 = str[++a] - 48;
int ch2 = str[++a] - 48;
if(ch1 > 9) ch1 -= 7;
if(ch2 > 9) ch2 -= 7;
uint8_t ret = 0;
ret |= ch1 << 4;
ret |= ch2;
result.push_back(ret);
}
else if(str[a] == '+')
{

View File

@ -22,6 +22,10 @@ namespace fr
//Ensure that the whole header has been parsed first
if(!header_ended)
{
//Verify that it's a valid HTTP header so far
if(!body.empty() && Http::string_to_request_type(body) == fr::Http::Unknown)
return fr::Socket::ParseError;
//Check to see if this request data contains the end of the header
uint16_t header_end_size = 4;
auto header_end = body.find("\r\n\r\n");
@ -78,7 +82,9 @@ namespace fr
return true;
//Parse request type & uri
parse_header_type(header_lines[line]);
request_type = parse_header_type(header_lines[line]);
if(request_type > Http::RequestTypeCount)
return false;
parse_header_uri(header_lines[line]);
line++;
@ -177,29 +183,15 @@ namespace fr
}
}
void HttpRequest::parse_header_type(const std::string &str)
Http::RequestType HttpRequest::parse_header_type(const std::string &str)
{
//Find the request type
auto type_end = str.find(' ');
if(type_end != std::string::npos)
{
//Check what it is
if(str.compare(0, type_end, "GET") == 0)
request_type = fr::Http::Get;
else if(str.compare(0, type_end, "POST") == 0)
request_type = fr::Http::Post;
else if(str.compare(0, type_end, "PUT") == 0)
request_type = fr::Http::Put;
else if(str.compare(0, type_end, "DELETE") == 0)
request_type = fr::Http::Delete;
else if(str.compare(0, type_end, "PATCH") == 0)
request_type = fr::Http::Patch;
else
request_type = fr::Http::Unknown;
return;
return string_to_request_type(str.substr(0, type_end));
}
throw std::invalid_argument("No known request type found in: " + str);
return Http::Unknown;
}
void HttpRequest::parse_header_uri(const std::string &str)

View File

@ -14,6 +14,10 @@ namespace fr
//Ensure that the whole header has been parsed first
if(!header_ended)
{
//Verify that it's a valid HTTP response if there's enough data
if(body.size() >= 4 && body.compare(0, 4, "HTTP") != 0)
return fr::Socket::ParseError;
//Check to see if this request data contains the end of the header
uint16_t header_end_size = 4;
auto header_end = body.find("\r\n\r\n");

View File

@ -91,7 +91,7 @@ namespace fr
//Initialise mbedtls
int error = 0;
std::unique_ptr<mbedtls_ssl_context> ssl(new mbedtls_ssl_context);
auto client_fd = std::make_unique<mbedtls_net_context>();
std::unique_ptr<mbedtls_net_context> client_fd(new mbedtls_net_context);
mbedtls_ssl_init(ssl.get());
mbedtls_net_init(client_fd.get());
@ -108,7 +108,6 @@ namespace fr
size_t ip_len = 0;
if((error = mbedtls_net_accept(&listen_fd, client_fd.get(), client_ip, sizeof(client_ip), &ip_len)) != 0)
{
std::cout << "Accept error: " << error << std::endl;
free_contexts();
return Socket::Error;
}
@ -120,7 +119,6 @@ namespace fr
{
if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE)
{
std::cout << "Handshake error: " << error << std::endl;
free_contexts();
return Socket::Status::HandshakeFailed;
}

View File

@ -84,6 +84,7 @@ namespace fr
{
//Initialise mbedtls stuff
ssl = std::make_unique<mbedtls_ssl_context>();
ssl_socket_descriptor = std::make_unique<mbedtls_net_context>();
mbedtls_ssl_init(ssl.get());
mbedtls_net_init(ssl_socket_descriptor.get());
@ -94,7 +95,6 @@ namespace fr
auto ret = socket.connect(address, port, timeout);
if(ret != fr::Socket::Success)
return ret;
ssl_socket_descriptor = std::make_unique<mbedtls_net_context>();
ssl_socket_descriptor->fd = socket.get_socket_descriptor();
remote_address = socket.get_remote_address();
socket.set_descriptor(nullptr);
@ -121,7 +121,7 @@ namespace fr
return Socket::Status::Error;
}
mbedtls_ssl_set_bio(ssl.get(), &ssl_socket_descriptor, mbedtls_net_send, mbedtls_net_recv, nullptr);
mbedtls_ssl_set_bio(ssl.get(), ssl_socket_descriptor.get(), mbedtls_net_send, mbedtls_net_recv, nullptr);
//Do SSL handshake
while((error = mbedtls_ssl_handshake(ssl.get())) != 0)
@ -157,7 +157,8 @@ namespace fr
void SSLSocket::set_descriptor(void *descriptor)
{
ssl_socket_descriptor.reset(static_cast<mbedtls_net_context*>(descriptor));
reconfigure_socket();
if(descriptor)
reconfigure_socket();
}
void SSLSocket::verify_certificates(bool should_verify_)

View File

@ -90,7 +90,7 @@ namespace fr
//Prepare to wait for the client
sockaddr_storage client_addr{};
int client_descriptor;
int32_t client_descriptor;
char client_printable_addr[INET6_ADDRSTRLEN];
//Accept one

View File

@ -80,6 +80,11 @@ namespace fr
void TcpSocket::set_descriptor(void *descriptor)
{
if(!descriptor)
{
socket_descriptor = -1;
return;
}
socket_descriptor = *static_cast<int32_t*>(descriptor);
reconfigure_socket();
}

View File

@ -82,29 +82,34 @@ TEST(HttpRequestTest, request_type_parse)
const std::string delete_request = "DELETE / HTTP/1.1\r\n\r\n";
const std::string patch_request = "PATCH / HTTP/1.1\r\n\r\n";
const std::string invalid_request = "INVALID / HTTP/1.1\r\n\r\n";
const std::string invalid_request2 = "PU / HTTP/1.1\r\n\r\n";
fr::HttpRequest request;
request.parse(get_request.c_str(), get_request.size());
ASSERT_EQ(request.parse(get_request.c_str(), get_request.size()), fr::Socket::Success);
ASSERT_EQ(request.get_type(), fr::Http::Get);
request = {};
request.parse(post_request.c_str(), post_request.size());
ASSERT_EQ(request.parse(post_request.c_str(), post_request.size()), fr::Socket::Success);
ASSERT_EQ(request.get_type(), fr::Http::Post);
request = {};
request.parse(put_request.c_str(), put_request.size());
ASSERT_EQ(request.parse(put_request.c_str(), put_request.size()), fr::Socket::Success);
ASSERT_EQ(request.get_type(), fr::Http::Put);
request = {};
request.parse(delete_request.c_str(), delete_request.size());
ASSERT_EQ(request.parse(delete_request.c_str(), delete_request.size()), fr::Socket::Success);
ASSERT_EQ(request.get_type(), fr::Http::Delete);
request = {};
request.parse(patch_request.c_str(), patch_request.size());
ASSERT_EQ(request.parse(patch_request.c_str(), patch_request.size()), fr::Socket::Success);
ASSERT_EQ(request.get_type(), fr::Http::Patch);
request = {};
request.parse(invalid_request.c_str(), invalid_request.size());
ASSERT_EQ(request.parse(invalid_request.c_str(), invalid_request.size()), fr::Socket::ParseError);
ASSERT_EQ(request.get_type(), fr::Http::Unknown);
request = {};
ASSERT_EQ(request.parse(invalid_request2.c_str(), invalid_request2.size()), fr::Socket::ParseError);
ASSERT_EQ(request.get_type(), fr::Http::Unknown);
request = {};
}

View File

@ -92,6 +92,7 @@ TEST(HttpResponseTest, header_length_test)
//Try data with no header end first
std::string buff(MAX_HTTP_HEADER_SIZE + 1, '\0');
fr::HttpResponse response;
buff.insert(0, "HTTP");
ASSERT_EQ(response.parse(buff.c_str(), buff.size()), fr::Socket::HttpHeaderTooBig);
response = {};

60
tests/HttpTest.cpp Normal file
View File

@ -0,0 +1,60 @@
//
// Created by fred.nicolson on 01/03/18.
//
#include <gtest/gtest.h>
#include <frnetlib/HttpResponse.h>
TEST(HttpTest, test_request_type_to_string)
{
for(size_t a = 0; a < fr::Http::RequestTypeCount; ++a)
{
ASSERT_EQ(fr::Http::string_to_request_type(fr::Http::request_type_to_string((fr::Http::RequestType)a)), a);
}
ASSERT_EQ(fr::Http::request_type_to_string(fr::Http::Partial), "UNKNOWN");
ASSERT_EQ(fr::Http::request_type_to_string(fr::Http::RequestTypeCount), "UNKNOWN");
ASSERT_EQ(fr::Http::request_type_to_string(fr::Http::Unknown), "UNKNOWN");
}
TEST(HttpTest, test_string_to_request_type)
{
std::vector<std::pair<fr::Http::RequestType, std::string>> strings = {
{fr::Http::Get, "GET"},
{fr::Http::Put, "PUT"},
{fr::Http::Delete, "DELETE"},
{fr::Http::Patch, "PATCH"},
{fr::Http::Patch, "PATCHid-=wa"},
{fr::Http::Partial, "PA"},
{fr::Http::Partial, "PU"},
{fr::Http::Partial, "DELET"},
{fr::Http::Unknown, "DELETa"},
{fr::Http::Unknown, "U"},
{fr::Http::Unknown, "dwaouidhwi"},
{fr::Http::Unknown, "get"},
};
for(auto &str : strings)
{
ASSERT_EQ(fr::Http::string_to_request_type(str.second), str.first);
}
}
TEST(HttpTest, test_url_encode)
{
std::string source = "1\"!£FEW$\"931-90%%+-&*0(du%a90dj09=_da.A~";
ASSERT_EQ(fr::Http::url_encode(source), "1%22!%C2%A3FEW%24%22931-90%25%25%2B-%26*0(du%25a90dj09%3D_da.A~");
}
TEST(HttpTest, test_url_decode)
{
std::string source = "1%22!%C2%A3FEW%24%22931-90%25%25%2B-%26*0(du%25a90dj09%3D_da.A~";
ASSERT_EQ(fr::Http::url_decode(source), "1\"!£FEW$\"931-90%%+-&*0(du%a90dj09=_da.A~");
}
TEST(HttpTest, test_get_mimetype)
{
ASSERT_EQ(fr::Http::get_mimetype(".html"), "text/html");
ASSERT_EQ(fr::Http::get_mimetype("my_file.html"), "text/html");
ASSERT_EQ(fr::Http::get_mimetype("file.some_random_type"), "application/octet-stream");
}