Explorar o código

Fix #1601 (#2167)

* Fix #1601

* clang-format

* Fix Windows problem

* Use GetAddrInfoEx on Windows

* Fix Windows problem

* Add getaddrinfo_a

* clang-format

* Adjust Benchmark Test

* Test

* Fix Bench test

* Fix build error

* Fix build error

* Fix Makefile

* Fix build error

* Fix buid error
yhirose hai 7 meses
pai
achega
ea850cbfa7
Modificáronse 4 ficheiros con 356 adicións e 31 borrados
  1. 337 9
      httplib.h
  2. 11 3
      test/Makefile
  3. 1 1
      test/fuzzing/Makefile
  4. 7 18
      test/test.cc

+ 337 - 9
httplib.h

@@ -278,6 +278,14 @@ using socket_t = int;
 #include <unordered_set>
 #include <utility>
 
+#if defined(__APPLE__)
+#include <TargetConditionals.h>
+#if TARGET_OS_OSX || TARGET_OS_IPHONE
+#include <CFNetwork/CFHost.h>
+#include <CoreFoundation/CoreFoundation.h>
+#endif
+#endif
+
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
 #ifdef _WIN32
 #include <wincrypt.h>
@@ -292,13 +300,16 @@ using socket_t = int;
 #ifdef _MSC_VER
 #pragma comment(lib, "crypt32.lib")
 #endif
-#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && defined(__APPLE__)
+#endif // _WIN32
+
+#if defined(__APPLE__)
+#if defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
 #include <TargetConditionals.h>
 #if TARGET_OS_OSX
-#include <CoreFoundation/CoreFoundation.h>
 #include <Security/Security.h>
 #endif // TARGET_OS_OSX
-#endif // _WIN32
+#endif // CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN
+#endif // ___APPLE__
 
 #include <openssl/err.h>
 #include <openssl/evp.h>
@@ -321,7 +332,7 @@ using socket_t = int;
 #error Sorry, OpenSSL versions prior to 3.0.0 are not supported
 #endif
 
-#endif
+#endif // CPPHTTPLIB_OPENSSL_SUPPORT
 
 #ifdef CPPHTTPLIB_ZLIB_SUPPORT
 #include <zlib.h>
@@ -3369,11 +3380,323 @@ unescape_abstract_namespace_unix_domain(const std::string &s) {
   return s;
 }
 
