diff --git a/CMakeLists.txt b/CMakeLists.txt index 1703890..755f4ce 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,16 +6,15 @@ set(FRNETLIB_LINK_LIBRARIES "") set( CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/cmake_modules) #User options -option(USE_SSL "Use SSL" OFF) +option(USE_SSL "Enable SSL support" OFF) +option(BUILD_EXAMPLES "Build frnetlib examples" ON) +option(BUILD_TESTS "Build frnetlib tests" ON) +option(BUILD_WEBSOCK "Enable WebSocket support" ON) set(FRNETLIB_BUILD_SHARED_LIBS false CACHE BOOL "Build shared library.") set(MAX_HTTP_HEADER_SIZE "0xC800" CACHE STRING "The maximum allowed HTTP header size in bytes") set(MAX_HTTP_BODY_SIZE "0xA00000" CACHE STRING "The maximum allowed HTTP body size in bytes") set(LISTEN_QUEUE_SIZE "64" CACHE STRING "The listen queue depth for fr::TcpListener/fr::SSLListener") -#Enable tests and examples by default -option(BUILD_EXAMPLES "Build frnetlib examples" ON) -option(BUILD_TESTS "Build frnetlib tests" ON) - #Configure defines based on user options add_definitions(-DMAX_HTTP_HEADER_SIZE=${MAX_HTTP_HEADER_SIZE}) add_definitions(-DMAX_HTTP_BODY_SIZE=${MAX_HTTP_BODY_SIZE}) @@ -27,6 +26,10 @@ if(USE_SSL) set(SOURCE_FILES ${SOURCE_FILES} src/SSLSocket.cpp include/frnetlib/SSLSocket.h src/SSLListener.cpp include/frnetlib/SSLListener.h include/frnetlib/SSLContext.h) endif() +if(BUILD_WEBSOCK) + set(SOURCE_FILES ${SOURCE_FILES} src/WebFrame.cpp include/frnetlib/WebFrame.h src/Sha1.cpp include/frnetlib/Sha1.h src/Base64.cpp include/frnetlib/Base64.h src/Sha1.cpp include/frnetlib/WebSocket.h) +endif() + add_definitions(-DNOMINMAX) add_definitions(-Dhtonf) add_definitions(-Dhtonll) diff --git a/include/frnetlib/Base64.h b/include/frnetlib/Base64.h new file mode 100644 index 0000000..20e79d5 --- /dev/null +++ b/include/frnetlib/Base64.h @@ -0,0 +1,25 @@ +// +// Created by fred on 01/03/18. +// + +#ifndef FRNETLIB_BASE64_H +#define FRNETLIB_BASE64_H +#include + +class Base64 +{ +public: + /*! + * Encodes a string into Base64 + * + * @param input The string to encode + * @return The resulting encoded string + */ + static std::string encode(const std::string &input); + + //There's no decode function at the moment. Maybe I'll write one eventually. Sorry. +private: + +}; + +#endif //FRNETLIB_BASE64_H diff --git a/include/frnetlib/SSLSocket.h b/include/frnetlib/SSLSocket.h index ec36229..3ac02a4 100644 --- a/include/frnetlib/SSLSocket.h +++ b/include/frnetlib/SSLSocket.h @@ -94,7 +94,7 @@ namespace fr /*! * Gets the underlying socket descriptor. * - * @return The socket's descriptor. + * @return The socket's descriptor. -1 indicates no connection. */ inline int32_t get_socket_descriptor() const override { diff --git a/include/frnetlib/Sha1.h b/include/frnetlib/Sha1.h new file mode 100644 index 0000000..f9c2d7e --- /dev/null +++ b/include/frnetlib/Sha1.h @@ -0,0 +1,43 @@ +// +// Created by fred on 01/03/18. +// + +#ifndef FRNETLIB_SHA1_H +#define FRNETLIB_SHA1_H + +#include +#include "frnetlib/NetworkEncoding.h" + +class Sha1 +{ +public: + Sha1(); + + /*! + * Sha1 hashes a string input and returns the raw digest + * + * @param input The string to hash + * @return The Sha1 digest converted to host endianness + */ + static std::string sha1_digest(const std::string &input) + { + Sha1 ctx; + ctx.update(input); + ctx.final(); + for(unsigned int &a : ctx.digest) + a = ntohl(a); + + return std::string((char*)&ctx.digest[0], sizeof(ctx.digest)); + } + +private: + void update(const std::string &s); + void update(std::istream &is); + void final(); + + uint32_t digest[5]; + std::string buffer; + uint64_t transforms; +}; + +#endif //FRNETLIB_SHA1_H diff --git a/include/frnetlib/Socket.h b/include/frnetlib/Socket.h index f68a113..f68921f 100644 --- a/include/frnetlib/Socket.h +++ b/include/frnetlib/Socket.h @@ -95,7 +95,7 @@ namespace fr virtual Status receive_raw(void *data, size_t data_size, size_t &received) = 0; /*! - * Gets the socket descriptor. + * Gets the underlying socket descriptor. * * @return The socket descriptor. */ @@ -175,13 +175,15 @@ namespace fr * * If a client attempts to send a packet larger than sz bytes, then * the client will be disconnected and an fr::Socket::MaxPacketSizeExceeded - * will be returned. Pass '0' to indicate no limit. The default value is 0. + * will be returned. Pass '0' to indicate no limit. * * This should be used to prevent potential abuse, as a client could say that * it's going to send a 200GiB packet, which would cause the Socket to try and * allocate that much memory to accommodate the data, which is most likely not * desirable. * + * By default, there is no limit (0) + * * @param sz The maximum number of bytes that may be received in an fr::Packet */ void set_max_receive_size(uint32_t sz); @@ -207,12 +209,12 @@ namespace fr /*! * Gets the max packet size. See set_max_packet_size - * for more information. - * + * for more information. If this returns 0, then + * there is no limit. * * @return The max packet size */ - inline uint32_t get_max_receive_size() + inline uint32_t get_max_receive_size() const { return max_receive_size; } @@ -222,7 +224,7 @@ namespace fr * * @return The string address */ - inline const std::string &get_remote_address() + inline const std::string &get_remote_address() const { return remote_address; } diff --git a/include/frnetlib/TcpSocket.h b/include/frnetlib/TcpSocket.h index d727a2f..d9f42ef 100644 --- a/include/frnetlib/TcpSocket.h +++ b/include/frnetlib/TcpSocket.h @@ -82,9 +82,9 @@ public: void set_descriptor(void *descriptor_data) override; /*! - * Gets the unerlying socket descriptor + * Gets the underlying socket descriptor * - * @return The socket descriptor + * @return The socket descriptor. -1 typically indicates no connection. */ int32_t get_socket_descriptor() const override; diff --git a/include/frnetlib/WebFrame.h b/include/frnetlib/WebFrame.h new file mode 100644 index 0000000..103d4e1 --- /dev/null +++ b/include/frnetlib/WebFrame.h @@ -0,0 +1,127 @@ +// +// Created by fred on 01/03/18. +// + +#ifndef FRNETLIB_WEBFRAME_H +#define FRNETLIB_WEBFRAME_H + + +#include "Sendable.h" + +namespace fr +{ + class WebFrame : public fr::Sendable + { + public: + enum Opcode : uint8_t + { + Continuation = 0, + Text = 1, + Binary = 2, + Disconnect = 8, + Ping = 9, + Pong = 10 + }; + + /*! + * Constructs the WebFrame. + * + * @param type The opcode type. See set_opcode. Text by default. + */ + explicit WebFrame(Opcode type = Text); + + /*! + * Get's the received payload data. (Data received). + * + * @return The payload + */ + inline const std::string get_payload() + { + return payload; + } + + /*! + * Sets the frame payload (data being sent) + * + * @param payload_ The payload to send + */ + inline void set_payload(std::string payload_) + { + payload = std::move(payload_); + } + + /*! + * Sets the WebFrame opcode (it's type) + * This should be Text for non-binary data. + * Or Binary for binary data. + * Or Continuation if it's the next part of a fragmented message (set_final() set to true if this is the last part) + * + * @param opcode_ The opcode to use. + */ + inline void set_opcode(Opcode opcode_) + { + opcode = opcode_; + } + + /*! + * Gets the WebFrame opcode (it's type) + * + * @return The opcode of the frame + */ + inline Opcode get_opcode() + { + return opcode; + } + + /*! + * Checks if the frame is the final part of the message + * + * @return True if this frame is the final part, false if there's more to come. + */ + inline bool is_final() + { + return final; + } + + /*! + * Sets whether the frame is the final part of a message or not. + * + * @param is_final True if this frame is the final part of a message. False if it's a fragment with more to come. + */ + inline void set_final(bool is_final = true) + { + final = is_final; + } + + protected: + /*! + * Overridable send, to allow + * custom types to be directly sent through + * sockets. + * + * @param socket The socket to send through + * @return Status indicating if the send succeeded or not. + */ + Socket::Status send(Socket *socket) override; + + /*! + * Overrideable receive, to allow + * custom types to be directly received through + * sockets. + * + * @note If the maximum message length is exceeded, then the connection will be closed + * @param socket The socket to send through + * @return Status indicating if the send succeeded or not. + */ + Socket::Status receive(Socket *socket) override; + + private: + std::string payload; + Opcode opcode; + bool final; + static uint32_t current_mask_key; + }; +} + + +#endif //FRNETLIB_WEBFRAME_H diff --git a/include/frnetlib/WebSocket.h b/include/frnetlib/WebSocket.h new file mode 100644 index 0000000..6ea8cef --- /dev/null +++ b/include/frnetlib/WebSocket.h @@ -0,0 +1,187 @@ +// +// Created by fred on 01/03/18. +// + +#ifndef FRNETLIB_WEBSOCKET_H +#define FRNETLIB_WEBSOCKET_H + +#include "frnetlib/Socket.h" +#include "frnetlib/HttpRequest.h" +#include "frnetlib/HttpResponse.h" +#include "Base64.h" +#include "Sha1.h" +#include "WebFrame.h" + +namespace fr +{ + class WebSocketBase + { + public: + /*! + * Checks if the socket is the client component or the server component + * + * @return True if it's the client component. False otherwise. + */ + virtual bool is_client()=0; + }; + + template + class WebSocket : public SocketType, public WebSocketBase + { + public: + WebSocket() + : is_the_client(true) + {} + + /*! + * Connects the WebSocket to a WebSocket server. Makes + * the connection using the underlying socket, and then handshakes. + * + * @param address The address of the socket to connect to + * @param port The port of the socket to connect to + * @param timeout The number of seconds to wait before timing the connection attempt out. Pass {} for default. + * @return A Socket::Status indicating the status of the operation (Success on success, an error type on failure). + */ + Socket::Status connect(const std::string &address, const std::string &port, std::chrono::seconds timeout) override + { + //Establish a connection using the parent class + Socket::Status status = SocketType::connect(address, port, timeout); + if(status != fr::Socket::Success) + return status; + + //Send an upgrade request header + std::string websocket_key = Base64::encode(std::to_string(std::time(nullptr))); + HttpRequest request; + request.header("sec-websocket-key") = websocket_key; + request.header("sec-websocket-version") = "13"; + request.header("connection") = "upgrade"; + request.header("upgrade") = "websocket"; + status = SocketType::send(request); + if(status != fr::Socket::Success) + return status; + + //Receive the response + HttpResponse response; + status = SocketType::receive(response); + if(status != fr::Socket::Success) + return status; + if(response.get_status() != Http::SwitchingProtocols) + { + disconnect(); + return Socket::HandshakeFailed; + } + + //Verify the sec-websocket-accept header + std::string derived_key = response.header("sec-websocket-accept"); + if(derived_key != Base64::encode(Sha1::sha1_digest(websocket_key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))) + { + disconnect(); + return Socket::HandshakeFailed; + } + + return fr::Socket::Success; + } + + /*! + * Sends a Disconnect Frame and then + * closes the connection. + */ + void disconnect() override + { + WebFrame frame; + frame.set_opcode(WebFrame::Disconnect); + if(SocketType::connected()) + SocketType::send(frame); + SocketType::close_socket(); + } + + /*! + * Sets the socket file descriptor. This is called by the Listener + * when accepting a connection, and so we can use the opportunity to + * handshake with the server. + * + * @param descriptor The socket descriptor. + */ + void set_descriptor(void *descriptor) override + { + SocketType::set_descriptor(descriptor); + if(!descriptor || SocketType::get_socket_descriptor() == -1) + return; + + is_the_client = false; //If we're accepting a connection then we're the server + + //Initialise connection, receive the handshake + HttpRequest request; + if(SocketType::receive(request) != fr::Socket::Success) + throw std::runtime_error("Failed to receive WebSock handshake"); + + if(request.header("Upgrade") != "websocket" || request.get_type() != fr::Http::Get) + throw std::runtime_error("Client isn't using the WebSock protocol"); + + //Calculate the derived key, then send back our response + std::string derived_key = Base64::encode(Sha1::sha1_digest(request.header("sec-websocket-key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")); + HttpResponse response; + response.set_status(fr::Http::SwitchingProtocols); + response.header("Upgrade") = "websocket"; + response.header("Connection") = "Upgrade"; + response.header("Sec-WebSocket-Accept") = derived_key; + SocketType::send(response); + } + + /*! + * Receive a Sendable object through the socket. + * Internally sends back a pong if a ping is received. + * + * @param obj The object to receive + * @return The status of the receive + */ + Socket::Status receive(Sendable &obj) override + { + WebFrame &frame = dynamic_cast(obj); + + //Try and receive a message. If it's a ping, then silently send back a pong. + Socket::Status status; + while(true) + { + status = SocketType::receive(obj); + if(status != Socket::Success) + return status; + + if(frame.get_opcode() == WebFrame::Ping) + { + frame.set_opcode(WebFrame::Pong); + status = SocketType::send(frame); + if(status != fr::Socket::Success) + return status; + continue; + } + break; + } + + //If it's a disconnect + if(frame.get_opcode() == WebFrame::Disconnect) + { + disconnect(); + return Socket::Disconnected; + } + return status; + } + + /*! + * Checks to see if the socket initialised the connection, or + * if it was accepted by a listener. + * + * @return True if accepted by a listener, false otherwise. + */ + inline bool is_client() override + { + return is_the_client; + } + + private: + bool is_the_client; + + }; +} + +#endif //FRNETLIB_WEBSOCKET_H diff --git a/src/Base64.cpp b/src/Base64.cpp new file mode 100644 index 0000000..f881e01 --- /dev/null +++ b/src/Base64.cpp @@ -0,0 +1,42 @@ +// +// Created by fred on 01/03/18. +// + +#include "frnetlib/Base64.h" + +std::string Base64::encode(const std::string &input) +{ + static const std::string base64_table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + std::string out; + out.reserve(input.size() + (input.size() / 3) + 1); + int a; + + //Do as many sets of 3 bytes as we can + for(a = 0; a < input.size() - 2; a += 3) + { + out += base64_table[(input[a] >> 2) & 0x3F]; //Store the first 6 bits of the first byte + out += base64_table[((input[a] & 0x3) << 4) | (input[a + 1] & 0xF0) >> 4]; //Store the last 2 bits of the first byte combined with the first 4 bits of the second byte + out += base64_table[((input[a + 1] & 0xF) << 2) | (input[a + 2] & 0xC0) >> 6]; //Store the last 4 bits of the second byte combined with the first 2 bits of the third byte + out += base64_table[input[a + 2] & 0x3F]; //Store the last 6 bits of the third byte + } + + //Check if there's a remainder + if(a < input.size()) + { + out += base64_table[(input[a] >> 2) & 0x3F]; //Store first 6 bits of the first byte + if(a == input.size() - 2) //There's 2 bytes left + { + out += base64_table[((input[a] & 0x3) << 4) | (input[a + 1] & 0xF0) >> 4]; //Store last 2 bits of first byte, and first 4 bits of second byte + out += base64_table[(input[a + 1] & 0xF) << 2]; //Finally store the last 4 bits of the second byte + out += "="; + } + else //There's 1 byte left + { + out += base64_table[(input[a] & 0x3) << 4]; //Store last 2 bits of first byte + out += "=="; + } + } + + return out; +} diff --git a/src/Sha1.cpp b/src/Sha1.cpp new file mode 100644 index 0000000..75ee569 --- /dev/null +++ b/src/Sha1.cpp @@ -0,0 +1,287 @@ +/* + Sha1.cpp - source code of + ============ + SHA-1 in C++ + ============ + 100% Public Domain. + Original C Code + -- Steve Reid + Small changes to fit into bglibs + -- Bruce Guenter + Translation to simpler C++ Code + -- Volker Grabsch + Safety fixes + -- Eugene Hopkinson + +-------------------------------------------------------------------- +| Sourced from: https://github.com/vog/sha1 | +| Altered for inclusion within frnetlib. | +-------------------------------------------------------------------- +*/ + +#include "frnetlib/Sha1.h" +#include +#include +#include + + +static const size_t BLOCK_INTS = 16; /* number of 32bit integers per Sha1 block */ +static const size_t BLOCK_BYTES = BLOCK_INTS * 4; + + +static void reset(uint32_t digest[], std::string &buffer, uint64_t &transforms) +{ + /* Sha1 initialization constants */ + digest[0] = 0x67452301; + digest[1] = 0xefcdab89; + digest[2] = 0x98badcfe; + digest[3] = 0x10325476; + digest[4] = 0xc3d2e1f0; + + /* Reset counters */ + buffer = ""; + transforms = 0; +} + + +static uint32_t rol(const uint32_t value, const size_t bits) +{ + return (value << bits) | (value >> (32 - bits)); +} + + +static uint32_t blk(const uint32_t block[BLOCK_INTS], const size_t i) +{ + return rol(block[(i+13)&15] ^ block[(i+8)&15] ^ block[(i+2)&15] ^ block[i], 1); +} + + +/* + * (R0+R1), R2, R3, R4 are the different operations used in Sha1 + */ + +static void R0(const uint32_t block[BLOCK_INTS], const uint32_t v, uint32_t &w, const uint32_t x, const uint32_t y, uint32_t &z, const size_t i) +{ + z += ((w&(x^y))^y) + block[i] + 0x5a827999 + rol(v, 5); + w = rol(w, 30); +} + + +static void R1(uint32_t block[BLOCK_INTS], const uint32_t v, uint32_t &w, const uint32_t x, const uint32_t y, uint32_t &z, const size_t i) +{ + block[i] = blk(block, i); + z += ((w&(x^y))^y) + block[i] + 0x5a827999 + rol(v, 5); + w = rol(w, 30); +} + + +static void R2(uint32_t block[BLOCK_INTS], const uint32_t v, uint32_t &w, const uint32_t x, const uint32_t y, uint32_t &z, const size_t i) +{ + block[i] = blk(block, i); + z += (w^x^y) + block[i] + 0x6ed9eba1 + rol(v, 5); + w = rol(w, 30); +} + + +static void R3(uint32_t block[BLOCK_INTS], const uint32_t v, uint32_t &w, const uint32_t x, const uint32_t y, uint32_t &z, const size_t i) +{ + block[i] = blk(block, i); + z += (((w|x)&y)|(w&x)) + block[i] + 0x8f1bbcdc + rol(v, 5); + w = rol(w, 30); +} + + +static void R4(uint32_t block[BLOCK_INTS], const uint32_t v, uint32_t &w, const uint32_t x, const uint32_t y, uint32_t &z, const size_t i) +{ + block[i] = blk(block, i); + z += (w^x^y) + block[i] + 0xca62c1d6 + rol(v, 5); + w = rol(w, 30); +} + + +/* + * Hash a single 512-bit block. This is the core of the algorithm. + */ + +static void transform(uint32_t digest[], uint32_t block[BLOCK_INTS], uint64_t &transforms) +{ + /* Copy digest[] to working vars */ + uint32_t a = digest[0]; + uint32_t b = digest[1]; + uint32_t c = digest[2]; + uint32_t d = digest[3]; + uint32_t e = digest[4]; + + /* 4 rounds of 20 operations each. Loop unrolled. */ + R0(block, a, b, c, d, e, 0); + R0(block, e, a, b, c, d, 1); + R0(block, d, e, a, b, c, 2); + R0(block, c, d, e, a, b, 3); + R0(block, b, c, d, e, a, 4); + R0(block, a, b, c, d, e, 5); + R0(block, e, a, b, c, d, 6); + R0(block, d, e, a, b, c, 7); + R0(block, c, d, e, a, b, 8); + R0(block, b, c, d, e, a, 9); + R0(block, a, b, c, d, e, 10); + R0(block, e, a, b, c, d, 11); + R0(block, d, e, a, b, c, 12); + R0(block, c, d, e, a, b, 13); + R0(block, b, c, d, e, a, 14); + R0(block, a, b, c, d, e, 15); + R1(block, e, a, b, c, d, 0); + R1(block, d, e, a, b, c, 1); + R1(block, c, d, e, a, b, 2); + R1(block, b, c, d, e, a, 3); + R2(block, a, b, c, d, e, 4); + R2(block, e, a, b, c, d, 5); + R2(block, d, e, a, b, c, 6); + R2(block, c, d, e, a, b, 7); + R2(block, b, c, d, e, a, 8); + R2(block, a, b, c, d, e, 9); + R2(block, e, a, b, c, d, 10); + R2(block, d, e, a, b, c, 11); + R2(block, c, d, e, a, b, 12); + R2(block, b, c, d, e, a, 13); + R2(block, a, b, c, d, e, 14); + R2(block, e, a, b, c, d, 15); + R2(block, d, e, a, b, c, 0); + R2(block, c, d, e, a, b, 1); + R2(block, b, c, d, e, a, 2); + R2(block, a, b, c, d, e, 3); + R2(block, e, a, b, c, d, 4); + R2(block, d, e, a, b, c, 5); + R2(block, c, d, e, a, b, 6); + R2(block, b, c, d, e, a, 7); + R3(block, a, b, c, d, e, 8); + R3(block, e, a, b, c, d, 9); + R3(block, d, e, a, b, c, 10); + R3(block, c, d, e, a, b, 11); + R3(block, b, c, d, e, a, 12); + R3(block, a, b, c, d, e, 13); + R3(block, e, a, b, c, d, 14); + R3(block, d, e, a, b, c, 15); + R3(block, c, d, e, a, b, 0); + R3(block, b, c, d, e, a, 1); + R3(block, a, b, c, d, e, 2); + R3(block, e, a, b, c, d, 3); + R3(block, d, e, a, b, c, 4); + R3(block, c, d, e, a, b, 5); + R3(block, b, c, d, e, a, 6); + R3(block, a, b, c, d, e, 7); + R3(block, e, a, b, c, d, 8); + R3(block, d, e, a, b, c, 9); + R3(block, c, d, e, a, b, 10); + R3(block, b, c, d, e, a, 11); + R4(block, a, b, c, d, e, 12); + R4(block, e, a, b, c, d, 13); + R4(block, d, e, a, b, c, 14); + R4(block, c, d, e, a, b, 15); + R4(block, b, c, d, e, a, 0); + R4(block, a, b, c, d, e, 1); + R4(block, e, a, b, c, d, 2); + R4(block, d, e, a, b, c, 3); + R4(block, c, d, e, a, b, 4); + R4(block, b, c, d, e, a, 5); + R4(block, a, b, c, d, e, 6); + R4(block, e, a, b, c, d, 7); + R4(block, d, e, a, b, c, 8); + R4(block, c, d, e, a, b, 9); + R4(block, b, c, d, e, a, 10); + R4(block, a, b, c, d, e, 11); + R4(block, e, a, b, c, d, 12); + R4(block, d, e, a, b, c, 13); + R4(block, c, d, e, a, b, 14); + R4(block, b, c, d, e, a, 15); + + /* Add the working vars back into digest[] */ + digest[0] += a; + digest[1] += b; + digest[2] += c; + digest[3] += d; + digest[4] += e; + + /* Count the number of transformations */ + transforms++; +} + + +static void buffer_to_block(const std::string &buffer, uint32_t block[BLOCK_INTS]) +{ + /* Convert the std::string (byte buffer) to a uint32_t array (MSB) */ + for (size_t i = 0; i < BLOCK_INTS; i++) + { + block[i] = (buffer[4*i+3] & 0xff) + | (buffer[4*i+2] & 0xff)<<8 + | (buffer[4*i+1] & 0xff)<<16 + | (buffer[4*i+0] & 0xff)<<24; + } +} + + +Sha1::Sha1() +{ + reset(digest, buffer, transforms); +} + + +void Sha1::update(const std::string &s) +{ + std::istringstream is(s); + update(is); +} + + +void Sha1::update(std::istream &is) +{ + while (true) + { + char sbuf[BLOCK_BYTES]; + is.read(sbuf, BLOCK_BYTES - buffer.size()); + buffer.append(sbuf, is.gcount()); + if (buffer.size() != BLOCK_BYTES) + { + return; + } + uint32_t block[BLOCK_INTS]; + buffer_to_block(buffer, block); + transform(digest, block, transforms); + buffer.clear(); + } +} + + +/* + * Add padding and finish up + */ + +void Sha1::final() +{ + /* Total number of hashed bits */ + uint64_t total_bits = (transforms*BLOCK_BYTES + buffer.size()) * 8; + + /* Padding */ + buffer += 0x80; + size_t orig_size = buffer.size(); + while (buffer.size() < BLOCK_BYTES) + { + buffer += (char)0x00; + } + + uint32_t block[BLOCK_INTS]; + buffer_to_block(buffer, block); + + if (orig_size > BLOCK_BYTES - 8) + { + transform(digest, block, transforms); + for (size_t i = 0; i < BLOCK_INTS - 2; i++) + { + block[i] = 0; + } + } + + /* Append total_bits, split this uint64_t into two uint32_t */ + block[BLOCK_INTS - 1] = total_bits; + block[BLOCK_INTS - 2] = (total_bits >> 32); + transform(digest, block, transforms); +} diff --git a/src/WebFrame.cpp b/src/WebFrame.cpp new file mode 100644 index 0000000..ea44aa3 --- /dev/null +++ b/src/WebFrame.cpp @@ -0,0 +1,169 @@ +// +// Created by fred on 01/03/18. +// + +#include "frnetlib/WebFrame.h" +#include "frnetlib/WebSocket.h" + +namespace fr +{ + uint32_t WebFrame::current_mask_key = static_cast(std::time(nullptr)); + + WebFrame::WebFrame(WebFrame::Opcode type) + : opcode(type), + final(true) + { + + } + + fr::Socket::Status WebFrame::send(Socket *socket_) + { + auto *socket = dynamic_cast(socket_); + if(!socket) + return Socket::Error; + + uint16_t first_2bytes = 0; + std::string buffer; + + //Set fin bit. Bit 1. + first_2bytes |= final << 15; + + //Set opcode bit + first_2bytes |= opcode << 8; + + //Set mask bit (dependent on is_client flag, only client -> server messages are masked) + first_2bytes |= socket->is_client() << 7; + + //Set payload length + if(payload.size() <= 125) + first_2bytes |= payload.size(); + else + first_2bytes |= (payload.size() < std::numeric_limits::max()) ? 126 : 127; + first_2bytes = htons(first_2bytes); + buffer.append((char*)&first_2bytes, sizeof(first_2bytes)); + + //Set additional payload bits if large enough + if(payload.size() > 125) + { + if(payload.size() < std::numeric_limits::max()) //16bit length + { + auto len = htons(static_cast(payload.size())); + buffer.append((char*)&len, sizeof(len)); + } + else //64bit length + { + uint64_t len = htonll(payload.size()); + buffer.append((char*)&len, sizeof(len)); + } + } + + //Add a masking key if we're the client + if(socket->is_client()) + { + union + { + uint32_t mask_key; + char str_mask_key[4]; + } mask_union{}; + + mask_union.mask_key = ++current_mask_key; + buffer.append((char*)&mask_union.mask_key, sizeof(mask_union.mask_key)); + + //Encode the payload using the mask key + for(size_t a = 0; a < payload.size(); ++a) + { + payload[a] = payload[a] ^ mask_union.str_mask_key[a % 4]; + } + } + + buffer.append(payload); + return socket_->send_raw(buffer.c_str(), buffer.size()); + } + + Socket::Status WebFrame::receive(Socket *socket) + { + auto *socket_ = dynamic_cast(socket); + if(!socket_) + return Socket::Error; + payload.clear(); + Socket::Status status; + + uint16_t first_2bytes; + status = socket->receive_all(&first_2bytes, sizeof(first_2bytes)); + if(status != fr::Socket::Success) + return status; + first_2bytes = ntohs(first_2bytes); + + //Extract fin bit. Read bit 1. + final = static_cast((first_2bytes >> 15) & 0x1); + + //Extract opcode. Read bits 4-7 + opcode = static_cast((first_2bytes >> 8) & 0xF); + + //Extract mask, if we're the server then messages should always be masked. Read bit 9 + auto mask = static_cast((first_2bytes >> 7) & 0x1); + if(mask == socket_->is_client()) + { + socket->disconnect(); + return fr::Socket::Disconnected; + } + + + //Extract payload length. Read bits 9-15 + auto payload_length = static_cast(first_2bytes & 0x7F); + if(payload_length == 126) //Length is longer than 7 bit, so read 16bit length + { + uint16_t length; + status = socket->receive_all(&length, sizeof(length)); + payload_length = ntohs(length); + if(status != fr::Socket::Success) + return status; + } + else if(payload_length == 127) //Length is longer than 16 bit, so read 64bit length + { + status = socket->receive_all(&payload_length, sizeof(payload_length)); + payload_length = ntohll(payload_length); + if(status != fr::Socket::Success) + return status; + } + + //Verify that payload length isn't too large + if(socket->get_max_receive_size() && payload_length > socket->get_max_receive_size()) + { + socket->disconnect(); //We're forced to disconnect, otherwise we'll be out of sync with the server + return Socket::MaxPacketSizeExceeded; + } + + //Read masking key if the mask bit is set + union + { + uint32_t mask_key; + char str_mask_key[4]; + } mask_union{}; + if(mask) + { + status = socket->receive_all(&mask_union.mask_key, 4); + if(status != fr::Socket::Success) + return status; + } + + //Read payload + payload.resize(payload_length, '\0'); + status = socket->receive_all(&payload[0], payload_length); + if(status != fr::Socket::Success) + return status; + + //Decode the payload if the mask bit is set + if(mask) + { + for(size_t a = 0; a < payload_length; ++a) + { + payload[a] = payload[a] ^ mask_union.str_mask_key[a % 4]; + } + + } + return fr::Socket::Success; + } + +} +