Added fr::Socket::set_max_packet_size

Can be used to limit the size of packets being received, to prevent malicious attacks.
This commit is contained in:
Unknown 2017-06-03 14:39:49 +01:00
parent 71874837a0
commit eec983c8b7
2 changed files with 33 additions and 2 deletions

View File

@ -26,6 +26,7 @@ namespace fr
ConnectionFailed = 7,
HandshakeFailed = 8,
VerificationFailed = 9,
MaxPacketSizeExceeded = 10,
};
enum IP
@ -158,6 +159,22 @@ namespace fr
* @param version Should IPv4, IPv6 be used, or any?
*/
void set_inet_version(IP version);
/*!
* Sets the maximum fr::Packet size that may be received by the socket.
*
* If a client attempts to send a packet larger than sz bytes, then
* the client will be disconnected and an fr::Socket::MaxPacketSizeExceeded
* will be returned. Pass '0' to indicate no limit. The default value is 0.
*
* This should be used to prevent potential abuse, as a client could say that
* it's going to send a 200GiB packet, which would cause the Socket to try and
* allocate that much memory to accommodate the data, which is most likely not
* desirable.
*
* @param sz The maximum number of bytes that may be received in an fr::Packet
*/
void set_max_packet_size(uint32_t sz);
protected:
/*!
@ -171,6 +188,7 @@ namespace fr
std::mutex outbound_mutex;
std::mutex inbound_mutex;
int ai_family;
uint32_t max_packet_size;
#ifdef _WIN32
static WSADATA wsaData;
#endif // _WIN32

View File

@ -16,7 +16,8 @@ namespace fr
Socket::Socket() noexcept
: is_blocking(true),
ai_family(AF_UNSPEC)
ai_family(AF_UNSPEC),
max_packet_size(0)
{
if(instance_count == 0)
{
@ -68,6 +69,13 @@ namespace fr
return status;
packet_length = ntohl(packet_length);
//Check that packet_length doesn't exceed the limit, if any
if(max_packet_size && max_packet_size > max_packet_size)
{
close_socket();
return fr::Socket::MaxPacketSizeExceeded;
}
//Now we've got the length, read the rest of the data in
packet.buffer.resize(packet_length + PACKET_HEADER_LENGTH);
status = receive_all(&packet.buffer[PACKET_HEADER_LENGTH], packet_length);
@ -83,7 +91,7 @@ namespace fr
if(!connected())
return Socket::Disconnected;
int32_t bytes_remaining = buffer_size;
int32_t bytes_remaining = (int32_t) buffer_size;
size_t bytes_read = 0;
while(bytes_remaining > 0)
{
@ -128,4 +136,9 @@ namespace fr
throw std::logic_error("Unknown Socket::IP value passed to set_inet_version()");
}
}
void Socket::set_max_packet_size(uint32_t sz)
{
max_packet_size = sz;
}
}