+inline int getaddrinfo_with_timeout(const char *node, const char *service,
+                                    const struct addrinfo *hints,
+                                    struct addrinfo **res, time_t timeout_sec) {
+  if (timeout_sec <= 0) {
+    // No timeout specified, use standard getaddrinfo
+    return getaddrinfo(node, service, hints, res);
+  }
+
+#ifdef _WIN32
+  // Windows-specific implementation using GetAddrInfoEx with overlapped I/O
+  OVERLAPPED overlapped = {0};
+  HANDLE event = CreateEventW(nullptr, TRUE, FALSE, nullptr);
+  if (!event) { return EAI_FAIL; }
+
+  overlapped.hEvent = event;
+
+  PADDRINFOEXW result_addrinfo = nullptr;
+  HANDLE cancel_handle = nullptr;
+
+  ADDRINFOEXW hints_ex = {0};
+  if (hints) {
+    hints_ex.ai_flags = hints->ai_flags;
+    hints_ex.ai_family = hints->ai_family;
+    hints_ex.ai_socktype = hints->ai_socktype;
+    hints_ex.ai_protocol = hints->ai_protocol;
+  }
+
+  auto wnode = u8string_to_wstring(node);
+  auto wservice = u8string_to_wstring(service);
+
+  auto ret = ::GetAddrInfoExW(wnode.data(), wservice.data(), NS_DNS, nullptr,
+                              hints ? &hints_ex : nullptr, &result_addrinfo,
+                              nullptr, &overlapped, nullptr, &cancel_handle);
+
+  if (ret == WSA_IO_PENDING) {
+    auto wait_result =
+        ::WaitForSingleObject(event, static_cast<DWORD>(timeout_sec * 1000));
+    if (wait_result == WAIT_TIMEOUT) {
+      if (cancel_handle) { ::GetAddrInfoExCancel(&cancel_handle); }
+      ::CloseHandle(event);
+      return EAI_AGAIN;
+    }
+
+    DWORD bytes_returned;
+    if (!::GetOverlappedResult((HANDLE)INVALID_SOCKET, &overlapped,
+                               &bytes_returned, FALSE)) {
+      ::CloseHandle(event);
+      return ::WSAGetLastError();
+    }
+  }
+
+  ::CloseHandle(event);
+
+  if (ret == NO_ERROR || ret == WSA_IO_PENDING) {
+    *res = reinterpret_cast<struct addrinfo *>(result_addrinfo);
+    return 0;
+  }
+
+  return ret;
+#elif defined(__APPLE__)
+  // macOS implementation using CFHost API for asynchronous DNS resolution
+  CFStringRef hostname_ref = CFStringCreateWithCString(
+      kCFAllocatorDefault, node, kCFStringEncodingUTF8);
+  if (!hostname_ref) { return EAI_MEMORY; }
+
+  CFHostRef host_ref = CFHostCreateWithName(kCFAllocatorDefault, hostname_ref);
+  CFRelease(hostname_ref);
+  if (!host_ref) { return EAI_MEMORY; }
+
+  // Set up context for callback
+  struct CFHostContext {
+    bool completed = false;
+    bool success = false;
+    CFArrayRef addresses = nullptr;
+    std::mutex mutex;
+    std::condition_variable cv;
+  } context;
+
+  CFHostClientContext client_context;
+  memset(&client_context, 0, sizeof(client_context));
+  client_context.info = &context;
+
+  // Set callback
+  auto callback = [](CFHostRef theHost, CFHostInfoType /*typeInfo*/,
+                     const CFStreamError *error, void *info) {
+    auto ctx = static_cast<CFHostContext *>(info);
+    std::lock_guard<std::mutex> lock(ctx->mutex);
+
+    if (error && error->error != 0) {
+      ctx->success = false;
+    } else {
+      Boolean hasBeenResolved;
+      ctx->addresses = CFHostGetAddressing(theHost, &hasBeenResolved);
+      if (ctx->addresses && hasBeenResolved) {
+        CFRetain(ctx->addresses);
+        ctx->success = true;
+      } else {
+        ctx->success = false;
+      }
+    }
+    ctx->completed = true;
+    ctx->cv.notify_one();
+  };
+
+  if (!CFHostSetClient(host_ref, callback, &client_context)) {
+    CFRelease(host_ref);
+    return EAI_SYSTEM;
+  }
+
+  // Schedule on run loop
+  CFRunLoopRef run_loop = CFRunLoopGetCurrent();
+  CFHostScheduleWithRunLoop(host_ref, run_loop, kCFRunLoopDefaultMode);
+
+  // Start resolution
+  CFStreamError stream_error;
+  if (!CFHostStartInfoResolution(host_ref, kCFHostAddresses, &stream_error)) {
+    CFHostUnscheduleFromRunLoop(host_ref, run_loop, kCFRunLoopDefaultMode);
+    CFRelease(host_ref);
+    return EAI_FAIL;
+  }
+
+  // Wait for completion with timeout
+  auto timeout_time =
+      std::chrono::steady_clock::now() + std::chrono::seconds(timeout_sec);
+  bool timed_out = false;
+
+  {
+    std::unique_lock<std::mutex> lock(context.mutex);
+
+    while (!context.completed) {
+      auto now = std::chrono::steady_clock::now();
+      if (now >= timeout_time) {
+        timed_out = true;
+        break;
+      }
+
+      // Run the runloop for a short time
+      lock.unlock();
+      CFRunLoopRunInMode(kCFRunLoopDefaultMode, 0.1, true);
+      lock.lock();
+    }
+  }
+
+  // Clean up
+  CFHostUnscheduleFromRunLoop(host_ref, run_loop, kCFRunLoopDefaultMode);
+  CFHostSetClient(host_ref, nullptr, nullptr);
+
+  if (timed_out || !context.completed) {
+    CFHostCancelInfoResolution(host_ref, kCFHostAddresses);
+    CFRelease(host_ref);
+    return EAI_AGAIN;
+  }
+
+  if (!context.success || !context.addresses) {
+    CFRelease(host_ref);
+    return EAI_NODATA;
+  }
+
+  // Convert CFArray to addrinfo
+  CFIndex count = CFArrayGetCount(context.addresses);
+  if (count == 0) {
+    CFRelease(context.addresses);
+    CFRelease(host_ref);
+    return EAI_NODATA;
+  }
+
+  struct addrinfo *result_addrinfo = nullptr;
+  struct addrinfo **current = &result_addrinfo;
+
+  for (CFIndex i = 0; i < count; i++) {
+    CFDataRef addr_data =
+        static_cast<CFDataRef>(CFArrayGetValueAtIndex(context.addresses, i));
+    if (!addr_data) continue;
+
+    const struct sockaddr *sockaddr_ptr =
+        reinterpret_cast<const struct sockaddr *>(CFDataGetBytePtr(addr_data));
+    socklen_t sockaddr_len = static_cast<socklen_t>(CFDataGetLength(addr_data));
+
+    // Allocate addrinfo structure
+    *current = static_cast<struct addrinfo *>(malloc(sizeof(struct addrinfo)));
+    if (!*current) {
+      freeaddrinfo(result_addrinfo);
+      CFRelease(context.addresses);
+      CFRelease(host_ref);
+      return EAI_MEMORY;
+    }
+
+    memset(*current, 0, sizeof(struct addrinfo));
+
+    // Set up addrinfo fields
+    (*current)->ai_family = sockaddr_ptr->sa_family;
+    (*current)->ai_socktype = hints ? hints->ai_socktype : SOCK_STREAM;
+    (*current)->ai_protocol = hints ? hints->ai_protocol : IPPROTO_TCP;
+    (*current)->ai_addrlen = sockaddr_len;
+
+    // Copy sockaddr
+    (*current)->ai_addr = static_cast<struct sockaddr *>(malloc(sockaddr_len));
+    if (!(*current)->ai_addr) {
+      freeaddrinfo(result_addrinfo);
+      CFRelease(context.addresses);
+      CFRelease(host_ref);
+      return EAI_MEMORY;
+    }
+    memcpy((*current)->ai_addr, sockaddr_ptr, sockaddr_len);
+
+    // Set port if service is specified
+    if (service && strlen(service) > 0) {
+      int port = atoi(service);
+      if (port > 0) {
+        if (sockaddr_ptr->sa_family == AF_INET) {
+          reinterpret_cast<struct sockaddr_in *>((*current)->ai_addr)
+              ->sin_port = htons(static_cast<uint16_t>(port));
+        } else if (sockaddr_ptr->sa_family == AF_INET6) {
+          reinterpret_cast<struct sockaddr_in6 *>((*current)->ai_addr)
+              ->sin6_port = htons(static_cast<uint16_t>(port));
+        }
+      }
+    }
+
+    current = &((*current)->ai_next);
+  }
+
+  CFRelease(context.addresses);
+  CFRelease(host_ref);
+
+  *res = result_addrinfo;
+  return 0;
+#elif defined(_GNU_SOURCE) && defined(__GLIBC__) &&                            \
+    (__GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ >= 2))
+  // Linux implementation using getaddrinfo_a for asynchronous DNS resolution
+  struct gaicb request;
+  struct gaicb *requests[1] = {&request};
+  struct sigevent sevp;
+  struct timespec timeout;
+
+  // Initialize the request structure
+  memset(&request, 0, sizeof(request));
+  request.ar_name = node;
+  request.ar_service = service;
+  request.ar_request = hints;
+
+  // Set up timeout
+  timeout.tv_sec = timeout_sec;
+  timeout.tv_nsec = 0;
+
+  // Initialize sigevent structure (not used, but required)
+  memset(&sevp, 0, sizeof(sevp));
+  sevp.sigev_notify = SIGEV_NONE;
+
+  // Start asynchronous resolution
+  int start_result = getaddrinfo_a(GAI_NOWAIT, requests, 1, &sevp);
+  if (start_result != 0) { return start_result; }
+
+  // Wait for completion with timeout
+  int wait_result =
+      gai_suspend((const struct gaicb *const *)requests, 1, &timeout);
+
+  if (wait_result == 0) {
+    // Completed successfully, get the result
+    int gai_result = gai_error(&request);
+    if (gai_result == 0) {
+      *res = request.ar_result;
+      return 0;
+    } else {
+      // Clean up on error
+      if (request.ar_result) { freeaddrinfo(request.ar_result); }
+      return gai_result;
+    }
+  } else if (wait_result == EAI_AGAIN) {
+    // Timeout occurred, cancel the request
+    gai_cancel(&request);
+    return EAI_AGAIN;
+  } else {
+    // Other error occurred
+    gai_cancel(&request);
+    return wait_result;
+  }
+#else
+  // Fallback implementation using thread-based timeout for other Unix systems
+  std::mutex result_mutex;
+  std::condition_variable result_cv;
+  auto completed = false;
+  auto result = EAI_SYSTEM;
+  struct addrinfo *result_addrinfo = nullptr;
+
+  std::thread resolve_thread([&]() {
+    auto thread_result = getaddrinfo(node, service, hints, &result_addrinfo);
+
+    std::lock_guard<std::mutex> lock(result_mutex);
+    result = thread_result;
+    completed = true;
+    result_cv.notify_one();
+  });
+
+  // Wait for completion or timeout
+  std::unique_lock<std::mutex> lock(result_mutex);
+  auto finished = result_cv.wait_for(lock, std::chrono::seconds(timeout_sec),
+                                     [&] { return completed; });
+
+  if (finished) {
+    // Operation completed within timeout
+    resolve_thread.join();
+    *res = result_addrinfo;
+    return result;
+  } else {
+    // Timeout occurred
+    resolve_thread.detach(); // Let the thread finish in background
+    return EAI_AGAIN;        // Return timeout error
+  }
+#endif
+}
+
 template <typename BindOrConnect>
 socket_t create_socket(const std::string &host, const std::string &ip, int port,
                        int address_family, int socket_flags, bool tcp_nodelay,
                        bool ipv6_v6only, SocketOptions socket_options,
-                       BindOrConnect bind_or_connect) {
+                       BindOrConnect bind_or_connect, time_t timeout_sec = 0) {
   // Get address info
   const char *node = nullptr;
   struct addrinfo hints;
@@ -3443,7 +3766,8 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port,
 
   auto service = std::to_string(port);
 
-  if (getaddrinfo(node, service.c_str(), &hints, &result)) {
+  if (getaddrinfo_with_timeout(node, service.c_str(), &hints, &result,
+                               timeout_sec)) {
 #if defined __linux__ && !defined __ANDROID__
     res_init();
 #endif
@@ -3541,7 +3865,9 @@ inline bool bind_ip_address(socket_t sock, const std::string &host) {
   hints.ai_socktype = SOCK_STREAM;
   hints.ai_protocol = 0;
 
-  if (getaddrinfo(host.c_str(), "0", &hints, &result)) { return false; }
+  if (getaddrinfo_with_timeout(host.c_str(), "0", &hints, &result, 0)) {
+    return false;
+  }
   auto se = detail::scope_exit([&] { freeaddrinfo(result); });
 
   auto ret = false;
@@ -3646,7 +3972,8 @@ inline socket_t create_client_socket(
 
         error = Error::Success;
         return true;
-      });
+      },
+      connection_timeout_sec); // Pass DNS timeout
 
   if (sock != INVALID_SOCKET) {
     error = Error::Success;
@@ -5867,7 +6194,8 @@ inline void hosted_at(const std::string &hostname,
   hints.ai_socktype = SOCK_STREAM;
   hints.ai_protocol = 0;
 
-  if (getaddrinfo(hostname.c_str(), nullptr, &hints, &result)) {
+  if (detail::getaddrinfo_with_timeout(hostname.c_str(), nullptr, &hints,
+                                       &result, 0)) {
 #if defined __linux__ && !defined __ANDROID__
     res_init();
 #endif

+ 11 - 3
test/Makefile

@@ -9,7 +9,7 @@ OPENSSL_SUPPORT = -DCPPHTTPLIB_OPENSSL_SUPPORT -I$(OPENSSL_DIR)/include -L$(OPEN
 ifneq ($(OS), Windows_NT)
 	UNAME_S := $(shell uname -s)
 	ifeq ($(UNAME_S), Darwin)
-		OPENSSL_SUPPORT += -DCPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN -framework CoreFoundation -framework Security
+		OPENSSL_SUPPORT += -DCPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN -framework CoreFoundation -framework Security -framework CFNetwork
 	endif
 endif
 
@@ -21,7 +21,15 @@ BROTLI_SUPPORT = -DCPPHTTPLIB_BROTLI_SUPPORT -I$(BROTLI_DIR)/include -L$(BROTLI_
 ZSTD_DIR = $(PREFIX)/opt/zstd
 ZSTD_SUPPORT = -DCPPHTTPLIB_ZSTD_SUPPORT -I$(ZSTD_DIR)/include -L$(ZSTD_DIR)/lib -lzstd
 
-TEST_ARGS = gtest/src/gtest-all.cc gtest/src/gtest_main.cc -Igtest -Igtest/include $(OPENSSL_SUPPORT) $(ZLIB_SUPPORT) $(BROTLI_SUPPORT) $(ZSTD_SUPPORT) -pthread -lcurl
+LIBS = -lpthread -lcurl
+ifneq ($(OS), Windows_NT)
+	UNAME_S := $(shell uname -s)
+	ifneq ($(UNAME_S), Darwin)
+		LIBS += -lanl
+	endif
+endif
+
+TEST_ARGS = gtest/src/gtest-all.cc gtest/src/gtest_main.cc -Igtest -Igtest/include $(OPENSSL_SUPPORT) $(ZLIB_SUPPORT) $(BROTLI_SUPPORT) $(ZSTD_SUPPORT) $(LIBS)
 
 # By default, use standalone_fuzz_target_runner.
 # This runner does no fuzzing, but simply executes the inputs
@@ -86,7 +94,7 @@ fuzz_test: server_fuzzer
 
 # Fuzz target, so that you can choose which $(LIB_FUZZING_ENGINE) to use.
 server_fuzzer : fuzzing/server_fuzzer.cc ../httplib.h standalone_fuzz_target_runner.o
-	$(CXX) -o $@ -I.. $(CXXFLAGS) $< $(OPENSSL_SUPPORT) $(ZLIB_SUPPORT) $(BROTLI_SUPPORT) $(LIB_FUZZING_ENGINE) -pthread
+	$(CXX) -o $@ -I.. $(CXXFLAGS) $< $(OPENSSL_SUPPORT) $(ZLIB_SUPPORT) $(BROTLI_SUPPORT) $(LIB_FUZZING_ENGINE) $(ZSTD_SUPPORT) $(LIBS)
 	@file $@
 
 # Standalone fuzz runner, which just reads inputs from fuzzing/corpus/ dir and

+ 1 - 1
test/fuzzing/Makefile

@@ -20,7 +20,7 @@ all : server_fuzzer
 # Fuzz target, so that you can choose which $(LIB_FUZZING_ENGINE) to use.
 server_fuzzer : server_fuzzer.cc ../../httplib.h
 # 	$(CXX) $(CXXFLAGS) -o $@  $<  -Wl,-Bstatic $(OPENSSL_SUPPORT)  -Wl,-Bdynamic -ldl  $(ZLIB_SUPPORT)  $(LIB_FUZZING_ENGINE) -pthread
-	$(CXX) $(CXXFLAGS) -o $@  $<  $(ZLIB_SUPPORT)  $(LIB_FUZZING_ENGINE) -pthread
+	$(CXX) $(CXXFLAGS) -o $@  $<  $(ZLIB_SUPPORT)  $(LIB_FUZZING_ENGINE) -pthread -lanl
 	zip -q -r server_fuzzer_seed_corpus.zip corpus
 
 clean:

+ 7 - 18
test/test.cc

@@ -3388,31 +3388,20 @@ void performance_test(const char *host) {
 
   Client cli(host, port);
 
-  const int NUM_REQUESTS = 50;
-  const int MAX_AVERAGE_MS = 5;
+  auto start = std::chrono::high_resolution_clock::now();
 
-  auto warmup = cli.Get("/benchmark");
-  ASSERT_TRUE(warmup);
+  auto res = cli.Get("/benchmark");
+  ASSERT_TRUE(res);
+  EXPECT_EQ(StatusCode::OK_200, res->status);
 
-  auto start = std::chrono::high_resolution_clock::now();
-  for (int i = 0; i < NUM_REQUESTS; ++i) {
-    auto res = cli.Get("/benchmark");
-    ASSERT_TRUE(res) << "Request " << i << " failed";
-    EXPECT_EQ(StatusCode::OK_200, res->status);
-  }
   auto end = std::chrono::high_resolution_clock::now();
 
-  auto total_ms =
+  auto elapsed =
       std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
           .count();
-  double avg_ms = static_cast<double>(total_ms) / NUM_REQUESTS;
-
-  std::cout << "Peformance test at \"" << host << "\": " << NUM_REQUESTS
-            << " requests in " << total_ms << "ms (avg: " << avg_ms << "ms)"
-            << std::endl;
 
-  EXPECT_LE(avg_ms, MAX_AVERAGE_MS)
-      << "Performance is too slow: " << avg_ms << "ms (Issue #1777)";
+  EXPECT_LE(elapsed, 5) << "Performance is too slow: " << elapsed
+                        << "ms (Issue #1777)";
 }
 
 TEST(BenchmarkTest, localhost) { performance_test("localhost"); }