Преглед изворни кода

Optimize multipart content provider to coalesce small writes and reduce TCP packet fragmentation (Fix #2410)

yhirose пре 4 дана
родитељ
комит
a9359df42e
2 измењених фајлова са 139 додато и 10 уклоњено
  1. 48 10
      httplib.h
  2. 91 0
      test/test.cc

+ 48 - 10
httplib.h

@@ -1571,6 +1571,13 @@ ssize_t write_headers(Stream &strm, const Headers &headers);
 bool set_socket_opt_time(socket_t sock, int level, int optname, time_t sec,
                          time_t usec);
 
+size_t get_multipart_content_length(const UploadFormDataItems &items,
+                                    const std::string &boundary);
+
+ContentProvider
+make_multipart_content_provider(const UploadFormDataItems &items,
+                                const std::string &boundary);
+
 } // namespace detail
 
 class Server {
@@ -8242,6 +8249,7 @@ make_multipart_content_provider(const UploadFormDataItems &items,
   struct MultipartState {
     std::vector<std::string> owned;
     std::vector<MultipartSegment> segs;
+    std::vector<char> buf = std::vector<char>(CPPHTTPLIB_SEND_BUFSIZ);
   };
   auto state = std::make_shared<MultipartState>();
   state->owned = std::move(owned);
@@ -8250,19 +8258,49 @@ make_multipart_content_provider(const UploadFormDataItems &items,
   state->segs = std::move(segs);
 
   return [state](size_t offset, size_t length, DataSink &sink) -> bool {
+    // Buffer multiple small segments into fewer, larger writes to avoid
+    // excessive TCP packets when there are many form data items (#2410)
+    auto &buf = state->buf;
+    auto buf_size = buf.size();
+    size_t buf_len = 0;
+    size_t remaining = length;
+
+    // Find the first segment containing 'offset'
     size_t pos = 0;
-    for (const auto &seg : state->segs) {
-      // Loop invariant: pos <= offset (proven by advancing pos only when
-      // offset - pos >= seg.size, i.e., the segment doesn't contain offset)
-      if (seg.size > 0 && offset - pos < seg.size) {
-        size_t seg_offset = offset - pos;
-        size_t available = seg.size - seg_offset;
-        size_t to_write = (std::min)(available, length);
-        return sink.write(seg.data + seg_offset, to_write);
-      }
+    size_t seg_idx = 0;
+    for (; seg_idx < state->segs.size(); seg_idx++) {
+      const auto &seg = state->segs[seg_idx];
+      if (seg.size > 0 && offset - pos < seg.size) { break; }
       pos += seg.size;
     }
-    return true; // past end (shouldn't be reached when content_length is exact)
+
+    size_t seg_offset = (seg_idx < state->segs.size()) ? offset - pos : 0;
+
+    for (; seg_idx < state->segs.size() && remaining > 0; seg_idx++) {
+      const auto &seg = state->segs[seg_idx];
+      size_t available = seg.size - seg_offset;
+      size_t to_copy = (std::min)(available, remaining);
+      const char *src = seg.data + seg_offset;
+      seg_offset = 0; // only the first segment has a non-zero offset
+
+      while (to_copy > 0) {
+        size_t space = buf_size - buf_len;
+        size_t chunk = (std::min)(to_copy, space);
+        std::memcpy(buf.data() + buf_len, src, chunk);
+        buf_len += chunk;
+        src += chunk;
+        to_copy -= chunk;
+        remaining -= chunk;
+
+        if (buf_len == buf_size) {
+          if (!sink.write(buf.data(), buf_len)) { return false; }
+          buf_len = 0;
+        }
+      }
+    }
+
+    if (buf_len > 0) { return sink.write(buf.data(), buf_len); }
+    return true;
   };
 }
 

+ 91 - 0
test/test.cc

@@ -12035,6 +12035,97 @@ TEST(MultipartFormDataTest, UploadItemsHasContentLength) {
   EXPECT_EQ(StatusCode::OK_200, res->status);
 }
 
+TEST(MultipartFormDataTest, ContentProviderCoalescesWrites) {
+  // Verify that make_multipart_content_provider coalesces many small segments
+  // into fewer sink.write() calls to avoid TCP packet fragmentation (#2410).
+  constexpr size_t kItemCount = 1000;
+
+  UploadFormDataItems items;
+  items.reserve(kItemCount);
+  for (size_t i = 0; i < kItemCount; i++) {
+    items.push_back(
+        {"field" + std::to_string(i), "value" + std::to_string(i), "", ""});
+  }
+
+  const std::string boundary = "----test-boundary";
+  auto content_length = detail::get_multipart_content_length(items, boundary);
+  auto provider = detail::make_multipart_content_provider(items, boundary);
+
+  // Drive the provider the same way write_content_with_progress does
+  size_t write_count = 0;
+  size_t total_bytes = 0;
+
+  DataSink sink;
+  size_t offset = 0;
+  sink.write = [&](const char *d, size_t l) -> bool {
+    (void)d;
+    write_count++;
+    total_bytes += l;
+    offset += l;
+    return true;
+  };
+  sink.is_writable = []() -> bool { return true; };
+
+  while (offset < content_length) {
+    ASSERT_TRUE(provider(offset, content_length - offset, sink));
+  }
+
+  EXPECT_EQ(content_length, total_bytes);
+
+  // The total number of segments is 3 * kItemCount + 1 = 3001.
+  // With buffering into 64KB blocks, write_count should be much smaller.
+  auto segment_count = 3 * kItemCount + 1;
+  EXPECT_LT(write_count, segment_count / 10);
+}
+
+TEST(MultipartFormDataTest, ManyItemsEndToEnd) {
+  // Integration test: send many UploadFormDataItems and verify the server
+  // receives all of them correctly (#2410).
+  constexpr size_t kItemCount = 500;
+
+  auto handled = false;
+
+  Server svr;
+  svr.Post("/upload", [&](const Request &req, Response &res) {
+    EXPECT_EQ(kItemCount, req.form.fields.size());
+    for (size_t i = 0; i < kItemCount; i++) {
+      auto key = "field" + std::to_string(i);
+      auto val = "value" + std::to_string(i);
+      auto it = req.form.fields.find(key);
+      if (it != req.form.fields.end()) {
+        EXPECT_EQ(val, it->second.content);
+      } else {
+        ADD_FAILURE() << "Missing field: " << key;
+      }
+    }
+    res.set_content("ok", "text/plain");
+    handled = true;
+  });
+
+  auto port = svr.bind_to_any_port(HOST);
+  auto t = thread([&] { svr.listen_after_bind(); });
+  auto se = detail::scope_exit([&] {
+    svr.stop();
+    t.join();
+    ASSERT_FALSE(svr.is_running());
+    ASSERT_TRUE(handled);
+  });
+
+  svr.wait_until_ready();
+
+  UploadFormDataItems items;
+  items.reserve(kItemCount);
+  for (size_t i = 0; i < kItemCount; i++) {
+    items.push_back(
+        {"field" + std::to_string(i), "value" + std::to_string(i), "", ""});
+  }
+
+  Client cli(HOST, port);
+  auto res = cli.Post("/upload", items);
+  ASSERT_TRUE(res);
+  EXPECT_EQ(StatusCode::OK_200, res->status);
+}
+
 TEST(MultipartFormDataTest, MakeFileProvider) {
   // Verify make_file_provider sends a file's contents correctly.
   const std::string file_content(4096, 'Z');