Browse Source

Add error handling for stream read timeouts and connection closures

yhirose 2 tháng trước cách đây
mục cha
commit
dbd5ca4bf2
3 tập tin đã thay đổi với 324 bổ sung50 xóa
  1. 19 1
      README.md
  2. 81 48
      httplib.h
  3. 224 1
      test/test.cc

+ 19 - 1
README.md

@@ -722,7 +722,7 @@ httplib::SSLClient cli("localhost");
 Here is the list of errors from `Result::error()`.
 
 ```c++
-enum Error {
+enum class Error {
   Success = 0,
   Unknown,
   Connection,
@@ -739,6 +739,24 @@ enum Error {
   Compression,
   ConnectionTimeout,
   ProxyConnection,
+  ConnectionClosed,
+  Timeout,
+  ResourceExhaustion,
+  TooManyFormDataFiles,
+  ExceedMaxPayloadSize,
+  ExceedUriMaxLength,
+  ExceedMaxSocketDescriptorCount,
+  InvalidRequestLine,
+  InvalidHTTPMethod,
+  InvalidHTTPVersion,
+  InvalidHeaders,
+  MultipartParsing,
+  OpenFile,
+  Listen,
+  GetSockName,
+  UnsupportedAddressFamily,
+  HTTPParsing,
+  InvalidRangeHeader,
 };
 ```
 

+ 81 - 48
httplib.h

@@ -838,6 +838,50 @@ struct Response {
   std::string file_content_content_type_;
 };
 
+enum class Error {
+  Success = 0,
+  Unknown,
+  Connection,
+  BindIPAddress,
+  Read,
+  Write,
+  ExceedRedirectCount,
+  Canceled,
+  SSLConnection,
+  SSLLoadingCerts,
+  SSLServerVerification,
+  SSLServerHostnameVerification,
+  UnsupportedMultipartBoundaryChars,
+  Compression,
+  ConnectionTimeout,
+  ProxyConnection,
+  ConnectionClosed,
+  Timeout,
+  ResourceExhaustion,
+  TooManyFormDataFiles,
+  ExceedMaxPayloadSize,
+  ExceedUriMaxLength,
+  ExceedMaxSocketDescriptorCount,
+  InvalidRequestLine,
+  InvalidHTTPMethod,
+  InvalidHTTPVersion,
+  InvalidHeaders,
+  MultipartParsing,
+  OpenFile,
+  Listen,
+  GetSockName,
+  UnsupportedAddressFamily,
+  HTTPParsing,
+  InvalidRangeHeader,
+
+  // For internal use only
+  SSLPeerCouldBeClosed_,
+};
+
+std::string to_string(Error error);
+
+std::ostream &operator<<(std::ostream &os, const Error &obj);
+
 class Stream {
 public:
   virtual ~Stream() = default;
@@ -856,6 +900,11 @@ public:
 
   ssize_t write(const char *ptr);
   ssize_t write(const std::string &s);
+
+  Error get_error() const { return error_; }
+
+protected:
+  Error error_ = Error::Success;
 };
 
 class TaskQueue {
@@ -1292,48 +1341,6 @@ private:
       detail::write_headers;
 };
 
