ソースを参照

Fix HTTP 414 errors hanging until timeout (#2260)

* Fix HTTP 414 errors hanging until timeout

* All errors (status code 400+) close the connection

* 🧹

---------

Co-authored-by: Wor Ker <worker@factory>
chansikpark 3 ヶ月 前
コミット
4b2b851dbb
2 ファイル変更46 行追加3 行削除
  1. 2 3
      httplib.h
  2. 44 0
      test/test.cc

+ 2 - 3
httplib.h

@@ -7692,7 +7692,8 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection,
   if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); }
 
   // Prepare additional headers
-  if (close_connection || req.get_header_value("Connection") == "close") {
+  if (close_connection || req.get_header_value("Connection") == "close" ||
+      400 <= res.status) { // Don't leave connections open after errors
     res.set_header("Connection", "close");
   } else {
     std::string s = "timeout=";
@@ -8403,8 +8404,6 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
 
   // Check if the request URI doesn't exceed the limit
   if (req.target.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) {
-    Headers dummy;
-    detail::read_headers(strm, dummy);
     res.status = StatusCode::UriTooLong_414;
     output_error_log(Error::ExceedUriMaxLength, &req);
     return write_response(strm, close_connection, req, res);

+ 44 - 0
test/test.cc

@@ -4289,10 +4289,21 @@ TEST_F(ServerTest, TooLongRequest) {
   }
   request += "_NG";
 
+  auto start = std::chrono::high_resolution_clock::now();
+
+  cli_.set_keep_alive(true);
   auto res = cli_.Get(request.c_str());
 
+  auto end = std::chrono::high_resolution_clock::now();
+  auto elapsed =
+      std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
+          .count();
+
   ASSERT_TRUE(res);
   EXPECT_EQ(StatusCode::UriTooLong_414, res->status);
+  EXPECT_LE(elapsed, 100);
+  EXPECT_EQ("close", res->get_header_value("Connection"));
+  EXPECT_FALSE(cli_.is_socket_open());
 }
 
 TEST_F(ServerTest, AlmostTooLongRequest) {
@@ -4363,10 +4374,21 @@ TEST_F(ServerTest, LongHeader) {
 }
 
 TEST_F(ServerTest, LongQueryValue) {
+  auto start = std::chrono::high_resolution_clock::now();
+
+  cli_.set_keep_alive(true);
   auto res = cli_.Get(LONG_QUERY_URL.c_str());
 
+  auto end = std::chrono::high_resolution_clock::now();
+  auto elapsed =
+      std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
+          .count();
+
   ASSERT_TRUE(res);
   EXPECT_EQ(StatusCode::UriTooLong_414, res->status);
+  EXPECT_LE(elapsed, 100);
+  EXPECT_EQ("close", res->get_header_value("Connection"));
+  EXPECT_FALSE(cli_.is_socket_open());
 }
 
 TEST_F(ServerTest, TooLongQueryValue) {
@@ -4460,6 +4482,7 @@ TEST_F(ServerTest, HeaderCountExceedsLimit) {
   }
 
   // This should fail due to exceeding header count limit
+  cli_.set_keep_alive(true);
   auto res = cli_.Get("/hi", headers);
 
   // The request should either fail or return 400 Bad Request
@@ -4470,6 +4493,9 @@ TEST_F(ServerTest, HeaderCountExceedsLimit) {
     // Or the request should fail entirely
     EXPECT_FALSE(res);
   }
+
+  EXPECT_EQ("close", res->get_header_value("Connection"));
+  EXPECT_FALSE(cli_.is_socket_open());
 }
 
 TEST_F(ServerTest, PercentEncoding) {
@@ -4524,6 +4550,7 @@ TEST_F(ServerTest, HeaderCountSecurityTest) {
   }
 
   // Try to POST with excessive headers
+  cli_.set_keep_alive(true);
   auto res = cli_.Post("/", attack_headers, "test_data", "text/plain");
 
   // Should either fail or return 400 Bad Request due to security limit
@@ -4534,6 +4561,9 @@ TEST_F(ServerTest, HeaderCountSecurityTest) {
     // Request failed, which is the expected behavior for DoS protection
     EXPECT_FALSE(res);
   }
+
+  EXPECT_EQ("close", res->get_header_value("Connection"));
+  EXPECT_FALSE(cli_.is_socket_open());
 }
 
 TEST_F(ServerTest, MultipartFormData) {
@@ -5854,6 +5884,20 @@ TEST_F(ServerTest, TooManyRedirect) {
   EXPECT_EQ(Error::ExceedRedirectCount, res.error());
 }
 
+TEST_F(ServerTest, BadRequestLineCancelsKeepAlive) {
+  Request req;
+  req.method = "FOOBAR";
+  req.path = "/hi";
+
+  cli_.set_keep_alive(true);
+  auto res = cli_.send(req);
+
+  ASSERT_TRUE(res);
+  EXPECT_EQ(StatusCode::BadRequest_400, res->status);
+  EXPECT_EQ("close", res->get_header_value("Connection"));
+  EXPECT_FALSE(cli_.is_socket_open());
+}
+
 TEST_F(ServerTest, StartTime) { auto res = cli_.Get("/test-start-time"); }
 
 #ifdef CPPHTTPLIB_ZLIB_SUPPORT