Added fr::Socket::set_max_packet_size

Can be used to limit the size of packets being received, to prevent malicious attacks.
This commit is contained in:
Unknown 2017-06-03 14:39:49 +01:00
parent 71874837a0
commit eec983c8b7
2 changed files with 33 additions and 2 deletions

View File

@ -26,6 +26,7 @@ namespace fr
ConnectionFailed = 7, ConnectionFailed = 7,
HandshakeFailed = 8, HandshakeFailed = 8,
VerificationFailed = 9, VerificationFailed = 9,
MaxPacketSizeExceeded = 10,
}; };
enum IP enum IP
@ -158,6 +159,22 @@ namespace fr
* @param version Should IPv4, IPv6 be used, or any? * @param version Should IPv4, IPv6 be used, or any?
*/ */
void set_inet_version(IP version); void set_inet_version(IP version);
/*!
* Sets the maximum fr::Packet size that may be received by the socket.
*
* If a client attempts to send a packet larger than sz bytes, then
* the client will be disconnected and an fr::Socket::MaxPacketSizeExceeded
* will be returned. Pass '0' to indicate no limit. The default value is 0.
*
* This should be used to prevent potential abuse, as a client could say that
* it's going to send a 200GiB packet, which would cause the Socket to try and
* allocate that much memory to accommodate the data, which is most likely not
* desirable.
*
* @param sz The maximum number of bytes that may be received in an fr::Packet
*/
void set_max_packet_size(uint32_t sz);
protected: protected:
/*! /*!
@ -171,6 +188,7 @@ namespace fr
std::mutex outbound_mutex; std::mutex outbound_mutex;
std::mutex inbound_mutex; std::mutex inbound_mutex;
int ai_family; int ai_family;
uint32_t max_packet_size;
#ifdef _WIN32 #ifdef _WIN32
static WSADATA wsaData; static WSADATA wsaData;
#endif // _WIN32 #endif // _WIN32

View File

@ -16,7 +16,8 @@ namespace fr
Socket::Socket() noexcept Socket::Socket() noexcept
: is_blocking(true), : is_blocking(true),
ai_family(AF_UNSPEC) ai_family(AF_UNSPEC),
max_packet_size(0)
{ {
if(instance_count == 0) if(instance_count == 0)
{ {
@ -68,6 +69,13 @@ namespace fr
return status; return status;
packet_length = ntohl(packet_length); packet_length = ntohl(packet_length);
//Check that packet_length doesn't exceed the limit, if any
if(max_packet_size && max_packet_size > max_packet_size)
{
close_socket();
return fr::Socket::MaxPacketSizeExceeded;
}
//Now we've got the length, read the rest of the data in //Now we've got the length, read the rest of the data in
packet.buffer.resize(packet_length + PACKET_HEADER_LENGTH); packet.buffer.resize(packet_length + PACKET_HEADER_LENGTH);
status = receive_all(&packet.buffer[PACKET_HEADER_LENGTH], packet_length); status = receive_all(&packet.buffer[PACKET_HEADER_LENGTH], packet_length);
@ -83,7 +91,7 @@ namespace fr
if(!connected()) if(!connected())
return Socket::Disconnected; return Socket::Disconnected;
int32_t bytes_remaining = buffer_size; int32_t bytes_remaining = (int32_t) buffer_size;
size_t bytes_read = 0; size_t bytes_read = 0;
while(bytes_remaining > 0) while(bytes_remaining > 0)
{ {
@ -128,4 +136,9 @@ namespace fr
throw std::logic_error("Unknown Socket::IP value passed to set_inet_version()"); throw std::logic_error("Unknown Socket::IP value passed to set_inet_version()");
} }
} }
void Socket::set_max_packet_size(uint32_t sz)
{
max_packet_size = sz;
}
} }