-enum class Error {
-  Success = 0,
-  Unknown,
-  Connection,
-  BindIPAddress,
-  Read,
-  Write,
-  ExceedRedirectCount,
-  Canceled,
-  SSLConnection,
-  SSLLoadingCerts,
-  SSLServerVerification,
-  SSLServerHostnameVerification,
-  UnsupportedMultipartBoundaryChars,
-  Compression,
-  ConnectionTimeout,
-  ProxyConnection,
-  ResourceExhaustion,
-  TooManyFormDataFiles,
-  ExceedMaxPayloadSize,
-  ExceedUriMaxLength,
-  ExceedMaxSocketDescriptorCount,
-  InvalidRequestLine,
-  InvalidHTTPMethod,
-  InvalidHTTPVersion,
-  InvalidHeaders,
-  MultipartParsing,
-  OpenFile,
-  Listen,
-  GetSockName,
-  UnsupportedAddressFamily,
-  HTTPParsing,
-  InvalidRangeHeader,
-
-  // For internal use only
-  SSLPeerCouldBeClosed_,
-};
-
-std::string to_string(Error error);
-
-std::ostream &operator<<(std::ostream &os, const Error &obj);
-
 class Result {
 public:
   Result() = default;
@@ -2437,6 +2444,8 @@ inline std::string to_string(const Error error) {
   case Error::Compression: return "Compression failed";
   case Error::ConnectionTimeout: return "Connection timed out";
   case Error::ProxyConnection: return "Proxy connection failed";
+  case Error::ConnectionClosed: return "Connection closed by server";
+  case Error::Timeout: return "Read timeout";
   case Error::ResourceExhaustion: return "Resource exhaustion";
   case Error::TooManyFormDataFiles: return "Too many form data files";
   case Error::ExceedMaxPayloadSize: return "Exceeded maximum payload size";
@@ -7273,13 +7282,15 @@ inline ssize_t detail::BodyReader::read(char *buf, size_t len) {
     auto n = stream->read(buf, to_read);
 
     if (n < 0) {
-      last_error = Error::Read;
+      last_error = stream->get_error();
+      if (last_error == Error::Success) { last_error = Error::Read; }
       eof = true;
       return n;
     }
     if (n == 0) {
       // Unexpected EOF before content_length
-      last_error = Error::Read;
+      last_error = stream->get_error();
+      if (last_error == Error::Success) { last_error = Error::Read; }
       eof = true;
       return 0;
     }
@@ -7296,7 +7307,8 @@ inline ssize_t detail::BodyReader::read(char *buf, size_t len) {
   size_t chunk_total = 0;
   auto n = chunked_decoder->read_payload(buf, len, chunk_offset, chunk_total);
   if (n < 0) {
-    last_error = Error::Read;
+    last_error = stream->get_error();
+    if (last_error == Error::Success) { last_error = Error::Read; }
     eof = true;
     return n;
   }
@@ -7387,7 +7399,10 @@ inline ssize_t SocketStream::read(char *ptr, size_t size) {
     }
   }
 
-  if (!wait_readable()) { return -1; }
+  if (!wait_readable()) {
+    error_ = Error::Timeout;
+    return -1;
+  }
 
   read_buff_off_ = 0;
   read_buff_content_size_ = 0;
@@ -7396,6 +7411,11 @@ inline ssize_t SocketStream::read(char *ptr, size_t size) {
     auto n = read_socket(sock_, read_buff_.data(), read_buff_size_,
                          CPPHTTPLIB_RECV_FLAGS);
     if (n <= 0) {
+      if (n == 0) {
+        error_ = Error::ConnectionClosed;
+      } else {
+        error_ = Error::Read;
+      }
       return n;
     } else if (n <= static_cast<ssize_t>(size)) {
       memcpy(ptr, read_buff_.data(), static_cast<size_t>(n));
@@ -7407,7 +7427,15 @@ inline ssize_t SocketStream::read(char *ptr, size_t size) {
       return static_cast<ssize_t>(size);
     }
   } else {
-    return read_socket(sock_, ptr, size, CPPHTTPLIB_RECV_FLAGS);
+    auto n = read_socket(sock_, ptr, size, CPPHTTPLIB_RECV_FLAGS);
+    if (n <= 0) {
+      if (n == 0) {
+        error_ = Error::ConnectionClosed;
+      } else {
+        error_ = Error::Read;
+      }
+    }
+    return n;
   }
 }
 
@@ -11435,7 +11463,9 @@ inline bool SSLSocketStream::wait_writable() const {
 
 inline ssize_t SSLSocketStream::read(char *ptr, size_t size) {
   if (SSL_pending(ssl_) > 0) {
-    return SSL_read(ssl_, ptr, static_cast<int>(size));
+    auto ret = SSL_read(ssl_, ptr, static_cast<int>(size));
+    if (ret == 0) { error_ = Error::ConnectionClosed; }
+    return ret;
   } else if (wait_readable()) {
     auto ret = SSL_read(ssl_, ptr, static_cast<int>(size));
     if (ret < 0) {
@@ -11460,9 +11490,12 @@ inline ssize_t SSLSocketStream::read(char *ptr, size_t size) {
         }
       }
       assert(ret < 0);
+    } else if (ret == 0) {
+      error_ = Error::ConnectionClosed;
     }
     return ret;
   } else {
+    error_ = Error::Timeout;
     return -1;
   }
 }

+ 224 - 1
test/test.cc

@@ -12463,4 +12463,227 @@ TEST_F(SSLStreamApiTest, GetAndPost) {
   auto post = httplib::stream::Post(cli, "/echo", "test", "text/plain");
   EXPECT_EQ("test", read_body(post));
 }
-#endif
+#endif
+
+// Tests for Error::Timeout and Error::ConnectionClosed error types
+// These errors are set in SocketStream/SSLSocketStream and propagated through
+// BodyReader
+
+TEST(ErrorHandlingTest, StreamReadTimeout) {
+  // Test that read timeout during streaming is detected
+  // Use a large content-length response where server delays mid-stream
+  Server svr;
+
+  svr.Get("/slow-stream", [](const Request &, Response &res) {
+    // Send a large response with delay in the middle
+    res.set_content_provider(
+        1000, // content_length
+        "text/plain", [](size_t offset, size_t /*length*/, DataSink &sink) {
+          if (offset < 100) {
+            // Send first 100 bytes immediately
+            std::string data(100, 'A');
+            sink.write(data.c_str(), data.size());
+            return true;
+          }
+          // Then delay longer than client timeout
+          std::this_thread::sleep_for(std::chrono::seconds(3));
+          std::string data(900, 'B');
+          sink.write(data.c_str(), data.size());
+          return true;
+        });
+  });
+
+  auto port = 8091;
+  std::thread t([&]() { svr.listen("localhost", port); });
+  svr.wait_until_ready();
+
+  Client cli("localhost", port);
+  cli.set_read_timeout(1, 0); // 1 second timeout
+
+  auto handle = cli.open_stream("GET", "/slow-stream");
+  ASSERT_TRUE(handle.is_valid());
+
+  char buf[256];
+  ssize_t total = 0;
+  ssize_t n;
+  bool got_error = false;
+
+  while ((n = handle.read(buf, sizeof(buf))) > 0) {
+    total += n;
+  }
+
+  if (n < 0) {
+    got_error = true;
+    // Should be timeout or read error
+    EXPECT_TRUE(handle.get_read_error() == Error::Timeout ||
+                handle.get_read_error() == Error::Read)
+        << "Actual error: " << to_string(handle.get_read_error());
+  }
+
+  // Either we got an error, or we got less data than expected
+  EXPECT_TRUE(got_error || total < 1000)
+      << "Expected timeout but got all " << total << " bytes";
+
+  svr.stop();
+  t.join();
+}
+
+TEST(ErrorHandlingTest, StreamConnectionClosed) {
+  // Test connection closed detection via BodyReader
+  Server svr;
+  std::atomic<bool> close_now{false};
+
+  svr.Get("/will-close", [&](const Request &, Response &res) {
+    res.set_content_provider(
+        10000, // Large content_length that we won't fully send
+        "text/plain", [&](size_t offset, size_t /*length*/, DataSink &sink) {
+          if (offset < 100) {
+            std::string data(100, 'X');
+            sink.write(data.c_str(), data.size());
+            return true;
+          }
+          // Wait for signal then abort
+          while (!close_now) {
+            std::this_thread::sleep_for(std::chrono::milliseconds(10));
+          }
+          return false; // Abort - server will close connection
+        });
+  });
+
+  auto port = 8092;
+  std::thread t([&]() { svr.listen("localhost", port); });
+  svr.wait_until_ready();
+
+  Client cli("localhost", port);
+  auto handle = cli.open_stream("GET", "/will-close");
+  ASSERT_TRUE(handle.is_valid());
+
+  char buf[256];
+  ssize_t n = handle.read(buf, sizeof(buf)); // First read
+  EXPECT_GT(n, 0) << "First read should succeed";
+
+  // Signal server to close
+  close_now = true;
+
+  // Keep reading until error or EOF
+  while ((n = handle.read(buf, sizeof(buf))) > 0) {
+    // Keep reading
+  }
+
+  // Should get an error since content_length wasn't satisfied
+  if (n < 0) {
+    EXPECT_TRUE(handle.get_read_error() == Error::ConnectionClosed ||
+                handle.get_read_error() == Error::Read)
+        << "Actual error: " << to_string(handle.get_read_error());
+  }
+
+  svr.stop();
+  t.join();
+}
+
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+TEST(ErrorHandlingTest, SSLStreamReadTimeout) {
+  // Test that read timeout during SSL streaming is detected
+  SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE);
+
+  svr.Get("/slow-stream", [](const Request &, Response &res) {
+    res.set_content_provider(
+        1000, "text/plain",
+        [](size_t offset, size_t /*length*/, DataSink &sink) {
+          if (offset < 100) {
+            std::string data(100, 'A');
+            sink.write(data.c_str(), data.size());
+            return true;
+          }
+          std::this_thread::sleep_for(std::chrono::seconds(3));
+          std::string data(900, 'B');
+          sink.write(data.c_str(), data.size());
+          return true;
+        });
+  });
+
+  auto port = 8093;
+  std::thread t([&]() { svr.listen("localhost", port); });
+  svr.wait_until_ready();
+
+  SSLClient cli("localhost", port);
+  cli.enable_server_certificate_verification(false);
+  cli.set_read_timeout(1, 0); // 1 second timeout
+
+  auto handle = cli.open_stream("GET", "/slow-stream");
+  ASSERT_TRUE(handle.is_valid());
+
+  char buf[256];
+  ssize_t total = 0;
+  ssize_t n;
+  bool got_error = false;
+
+  while ((n = handle.read(buf, sizeof(buf))) > 0) {
+    total += n;
+  }
+
+  if (n < 0) {
+    got_error = true;
+    EXPECT_TRUE(handle.get_read_error() == Error::Timeout ||
+                handle.get_read_error() == Error::Read)
+        << "Actual error: " << to_string(handle.get_read_error());
+  }
+
+  EXPECT_TRUE(got_error || total < 1000)
+      << "Expected timeout but got all " << total << " bytes";
+
+  svr.stop();
+  t.join();
+}
+
+TEST(ErrorHandlingTest, SSLStreamConnectionClosed) {
+  // Test SSL connection closed detection
+  SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE);
+  std::atomic<bool> close_now{false};
+
+  svr.Get("/will-close", [&](const Request &, Response &res) {
+    res.set_content_provider(
+        10000, "text/plain",
+        [&](size_t offset, size_t /*length*/, DataSink &sink) {
+          if (offset < 100) {
+            std::string data(100, 'X');
+            sink.write(data.c_str(), data.size());
+            return true;
+          }
+          while (!close_now) {
+            std::this_thread::sleep_for(std::chrono::milliseconds(10));
+          }
+          return false;
+        });
+  });
+
+  auto port = 8094;
+  std::thread t([&]() { svr.listen("localhost", port); });
+  svr.wait_until_ready();
+
+  SSLClient cli("localhost", port);
+  cli.enable_server_certificate_verification(false);
+  auto handle = cli.open_stream("GET", "/will-close");
+  ASSERT_TRUE(handle.is_valid());
+
+  char buf[256];
+  ssize_t n = handle.read(buf, sizeof(buf)); // First read
+  EXPECT_GT(n, 0);
+
+  // Signal server to close
+  close_now = true;
+
+  while ((n = handle.read(buf, sizeof(buf))) > 0) {
+    // Keep reading
+  }
+
+  if (n < 0) {
+    EXPECT_TRUE(handle.get_read_error() == Error::ConnectionClosed ||
+                handle.get_read_error() == Error::Read)
+        << "Actual error: " << to_string(handle.get_read_error());
+  }
+
+  svr.stop();
+  t.join();
+}
+#endif