Added experimental WebSocket support.

Added Sha1 hash, and Base64 encode implementations which are optionally compiled if websock support is enabled, to assist in the WebSock handshake.

Added WebSocket to manage the WebSock protocol during connections.

Added WebFrame to allow for sending/receiving data through the WebSock protocol easily.
This commit is contained in:
Unknown 2018-03-01 23:03:05 +00:00
parent 103e0faaae
commit fa843b57c8
11 changed files with 899 additions and 14 deletions

View File

@ -6,16 +6,15 @@ set(FRNETLIB_LINK_LIBRARIES "")
set( CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/cmake_modules) set( CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/cmake_modules)
#User options #User options
option(USE_SSL "Use SSL" OFF) option(USE_SSL "Enable SSL support" OFF)
option(BUILD_EXAMPLES "Build frnetlib examples" ON)
option(BUILD_TESTS "Build frnetlib tests" ON)
option(BUILD_WEBSOCK "Enable WebSocket support" ON)
set(FRNETLIB_BUILD_SHARED_LIBS false CACHE BOOL "Build shared library.") set(FRNETLIB_BUILD_SHARED_LIBS false CACHE BOOL "Build shared library.")
set(MAX_HTTP_HEADER_SIZE "0xC800" CACHE STRING "The maximum allowed HTTP header size in bytes") set(MAX_HTTP_HEADER_SIZE "0xC800" CACHE STRING "The maximum allowed HTTP header size in bytes")
set(MAX_HTTP_BODY_SIZE "0xA00000" CACHE STRING "The maximum allowed HTTP body size in bytes") set(MAX_HTTP_BODY_SIZE "0xA00000" CACHE STRING "The maximum allowed HTTP body size in bytes")
set(LISTEN_QUEUE_SIZE "64" CACHE STRING "The listen queue depth for fr::TcpListener/fr::SSLListener") set(LISTEN_QUEUE_SIZE "64" CACHE STRING "The listen queue depth for fr::TcpListener/fr::SSLListener")
#Enable tests and examples by default
option(BUILD_EXAMPLES "Build frnetlib examples" ON)
option(BUILD_TESTS "Build frnetlib tests" ON)
#Configure defines based on user options #Configure defines based on user options
add_definitions(-DMAX_HTTP_HEADER_SIZE=${MAX_HTTP_HEADER_SIZE}) add_definitions(-DMAX_HTTP_HEADER_SIZE=${MAX_HTTP_HEADER_SIZE})
add_definitions(-DMAX_HTTP_BODY_SIZE=${MAX_HTTP_BODY_SIZE}) add_definitions(-DMAX_HTTP_BODY_SIZE=${MAX_HTTP_BODY_SIZE})
@ -27,6 +26,10 @@ if(USE_SSL)
set(SOURCE_FILES ${SOURCE_FILES} src/SSLSocket.cpp include/frnetlib/SSLSocket.h src/SSLListener.cpp include/frnetlib/SSLListener.h include/frnetlib/SSLContext.h) set(SOURCE_FILES ${SOURCE_FILES} src/SSLSocket.cpp include/frnetlib/SSLSocket.h src/SSLListener.cpp include/frnetlib/SSLListener.h include/frnetlib/SSLContext.h)
endif() endif()
if(BUILD_WEBSOCK)
set(SOURCE_FILES ${SOURCE_FILES} src/WebFrame.cpp include/frnetlib/WebFrame.h src/Sha1.cpp include/frnetlib/Sha1.h src/Base64.cpp include/frnetlib/Base64.h src/Sha1.cpp include/frnetlib/WebSocket.h)
endif()
add_definitions(-DNOMINMAX) add_definitions(-DNOMINMAX)
add_definitions(-Dhtonf) add_definitions(-Dhtonf)
add_definitions(-Dhtonll) add_definitions(-Dhtonll)

25
include/frnetlib/Base64.h Normal file
View File

