Browse Source

Merge commit from fork

yhirose 1 month ago
parent
commit
c99d7472b5
2 changed files with 124 additions and 30 deletions
  1. 61 29
      httplib.h
  2. 63 1
      test/test.cc

+ 61 - 29
httplib.h

@@ -6950,7 +6950,8 @@ inline bool is_chunked_transfer_encoding(const Headers &headers) {
 template <typename T, typename U>
 bool prepare_content_receiver(T &x, int &status,
                               ContentReceiverWithProgress receiver,
-                              bool decompress, U callback) {
+                              bool decompress, size_t payload_max_length,
+                              bool &exceed_payload_max_length, U callback) {
   if (decompress) {
     std::string encoding = x.get_header_value("Content-Encoding");
     std::unique_ptr<decompressor> decompressor;
@@ -6966,12 +6967,22 @@ bool prepare_content_receiver(T &x, int &status,
 
     if (decompressor) {
       if (decompressor->is_valid()) {
+        size_t decompressed_size = 0;
         ContentReceiverWithProgress out = [&](const char *buf, size_t n,
                                               size_t off, size_t len) {
-          return decompressor->decompress(buf, n,
-                                          [&](const char *buf2, size_t n2) {
-                                            return receiver(buf2, n2, off, len);
-                                          });
+          return decompressor->decompress(
+              buf, n, [&](const char *buf2, size_t n2) {
+                // Guard against zip-bomb: check
+                // decompressed size against limit.
+                if (payload_max_length > 0 &&
+                    (decompressed_size >= payload_max_length ||
+                     n2 > payload_max_length - decompressed_size)) {
+                  exceed_payload_max_length = true;
+                  return false;
+                }
+                decompressed_size += n2;
+                return receiver(buf2, n2, off, len);
+              });
         };
         return callback(std::move(out));
       } else {
@@ -6992,11 +7003,14 @@ template <typename T>
 bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status,
                   DownloadProgress progress,
                   ContentReceiverWithProgress receiver, bool decompress) {
+  bool exceed_payload_max_length = false;
   return prepare_content_receiver(
-      x, status, std::move(receiver), decompress,
-      [&](const ContentReceiverWithProgress &out) {
+      x, status, std::move(receiver), decompress, payload_max_length,
+      exceed_payload_max_length, [&](const ContentReceiverWithProgress &out) {
         auto ret = true;
-        auto exceed_payload_max_length = false;
+        // Note: exceed_payload_max_length may also be set by the decompressor
+        // wrapper in prepare_content_receiver when the decompressed payload
+        // size exceeds the limit.
 
         if (is_chunked_transfer_encoding(x.headers)) {
           auto result = read_content_chunked(strm, x, payload_max_length, out);
@@ -11288,45 +11302,63 @@ inline bool Server::routing(Request &req, Response &res, Stream &strm) {
   if (detail::expect_content(req)) {
     // Content reader handler
     {
+      // Track whether the ContentReader was aborted due to the decompressed
+      // payload exceeding `payload_max_length_`.
+      // The user handler runs after the lambda returns, so we must restore the
+      // 413 status if the handler overwrites it.
+      bool content_reader_payload_too_large = false;
+
       ContentReader reader(
           [&](ContentReceiver receiver) {
             auto result = read_content_with_content_receiver(
                 strm, req, res, std::move(receiver), nullptr, nullptr);
-            if (!result) { output_error_log(Error::Read, &req); }
+            if (!result) {
+              output_error_log(Error::Read, &req);
+              if (res.status == StatusCode::PayloadTooLarge_413) {
+                content_reader_payload_too_large = true;
+              }
+            }
             return result;
           },
           [&](FormDataHeader header, ContentReceiver receiver) {
             auto result = read_content_with_content_receiver(
                 strm, req, res, nullptr, std::move(header),
                 std::move(receiver));
-            if (!result) { output_error_log(Error::Read, &req); }
+            if (!result) {
+              output_error_log(Error::Read, &req);
+              if (res.status == StatusCode::PayloadTooLarge_413) {
+                content_reader_payload_too_large = true;
+              }
+            }
             return result;
           });
 
+      bool dispatched = false;
       if (req.method == "POST") {
-        if (dispatch_request_for_content_reader(
-                req, res, std::move(reader),
-                post_handlers_for_content_reader_)) {
-          return true;
-        }
+        dispatched = dispatch_request_for_content_reader(
+            req, res, std::move(reader), post_handlers_for_content_reader_);
       } else if (req.method == "PUT") {
-        if (dispatch_request_for_content_reader(
-                req, res, std::move(reader),
-                put_handlers_for_content_reader_)) {
-          return true;
-        }
+        dispatched = dispatch_request_for_content_reader(
+            req, res, std::move(reader), put_handlers_for_content_reader_);
       } else if (req.method == "PATCH") {
-        if (dispatch_request_for_content_reader(
-                req, res, std::move(reader),
-                patch_handlers_for_content_reader_)) {
-          return true;
-        }
+        dispatched = dispatch_request_for_content_reader(
+            req, res, std::move(reader), patch_handlers_for_content_reader_);
       } else if (req.method == "DELETE") {
-        if (dispatch_request_for_content_reader(
-                req, res, std::move(reader),
-                delete_handlers_for_content_reader_)) {
-          return true;
+        dispatched = dispatch_request_for_content_reader(
+            req, res, std::move(reader), delete_handlers_for_content_reader_);
+      }
+
+      if (dispatched) {
+        if (content_reader_payload_too_large) {
+          // Enforce the limit: override any status the handler may have set
+          // and return false so the error path sends a plain 413 response.
+          res.status = StatusCode::PayloadTooLarge_413;
+          res.body.clear();
+          res.content_length_ = 0;
+          res.content_provider_ = nullptr;
+          return false;
         }
+        return true;
       }
     }
 

+ 63 - 1
test/test.cc

@@ -8649,6 +8649,68 @@ TEST_F(LargePayloadMaxLengthTest, NoContentLengthExceeds10MB) {
   }
 }
 
+#ifdef CPPHTTPLIB_ZLIB_SUPPORT
+// `payload_max_length` is not enforced on decompressed body in ContentReader
+// path.
+TEST(PayloadLimitBypassTest, StreamingGzipDecompression) {
+  Server svr;
+  const size_t LIMIT = 64 * 1024; // 64KB
+  svr.set_payload_max_length(LIMIT);
+
+  size_t total = 0;
+  svr.Post("/stream", [&](const Request & /*req*/, Response &res,
+                          const ContentReader &content_reader) {
+    content_reader([&](const char * /*data*/, size_t len) {
+      total += len;
+      return true;
+    });
+    res.status = 200;
+    res.set_content("stream_ok", "text/plain");
+  });
+
+  auto thread = std::thread([&]() { svr.listen(HOST, PORT); });
+  auto se = detail::scope_exit([&] {
+    svr.stop();
+    thread.join();
+    ASSERT_FALSE(svr.is_running());
+  });
+  svr.wait_until_ready();
+
+  // Prepare 256KB raw data and gzip-compress it
+  std::string raw(256 * 1024, 'A');
+  std::string gz;
+  {
+    z_stream zs{};
+    deflateInit2(&zs, Z_BEST_COMPRESSION, Z_DEFLATED, 15 + 16, 8,
+                 Z_DEFAULT_STRATEGY);
+    zs.next_in = reinterpret_cast<Bytef *>(const_cast<char *>(raw.data()));
+    zs.avail_in = static_cast<uInt>(raw.size());
+    char outbuf[4096];
+    int ret;
+    do {
+      zs.next_out = reinterpret_cast<Bytef *>(outbuf);
+      zs.avail_out = sizeof(outbuf);
+      ret = deflate(&zs, Z_FINISH);
+      gz.append(outbuf, sizeof(outbuf) - zs.avail_out);
+    } while (ret != Z_STREAM_END);
+    deflateEnd(&zs);
+  }
+
+  Client cli(HOST, PORT);
+  cli.set_connection_timeout(std::chrono::seconds(5));
+  Headers headers = {{"Content-Encoding", "gzip"}};
+  auto res = cli.Post("/stream", headers, gz.data(), gz.size(),
+                      "application/octet-stream");
+  ASSERT_TRUE(res);
+
+  // Server must reject oversized decompressed payloads with 413.
+  EXPECT_EQ(StatusCode::PayloadTooLarge_413, res->status);
+
+  // Decompressed bytes delivered to the handler must not exceed LIMIT.
+  EXPECT_LE(total, LIMIT);
+}
+#endif
+
 // Regression test for DoS vulnerability: a malicious server sending a response
 // without Content-Length header must not cause unbounded memory consumption on
 // the client side. The client should stop reading after a reasonable limit,
@@ -15658,7 +15720,7 @@ TEST(ZipBombProtectionTest, DecompressedSizeExceedsLimit) {
 
   // Server should reject because decompressed size (8KB) exceeds limit (1KB)
   ASSERT_TRUE(res);
-  EXPECT_EQ(StatusCode::BadRequest_400, res->status);
+  EXPECT_EQ(StatusCode::PayloadTooLarge_413, res->status);
 }
 #endif