diff --git a/CMakeLists.txt b/CMakeLists.txt index a5a8d53..a326554 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,6 +25,7 @@ if(USE_SSL) FIND_PACKAGE(MBEDTLS) INCLUDE_DIRECTORIES(${MBEDTLS_INCLUDE_DIR}) set(SOURCE_FILES ${SOURCE_FILES} src/SSLSocket.cpp include/frnetlib/SSLSocket.h src/SSLListener.cpp include/frnetlib/SSLListener.h include/frnetlib/SSLContext.h) + ADD_DEFINITIONS(-DUSE_SSL) endif() if(BUILD_WEBSOCK) diff --git a/include/frnetlib/Socket.h b/include/frnetlib/Socket.h index 0390bb2..85a4d41 100644 --- a/include/frnetlib/Socket.h +++ b/include/frnetlib/Socket.h @@ -34,6 +34,11 @@ namespace fr ParseError = 12, HttpHeaderTooBig = 13, HttpBodyTooBig = 14, + AddressLookupFailure = 15, + SendError = 16, + ReceiveError = 17, + AcceptError = 18, + SSLError = 19, //Remember to update status_to_string if more are added }; @@ -51,6 +56,17 @@ namespace fr void operator=(const Socket &) =delete; void operator=(Socket &&) =delete; + /*! + * Converts an fr::Socket::Status value to a printable string + * + * Throws an std::logic_error if status is out of range. + * + * @note This should be called immediately after the error, as errno is used to help generate the string. + * @param status Status value to convert + * @return A string form version + */ + static std::string status_to_string(fr::Socket::Status status); + /*! * Connects the socket to an address. * @@ -155,16 +171,6 @@ namespace fr */ void set_inet_version(IP version); - /*! - * Converts an fr::Socket::Status value to a printable string - * - * Throws an std::logic_error if status is out of range. - * - * @param status Status value to convert - * @return A string form version - */ - static const std::string &status_to_string(fr::Socket::Status status); - /*! * Ends, and closes the connection. * There is a distinction between 'disconnect' and 'close_socket', diff --git a/include/frnetlib/WebSocket.h b/include/frnetlib/WebSocket.h index 7aa2bde..ccddc5b 100644 --- a/include/frnetlib/WebSocket.h +++ b/include/frnetlib/WebSocket.h @@ -69,6 +69,7 @@ namespace fr if(response.get_status() != Http::SwitchingProtocols) { disconnect(); + errno = EPROTO; return Socket::HandshakeFailed; } @@ -77,6 +78,7 @@ namespace fr if(derived_key != Base64::encode(Sha1::sha1_digest(websocket_key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))) { disconnect(); + errno = EPROTO; return Socket::HandshakeFailed; } diff --git a/src/SSLListener.cpp b/src/SSLListener.cpp index d46cd9f..f76fe4e 100644 --- a/src/SSLListener.cpp +++ b/src/SSLListener.cpp @@ -109,7 +109,7 @@ namespace fr if((error = mbedtls_net_accept(&listen_fd, client_fd.get(), client_ip, sizeof(client_ip), &ip_len)) != 0) { free_contexts(); - return Socket::Error; + return Socket::AcceptError; } @@ -120,7 +120,8 @@ namespace fr if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE) { free_contexts(); - return Socket::Status::HandshakeFailed; + errno = error; + return Socket::Status::SSLError; } } diff --git a/src/SSLSocket.cpp b/src/SSLSocket.cpp index 560dd4b..e9c8694 100644 --- a/src/SSLSocket.cpp +++ b/src/SSLSocket.cpp @@ -56,7 +56,8 @@ namespace fr } else if(response < 0) { - return Socket::Status::Error; + errno = response; + return Socket::Status::SSLError; } } @@ -76,7 +77,8 @@ namespace fr return Socket::Status::WouldBlock; } - return Socket::Status::Error; + errno = static_cast(status); + return Socket::Status::SSLError; } } else @@ -95,7 +97,8 @@ namespace fr continue; //try again, interrupted before anything could be received } - return Socket::Status::Error; + errno = static_cast(status); + return Socket::Status::SSLError; } break; } while(true); @@ -131,7 +134,8 @@ namespace fr int error = 0; if((error = mbedtls_ssl_config_defaults(&conf, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT)) != 0) { - return Socket::Status::Error; + errno = error; + return Socket::Status::SSLError; } mbedtls_ssl_conf_authmode(&conf, should_verify ? MBEDTLS_SSL_VERIFY_REQUIRED : MBEDTLS_SSL_VERIFY_NONE); @@ -140,12 +144,14 @@ namespace fr if((error = mbedtls_ssl_setup(ssl.get(), &conf)) != 0) { - return Socket::Status::Error; + errno = error; + return Socket::Status::SSLError; } if((error = mbedtls_ssl_set_hostname(ssl.get(), address.c_str())) != 0) { - return Socket::Status::Error; + errno = error; + return Socket::Status::SSLError; } mbedtls_ssl_set_bio(ssl.get(), ssl_socket_descriptor.get(), mbedtls_net_send, mbedtls_net_recv, nullptr); @@ -155,16 +161,14 @@ namespace fr { if(error != MBEDTLS_ERR_SSL_WANT_READ && error != MBEDTLS_ERR_SSL_WANT_WRITE) { - return Socket::Status::HandshakeFailed; + errno = error; + return Socket::Status::SSLError; } } //Verify server certificate if(should_verify && ((flags = mbedtls_ssl_get_verify_result(ssl.get())) != 0)) { - char verify_buffer[512]; - mbedtls_x509_crt_verify_info(verify_buffer, sizeof(verify_buffer), " ! ", flags); - return Socket::Status::VerificationFailed; } diff --git a/src/Socket.cpp b/src/Socket.cpp index 491615f..9b93450 100644 --- a/src/Socket.cpp +++ b/src/Socket.cpp @@ -6,6 +6,9 @@ #include #include #include +#ifdef USE_SSL +#include +#endif #include "frnetlib/NetworkEncoding.h" #include "frnetlib/Socket.h" #include "frnetlib/Sendable.h" @@ -83,29 +86,81 @@ namespace fr } } - const std::string &Socket::status_to_string(fr::Socket::Status status) + std::string Socket::status_to_string(fr::Socket::Status status) { - static std::vector map = { - "Unknown", - "Success", - "Listen Failed", - "Bind Failed", - "Disconnected", - "Error", - "Would Block", - "Connection Failed", - "Handshake Failed", - "Verification Failed", - "Max packet size exceeded", - "Not enough data", - "Parse error", - "HTTP header too big", - "HTTP body too big" +#ifdef _WIN32 + auto wsa_err_to_str = [](int err) -> std::string { + std::string buff(255, '\0'); + auto len = FormatMessage (FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, err, MAKELANGID (LANG_NEUTRAL, SUBLANG_DEFAULT), msgbuf, buff.size(), NULL); + if(len == 0) + return "Unknown"; + buff.resize(len); + return buff; }; + #define ERR_STR wsa_err_to_str(WSAGetLastError()) +#else + #define ERR_STR strerror(errno) +#endif - if(status < 0 || status > map.size()) - throw std::logic_error("Socket::status_to_string(): Invalid status value " + std::to_string(status)); - return map[status]; + switch(status) + { + case Unknown: + return "Unknown"; + case Success: + return "Success"; + case ListenFailed: + return std::string("Listen Failed (").append(ERR_STR).append(")"); + case BindFailed: + return std::string("Bind Failed (").append(ERR_STR).append(")"); + case Disconnected: + return "The Socket Is Not Connected"; + case Error: + return "Error"; + case WouldBlock: + return "Would Block"; + case ConnectionFailed: + return "Connection Failed"; + case HandshakeFailed: + return "Handshake Failed"; + case VerificationFailed: + return "Verification Failed"; + case MaxPacketSizeExceeded: + return "Max Packet Size Exceeded"; + case NotEnoughData: + return "Not Enough Data"; + case ParseError: + return "Parse Error"; + case HttpHeaderTooBig: + return "HTTP Header Too Big"; + case HttpBodyTooBig: + return "HTTP Body Too Big"; + case AddressLookupFailure: +#ifdef _WIN32 + return std::string("Address Lookup Failure (").append(wsa_err_to_str(WSAGetLastError())).append(")"); +#else + return std::string("Address Lookup Failure (").append(gai_strerror(errno)).append(")"); +#endif + case SendError: + return std::string("Send Error (").append(ERR_STR).append(")"); + case ReceiveError: + return std::string("Receive Error (").append(ERR_STR).append(")"); + case AcceptError: + return std::string("Accept Error (").append(ERR_STR).append(")"); + case SSLError: + { +#ifdef USE_SSL + char buff[256] = {0}; + mbedtls_strerror(errno, buff, sizeof(buff)); + return std::string("SSL Error (").append(buff).append(")"); +#else + return "Generic SSL Error"; +#endif + } + default: + return "Unknown"; + } + + return "Internal Error"; } void Socket::disconnect() diff --git a/src/TcpListener.cpp b/src/TcpListener.cpp index 3573694..51d3aa1 100644 --- a/src/TcpListener.cpp +++ b/src/TcpListener.cpp @@ -36,8 +36,9 @@ namespace fr if(getaddrinfo(nullptr, port.c_str(), &hints, &info) != 0) { - return Socket::Status::Unknown; + return Socket::Status::AddressLookupFailure; } + //Try each of the results until we listen successfully addrinfo *c = nullptr; for(c = info; c != nullptr; c = c->ai_next) @@ -107,7 +108,7 @@ namespace fr socklen_t client_addr_len = sizeof client_addr; client_descriptor = ::accept(socket_descriptor, (sockaddr*)&client_addr, &client_addr_len); if(client_descriptor == SOCKET_ERROR) - return Socket::Unknown; + return Socket::AcceptError; //Get printable address. If we failed then set it as just 'unknown' int err = getnameinfo((sockaddr*)&client_addr, client_addr_len, client_printable_addr, sizeof(client_printable_addr), nullptr, 0, NI_NUMERICHOST); diff --git a/src/TcpSocket.cpp b/src/TcpSocket.cpp index 3e3befc..7a04f6a 100644 --- a/src/TcpSocket.cpp +++ b/src/TcpSocket.cpp @@ -35,7 +35,7 @@ namespace fr } else if(errno != EWOULDBLOCK && errno != EAGAIN) //Don't exit if the socket just couldn't block { - return Socket::Status::Error; + return Socket::Status::SendError; } } return Socket::Status::Success; @@ -67,7 +67,7 @@ namespace fr continue; //try again, interrupted before anything could be received } - return Socket::Status::Error; + return Socket::Status::ReceiveError; } break; } while(true); @@ -104,9 +104,10 @@ namespace fr hints.ai_flags = AI_PASSIVE; //Have the IP filled in for us //Query remote address information - if(getaddrinfo(address.c_str(), port.c_str(), &hints, &info) != 0) + if((ret = getaddrinfo(address.c_str(), port.c_str(), &hints, &info)) != 0) { - return Socket::Status::Error; + errno = ret; + return Socket::Status::AddressLookupFailure; } //Try to connect to results returned by getaddrinfo until we succeed/run out of things