@ -0,0 +1,25 @@
//
// Created by fred on 01/03/18.
//
#ifndef FRNETLIB_BASE64_H
#define FRNETLIB_BASE64_H
#include <string>
class Base64
{
public:
/*!
* Encodes a string into Base64
*
* @param input The string to encode
* @return The resulting encoded string
*/
static std::string encode(const std::string &input);
//There's no decode function at the moment. Maybe I'll write one eventually. Sorry.
private:
};
#endif //FRNETLIB_BASE64_H

View File

@ -94,7 +94,7 @@ namespace fr
/*! /*!
* Gets the underlying socket descriptor. * Gets the underlying socket descriptor.
* *
* @return The socket's descriptor. * @return The socket's descriptor. -1 indicates no connection.
*/ */
inline int32_t get_socket_descriptor() const override inline int32_t get_socket_descriptor() const override
{ {

43
include/frnetlib/Sha1.h Normal file
View File

@ -0,0 +1,43 @@
//
// Created by fred on 01/03/18.
//
#ifndef FRNETLIB_SHA1_H
#define FRNETLIB_SHA1_H
#include <string>
#include "frnetlib/NetworkEncoding.h"
class Sha1
{
public:
Sha1();
/*!
* Sha1 hashes a string input and returns the raw digest
*
* @param input The string to hash
* @return The Sha1 digest converted to host endianness
*/
static std::string sha1_digest(const std::string &input)
{
Sha1 ctx;
ctx.update(input);
ctx.final();
for(unsigned int &a : ctx.digest)
a = ntohl(a);
return std::string((char*)&ctx.digest[0], sizeof(ctx.digest));
}
private:
void update(const std::string &s);
void update(std::istream &is);
void final();
uint32_t digest[5];
std::string buffer;
uint64_t transforms;
};
#endif //FRNETLIB_SHA1_H

View File

@ -95,7 +95,7 @@ namespace fr
virtual Status receive_raw(void *data, size_t data_size, size_t &received) = 0; virtual Status receive_raw(void *data, size_t data_size, size_t &received) = 0;
/*! /*!
* Gets the socket descriptor. * Gets the underlying socket descriptor.
* *
* @return The socket descriptor. * @return The socket descriptor.
*/ */
@ -175,13 +175,15 @@ namespace fr
* *
* If a client attempts to send a packet larger than sz bytes, then * If a client attempts to send a packet larger than sz bytes, then
* the client will be disconnected and an fr::Socket::MaxPacketSizeExceeded * the client will be disconnected and an fr::Socket::MaxPacketSizeExceeded
* will be returned. Pass '0' to indicate no limit. The default value is 0. * will be returned. Pass '0' to indicate no limit.
* *
* This should be used to prevent potential abuse, as a client could say that * This should be used to prevent potential abuse, as a client could say that
* it's going to send a 200GiB packet, which would cause the Socket to try and * it's going to send a 200GiB packet, which would cause the Socket to try and
* allocate that much memory to accommodate the data, which is most likely not * allocate that much memory to accommodate the data, which is most likely not
* desirable. * desirable.
* *
* By default, there is no limit (0)
*
* @param sz The maximum number of bytes that may be received in an fr::Packet * @param sz The maximum number of bytes that may be received in an fr::Packet
*/ */
void set_max_receive_size(uint32_t sz); void set_max_receive_size(uint32_t sz);
@ -207,12 +209,12 @@ namespace fr
/*! /*!
* Gets the max packet size. See set_max_packet_size * Gets the max packet size. See set_max_packet_size
* for more information. * for more information. If this returns 0, then
* * there is no limit.
* *
* @return The max packet size * @return The max packet size
*/ */
inline uint32_t get_max_receive_size() inline uint32_t get_max_receive_size() const
{ {
return max_receive_size; return max_receive_size;
} }
@ -222,7 +224,7 @@ namespace fr
* *
* @return The string address * @return The string address
*/ */
inline const std::string &get_remote_address() inline const std::string &get_remote_address() const
{ {
return remote_address; return remote_address;
} }

View File

@ -82,9 +82,9 @@ public:
void set_descriptor(void *descriptor_data) override; void set_descriptor(void *descriptor_data) override;
/*! /*!
* Gets the unerlying socket descriptor * Gets the underlying socket descriptor
* *
* @return The socket descriptor * @return The socket descriptor. -1 typically indicates no connection.
*/ */
int32_t get_socket_descriptor() const override; int32_t get_socket_descriptor() const override;

127
include/frnetlib/WebFrame.h Normal file
View File

@ -0,0 +1,127 @@
//
// Created by fred on 01/03/18.
//
#ifndef FRNETLIB_WEBFRAME_H
#define FRNETLIB_WEBFRAME_H
#include "Sendable.h"
namespace fr
{
class WebFrame : public fr::Sendable
{
public:
enum Opcode : uint8_t
{
Continuation = 0,
Text = 1,
Binary = 2,
Disconnect = 8,
Ping = 9,
Pong = 10
};
/*!
* Constructs the WebFrame.
*
* @param type The opcode type. See set_opcode. Text by default.
*/
explicit WebFrame(Opcode type = Text);
/*!
* Get's the received payload data. (Data received).
*
* @return The payload
*/
inline const std::string get_payload()
{
return payload;
}
/*!
* Sets the frame payload (data being sent)
*
* @param payload_ The payload to send
*/
inline void set_payload(std::string payload_)
{
payload = std::move(payload_);
}
/*!
* Sets the WebFrame opcode (it's type)
* This should be Text for non-binary data.
* Or Binary for binary data.
* Or Continuation if it's the next part of a fragmented message (set_final() set to true if this is the last part)
*
* @param opcode_ The opcode to use.
*/
inline void set_opcode(Opcode opcode_)
{
opcode = opcode_;
}
/*!
* Gets the WebFrame opcode (it's type)
*
* @return The opcode of the frame
*/
inline Opcode get_opcode()
{
return opcode;
}
/*!
* Checks if the frame is the final part of the message
*
* @return True if this frame is the final part, false if there's more to come.
*/
inline bool is_final()
{
return final;
}
/*!
* Sets whether the frame is the final part of a message or not.
*
* @param is_final True if this frame is the final part of a message. False if it's a fragment with more to come.
*/
inline void set_final(bool is_final = true)
{
final = is_final;
}
protected:
/*!
* Overridable send, to allow
* custom types to be directly sent through
* sockets.
*
* @param socket The socket to send through
* @return Status indicating if the send succeeded or not.
*/
Socket::Status send(Socket *socket) override;
/*!
* Overrideable receive, to allow
* custom types to be directly received through
* sockets.
*
* @note If the maximum message length is exceeded, then the connection will be closed
* @param socket The socket to send through
* @return Status indicating if the send succeeded or not.
*/
Socket::Status receive(Socket *socket) override;
private:
std::string payload;
Opcode opcode;
bool final;
static uint32_t current_mask_key;
};
}
#endif //FRNETLIB_WEBFRAME_H

View File

@ -0,0 +1,187 @@
//
// Created by fred on 01/03/18.
//
#ifndef FRNETLIB_WEBSOCKET_H
#define FRNETLIB_WEBSOCKET_H
#include "frnetlib/Socket.h"
#include "frnetlib/HttpRequest.h"
#include "frnetlib/HttpResponse.h"
#include "Base64.h"
#include "Sha1.h"
#include "WebFrame.h"
namespace fr
{
class WebSocketBase
{
public:
/*!
* Checks if the socket is the client component or the server component
*
* @return True if it's the client component. False otherwise.
*/
virtual bool is_client()=0;
};
template<typename SocketType = fr::Socket>
class WebSocket : public SocketType, public WebSocketBase
{
public:
WebSocket()
: is_the_client(true)
{}
/*!
* Connects the WebSocket to a WebSocket server. Makes
* the connection using the underlying socket, and then handshakes.
*
* @param address The address of the socket to connect to
* @param port The port of the socket to connect to
* @param timeout The number of seconds to wait before timing the connection attempt out. Pass {} for default.
* @return A Socket::Status indicating the status of the operation (Success on success, an error type on failure).
*/
Socket::Status connect(const std::string &address, const std::string &port, std::chrono::seconds timeout) override
{
//Establish a connection using the parent class
Socket::Status status = SocketType::connect(address, port, timeout);
if(status != fr::Socket::Success)
return status;
//Send an upgrade request header
std::string websocket_key = Base64::encode(std::to_string(std::time(nullptr)));
HttpRequest request;
request.header("sec-websocket-key") = websocket_key;
request.header("sec-websocket-version") = "13";
request.header("connection") = "upgrade";
request.header("upgrade") = "websocket";
status = SocketType::send(request);
if(status != fr::Socket::Success)
return status;
//Receive the response
HttpResponse response;
status = SocketType::receive(response);
if(status != fr::Socket::Success)
return status;
if(response.get_status() != Http::SwitchingProtocols)
{
disconnect();
return Socket::HandshakeFailed;
}
//Verify the sec-websocket-accept header
std::string derived_key = response.header("sec-websocket-accept");
if(derived_key != Base64::encode(Sha1::sha1_digest(websocket_key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")))
{
disconnect();
return Socket::HandshakeFailed;
}
return fr::Socket::Success;
}
/*!
* Sends a Disconnect Frame and then
* closes the connection.
*/
void disconnect() override
{
WebFrame frame;
frame.set_opcode(WebFrame::Disconnect);
if(SocketType::connected())
SocketType::send(frame);
SocketType::close_socket();
}
/*!
* Sets the socket file descriptor. This is called by the Listener
* when accepting a connection, and so we can use the opportunity to
* handshake with the server.
*
* @param descriptor The socket descriptor.
*/
void set_descriptor(void *descriptor) override
{
SocketType::set_descriptor(descriptor);
if(!descriptor || SocketType::get_socket_descriptor() == -1)
return;
is_the_client = false; //If we're accepting a connection then we're the server
//Initialise connection, receive the handshake
HttpRequest request;
if(SocketType::receive(request) != fr::Socket::Success)
throw std::runtime_error("Failed to receive WebSock handshake");
if(request.header("Upgrade") != "websocket" || request.get_type() != fr::Http::Get)
throw std::runtime_error("Client isn't using the WebSock protocol");
//Calculate the derived key, then send back our response
std::string derived_key = Base64::encode(Sha1::sha1_digest(request.header("sec-websocket-key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"));
HttpResponse response;
response.set_status(fr::Http::SwitchingProtocols);
response.header("Upgrade") = "websocket";
response.header("Connection") = "Upgrade";
response.header("Sec-WebSocket-Accept") = derived_key;
SocketType::send(response);
}
/*!
* Receive a Sendable object through the socket.
* Internally sends back a pong if a ping is received.
*
* @param obj The object to receive
* @return The status of the receive
*/
Socket::Status receive(Sendable &obj) override
{
WebFrame &frame = dynamic_cast<WebFrame&>(obj);
//Try and receive a message. If it's a ping, then silently send back a pong.
Socket::Status status;
while(true)
{
status = SocketType::receive(obj);
if(status != Socket::Success)
return status;
if(frame.get_opcode() == WebFrame::Ping)
{
frame.set_opcode(WebFrame::Pong);
status = SocketType::send(frame);
if(status != fr::Socket::Success)
return status;
continue;
}
break;
}
//If it's a disconnect
if(frame.get_opcode() == WebFrame::Disconnect)
{
disconnect();
return Socket::Disconnected;
}
return status;
}
/*!
* Checks to see if the socket initialised the connection, or
* if it was accepted by a listener.
*
* @return True if accepted by a listener, false otherwise.
*/
inline bool is_client() override
{
return is_the_client;
}
private:
bool is_the_client;
};
}
#endif //FRNETLIB_WEBSOCKET_H

42
src/Base64.cpp Normal file
View File

@ -0,0 +1,42 @@
//
// Created by fred on 01/03/18.
//
#include "frnetlib/Base64.h"
std::string Base64::encode(const std::string &input)
{
static const std::string base64_table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
std::string out;
out.reserve(input.size() + (input.size() / 3) + 1);
int a;
//Do as many sets of 3 bytes as we can
for(a = 0; a < input.size() - 2; a += 3)
{
out += base64_table[(input[a] >> 2) & 0x3F]; //Store the first 6 bits of the first byte
out += base64_table[((input[a] & 0x3) << 4) | (input[a + 1] & 0xF0) >> 4]; //Store the last 2 bits of the first byte combined with the first 4 bits of the second byte
out += base64_table[((input[a + 1] & 0xF) << 2) | (input[a + 2] & 0xC0) >> 6]; //Store the last 4 bits of the second byte combined with the first 2 bits of the third byte
out += base64_table[input[a + 2] & 0x3F]; //Store the last 6 bits of the third byte
}
//Check if there's a remainder
if(a < input.size())
{
out += base64_table[(input[a] >> 2) & 0x3F]; //Store first 6 bits of the first byte
if(a == input.size() - 2) //There's 2 bytes left
{
out += base64_table[((input[a] & 0x3) << 4) | (input[a + 1] & 0xF0) >> 4]; //Store last 2 bits of first byte, and first 4 bits of second byte
out += base64_table[(input[a + 1] & 0xF) << 2]; //Finally store the last 4 bits of the second byte
out += "=";
}
else //There's 1 byte left
{
out += base64_table[(input[a] & 0x3) << 4]; //Store last 2 bits of first byte
out += "==";
}
}
return out;
}

287
src/Sha1.cpp Normal file
View File

@ -0,0 +1,287 @@
/*
Sha1.cpp - source code of
============
SHA-1 in C++
============
100% Public Domain.
Original C Code
-- Steve Reid <steve@edmweb.com>
Small changes to fit into bglibs
-- Bruce Guenter <bruce@untroubled.org>
Translation to simpler C++ Code
-- Volker Grabsch <vog@notjusthosting.com>
Safety fixes
-- Eugene Hopkinson <slowriot at voxelstorm dot com>
--------------------------------------------------------------------
| Sourced from: https://github.com/vog/sha1 |
| Altered for inclusion within frnetlib. |
--------------------------------------------------------------------
*/
#include "frnetlib/Sha1.h"
#include <sstream>
#include <iomanip>
#include <fstream>
static const size_t BLOCK_INTS = 16; /* number of 32bit integers per Sha1 block */
static const size_t BLOCK_BYTES = BLOCK_INTS * 4;
static void reset(uint32_t digest[], std::string &buffer, uint64_t &transforms)
{
/* Sha1 initialization constants */
digest[0] = 0x67452301;
digest[1] = 0xefcdab89;
digest[2] = 0x98badcfe;
digest[3] = 0x10325476;
digest[4] = 0xc3d2e1f0;
/* Reset counters */
buffer = "";
transforms = 0;
}
static uint32_t rol(const uint32_t value, const size_t bits)
{
return (value << bits) | (value >> (32 - bits));
}
static uint32_t blk(const uint32_t block[BLOCK_INTS], const size_t i)
{
return rol(block[(i+13)&15] ^ block[(i+8)&15] ^ block[(i+2)&15] ^ block[i], 1);
}
/*
* (R0+R1), R2, R3, R4 are the different operations used in Sha1
*/
static void R0(const uint32_t block[BLOCK_INTS], const uint32_t v, uint32_t &w, const uint32_t x, const uint32_t y, uint32_t &z, const size_t i)
{
z += ((w&(x^y))^y) + block[i] + 0x5a827999 + rol(v, 5);
w = rol(w, 30);
}
static void R1(uint32_t block[BLOCK_INTS], const uint32_t v, uint32_t &w, const uint32_t x, const uint32_t y, uint32_t &z, const size_t i)
{
block[i] = blk(block, i);
z += ((w&(x^y))^y) + block[i] + 0x5a827999 + rol(v, 5);
w = rol(w, 30);
}
static void R2(uint32_t block[BLOCK_INTS], const uint32_t v, uint32_t &w, const uint32_t x, const uint32_t y, uint32_t &z, const size_t i)
{
block[i] = blk(block, i);
z += (w^x^y) + block[i] + 0x6ed9eba1 + rol(v, 5);
w = rol(w, 30);
}
static void R3(uint32_t block[BLOCK_INTS], const uint32_t v, uint32_t &w, const uint32_t x, const uint32_t y, uint32_t &z, const size_t i)
{
block[i] = blk(block, i);
z += (((w|x)&y)|(w&x)) + block[i] + 0x8f1bbcdc + rol(v, 5);
w = rol(w, 30);
}
static void R4(uint32_t block[BLOCK_INTS], const uint32_t v, uint32_t &w, const uint32_t x, const uint32_t y, uint32_t &z, const size_t i)
{
block[i] = blk(block, i);
z += (w^x^y) + block[i] + 0xca62c1d6 + rol(v, 5);
w = rol(w, 30);
}
/*
* Hash a single 512-bit block. This is the core of the algorithm.
*/
static void transform(uint32_t digest[], uint32_t block[BLOCK_INTS], uint64_t &transforms)
{
/* Copy digest[] to working vars */
uint32_t a = digest[0];
uint32_t b = digest[1];
uint32_t c = digest[2];
uint32_t d = digest[3];
uint32_t e = digest[4];
/* 4 rounds of 20 operations each. Loop unrolled. */
R0(block, a, b, c, d, e, 0);
R0(block, e, a, b, c, d, 1);
R0(block, d, e, a, b, c, 2);
R0(block, c, d, e, a, b, 3);
R0(block, b, c, d, e, a, 4);
R0(block, a, b, c, d, e, 5);
R0(block, e, a, b, c, d, 6);
R0(block, d, e, a, b, c, 7);
R0(block, c, d, e, a, b, 8);
R0(block, b, c, d, e, a, 9);
R0(block, a, b, c, d, e, 10);
R0(block, e, a, b, c, d, 11);
R0(block, d, e, a, b, c, 12);
R0(block, c, d, e, a, b, 13);
R0(block, b, c, d, e, a, 14);
R0(block, a, b, c, d, e, 15);
R1(block, e, a, b, c, d, 0);
R1(block, d, e, a, b, c, 1);
R1(block, c, d, e, a, b, 2);
R1(block, b, c, d, e, a, 3);
R2(block, a, b, c, d, e, 4);
R2(block, e, a, b, c, d, 5);
R2(block, d, e, a, b, c, 6);
R2(block, c, d, e, a, b, 7);
R2(block, b, c, d, e, a, 8);
R2(block, a, b, c, d, e, 9);
R2(block, e, a, b, c, d, 10);
R2(block, d, e, a, b, c, 11);
R2(block, c, d, e, a, b, 12);
R2(block, b, c, d, e, a, 13);
R2(block, a, b, c, d, e, 14);
R2(block, e, a, b, c, d, 15);
R2(block, d, e, a, b, c, 0);
R2(block, c, d, e, a, b, 1);
R2(block, b, c, d, e, a, 2);
R2(block, a, b, c, d, e, 3);
R2(block, e, a, b, c, d, 4);
R2(block, d, e, a, b, c, 5);
R2(block, c, d, e, a, b, 6);
R2(block, b, c, d, e, a, 7);
R3(block, a, b, c, d, e, 8);
R3(block, e, a, b, c, d, 9);
R3(block, d, e, a, b, c, 10);
R3(block, c, d, e, a, b, 11);
R3(block, b, c, d, e, a, 12);
R3(block, a, b, c, d, e, 13);
R3(block, e, a, b, c, d, 14);
R3(block, d, e, a, b, c, 15);
R3(block, c, d, e, a, b, 0);
R3(block, b, c, d, e, a, 1);
R3(block, a, b, c, d, e, 2);
R3(block, e, a, b, c, d, 3);
R3(block, d, e, a, b, c, 4);
R3(block, c, d, e, a, b, 5);
R3(block, b, c, d, e, a, 6);
R3(block, a, b, c, d, e, 7);
R3(block, e, a, b, c, d, 8);
R3(block, d, e, a, b, c, 9);
R3(block, c, d, e, a, b, 10);
R3(block, b, c, d, e, a, 11);
R4(block, a, b, c, d, e, 12);
R4(block, e, a, b, c, d, 13);
R4(block, d, e, a, b, c, 14);
R4(block, c, d, e, a, b, 15);
R4(block, b, c, d, e, a, 0);
R4(block, a, b, c, d, e, 1);
R4(block, e, a, b, c, d, 2);
R4(block, d, e, a, b, c, 3);
R4(block, c, d, e, a, b, 4);
R4(block, b, c, d, e, a, 5);
R4(block, a, b, c, d, e, 6);
R4(block, e, a, b, c, d, 7);
R4(block, d, e, a, b, c, 8);
R4(block, c, d, e, a, b, 9);
R4(block, b, c, d, e, a, 10);
R4(block, a, b, c, d, e, 11);
R4(block, e, a, b, c, d, 12);
R4(block, d, e, a, b, c, 13);
R4(block, c, d, e, a, b, 14);
R4(block, b, c, d, e, a, 15);
/* Add the working vars back into digest[] */
digest[0] += a;
digest[1] += b;
digest[2] += c;
digest[3] += d;
digest[4] += e;
/* Count the number of transformations */
transforms++;
}
static void buffer_to_block(const std::string &buffer, uint32_t block[BLOCK_INTS])
{
/* Convert the std::string (byte buffer) to a uint32_t array (MSB) */
for (size_t i = 0; i < BLOCK_INTS; i++)
{
block[i] = (buffer[4*i+3] & 0xff)
| (buffer[4*i+2] & 0xff)<<8
| (buffer[4*i+1] & 0xff)<<16
| (buffer[4*i+0] & 0xff)<<24;
}
}
Sha1::Sha1()
{
reset(digest, buffer, transforms);
}
void Sha1::update(const std::string &s)
{
std::istringstream is(s);
update(is);
}
void Sha1::update(std::istream &is)
{
while (true)
{
char sbuf[BLOCK_BYTES];
is.read(sbuf, BLOCK_BYTES - buffer.size());
buffer.append(sbuf, is.gcount());
if (buffer.size() != BLOCK_BYTES)
{
return;
}
uint32_t block[BLOCK_INTS];
buffer_to_block(buffer, block);
transform(digest, block, transforms);
buffer.clear();
}
}
/*
* Add padding and finish up
*/
void Sha1::final()
{
/* Total number of hashed bits */
uint64_t total_bits = (transforms*BLOCK_BYTES + buffer.size()) * 8;
/* Padding */
buffer += 0x80;
size_t orig_size = buffer.size();
while (buffer.size() < BLOCK_BYTES)
{
buffer += (char)0x00;
}
uint32_t block[BLOCK_INTS];
buffer_to_block(buffer, block);
if (orig_size > BLOCK_BYTES - 8)
{
transform(digest, block, transforms);
for (size_t i = 0; i < BLOCK_INTS - 2; i++)
{
block[i] = 0;
}
}
/* Append total_bits, split this uint64_t into two uint32_t */
block[BLOCK_INTS - 1] = total_bits;
block[BLOCK_INTS - 2] = (total_bits >> 32);
transform(digest, block, transforms);
}

169
src/WebFrame.cpp Normal file
View File

@ -0,0 +1,169 @@
//
// Created by fred on 01/03/18.
//
#include "frnetlib/WebFrame.h"
#include "frnetlib/WebSocket.h"
namespace fr
{
uint32_t WebFrame::current_mask_key = static_cast<uint32_t>(std::time(nullptr));
WebFrame::WebFrame(WebFrame::Opcode type)
: opcode(type),
final(true)
{
}
fr::Socket::Status WebFrame::send(Socket *socket_)
{
auto *socket = dynamic_cast<WebSocketBase*>(socket_);
if(!socket)
return Socket::Error;
uint16_t first_2bytes = 0;
std::string buffer;
//Set fin bit. Bit 1.
first_2bytes |= final << 15;
//Set opcode bit
first_2bytes |= opcode << 8;
//Set mask bit (dependent on is_client flag, only client -> server messages are masked)
first_2bytes |= socket->is_client() << 7;
//Set payload length
if(payload.size() <= 125)
first_2bytes |= payload.size();
else
first_2bytes |= (payload.size() < std::numeric_limits<uint16_t>::max()) ? 126 : 127;
first_2bytes = htons(first_2bytes);
buffer.append((char*)&first_2bytes, sizeof(first_2bytes));
//Set additional payload bits if large enough
if(payload.size() > 125)
{
if(payload.size() < std::numeric_limits<uint16_t>::max()) //16bit length
{
auto len = htons(static_cast<uint16_t>(payload.size()));
buffer.append((char*)&len, sizeof(len));
}
else //64bit length
{
uint64_t len = htonll(payload.size());
buffer.append((char*)&len, sizeof(len));
}
}
//Add a masking key if we're the client
if(socket->is_client())
{
union
{
uint32_t mask_key;
char str_mask_key[4];
} mask_union{};
mask_union.mask_key = ++current_mask_key;
buffer.append((char*)&mask_union.mask_key, sizeof(mask_union.mask_key));
//Encode the payload using the mask key
for(size_t a = 0; a < payload.size(); ++a)
{
payload[a] = payload[a] ^ mask_union.str_mask_key[a % 4];
}
}
buffer.append(payload);
return socket_->send_raw(buffer.c_str(), buffer.size());
}
Socket::Status WebFrame::receive(Socket *socket)
{
auto *socket_ = dynamic_cast<WebSocketBase*>(socket);
if(!socket_)
return Socket::Error;
payload.clear();
Socket::Status status;
uint16_t first_2bytes;
status = socket->receive_all(&first_2bytes, sizeof(first_2bytes));
if(status != fr::Socket::Success)
return status;
first_2bytes = ntohs(first_2bytes);
//Extract fin bit. Read bit 1.
final = static_cast<bool>((first_2bytes >> 15) & 0x1);
//Extract opcode. Read bits 4-7
opcode = static_cast<Opcode>((first_2bytes >> 8) & 0xF);
//Extract mask, if we're the server then messages should always be masked. Read bit 9
auto mask = static_cast<bool>((first_2bytes >> 7) & 0x1);
if(mask == socket_->is_client())
{
socket->disconnect();
return fr::Socket::Disconnected;
}
//Extract payload length. Read bits 9-15
auto payload_length = static_cast<uint64_t>(first_2bytes & 0x7F);
if(payload_length == 126) //Length is longer than 7 bit, so read 16bit length
{
uint16_t length;
status = socket->receive_all(&length, sizeof(length));
payload_length = ntohs(length);
if(status != fr::Socket::Success)
return status;
}
else if(payload_length == 127) //Length is longer than 16 bit, so read 64bit length
{
status = socket->receive_all(&payload_length, sizeof(payload_length));
payload_length = ntohll(payload_length);
if(status != fr::Socket::Success)
return status;
}
//Verify that payload length isn't too large
if(socket->get_max_receive_size() && payload_length > socket->get_max_receive_size())
{
socket->disconnect(); //We're forced to disconnect, otherwise we'll be out of sync with the server
return Socket::MaxPacketSizeExceeded;
}
//Read masking key if the mask bit is set
union
{
uint32_t mask_key;
char str_mask_key[4];
} mask_union{};
if(mask)
{
status = socket->receive_all(&mask_union.mask_key, 4);
if(status != fr::Socket::Success)
return status;
}
//Read payload
payload.resize(payload_length, '\0');
status = socket->receive_all(&payload[0], payload_length);
if(status != fr::Socket::Success)
return status;
//Decode the payload if the mask bit is set
if(mask)
{
for(size_t a = 0; a < payload_length; ++a)
{
payload[a] = payload[a] ^ mask_union.str_mask_key[a % 4];
}
}
return fr::Socket::Success;
}
}