Browse Source

WebSocket and Dynamic Thread Pool support (#2368)

* WebSocket support

* Validate selected subprotocol in WebSocket handshake

* Fix problem with a Unit test

* Dynamic Thread Pool support

* Fix race condition in new Dynamic ThreadPool
yhirose 1 month ago
parent
commit
464867a9ce
11 changed files with 2434 additions and 55 deletions
  1. 10 0
      .github/workflows/test.yaml
  2. 4 0
      .gitignore
  3. 391 0
      README-websocket.md
  4. 49 4
      README.md
  5. 5 2
      example/Makefile
  6. 135 0
      example/wsecho.cc
  7. 764 40
      httplib.h
  8. 8 0
      test/Makefile
  9. 763 9
      test/test.cc
  10. 228 0
      test/test_thread_pool.cc
  11. 77 0
      test/test_websocket_heartbeat.cc

+ 10 - 0
.github/workflows/test.yaml

@@ -103,6 +103,11 @@ jobs:
       - name: run fuzz test target
         if: matrix.tls_backend == 'openssl'
         run: cd test && make fuzz_test
+      - name: build and run WebSocket heartbeat test
+        if: matrix.tls_backend == 'openssl'
+        run: cd test && make test_websocket_heartbeat && ./test_websocket_heartbeat
+      - name: build and run ThreadPool test
+        run: cd test && make test_thread_pool && ./test_thread_pool
 
   macos:
     runs-on: macos-latest
@@ -132,6 +137,11 @@ jobs:
       - name: run fuzz test target
         if: matrix.tls_backend == 'openssl'
         run: cd test && make fuzz_test
+      - name: build and run WebSocket heartbeat test
+        if: matrix.tls_backend == 'openssl'
+        run: cd test && make test_websocket_heartbeat && ./test_websocket_heartbeat
+      - name: build and run ThreadPool test
+        run: cd test && make test_thread_pool && ./test_thread_pool
 
   windows:
     runs-on: windows-latest

+ 4 - 0
.gitignore

@@ -29,6 +29,8 @@ example/server_and_client
 !example/server_and_client.*
 example/accept_header
 !example/accept_header.*
+example/wsecho
+!example/wsecho.*
 example/*.pem
 test/httplib.cc
 test/httplib.h
@@ -41,6 +43,8 @@ test/test_proxy_mbedtls
 test/test_split
 test/test_split_mbedtls
 test/test_split_no_tls
+test/test_websocket_heartbeat
+test/test_thread_pool
 test/test.xcodeproj/xcuser*
 test/test.xcodeproj/*/xcuser*
 test/*.o

+ 391 - 0
README-websocket.md

@@ -0,0 +1,391 @@
+# WebSocket - RFC 6455 WebSocket Support
+
+A simple, blocking WebSocket implementation for C++11.
+
+> [!IMPORTANT]
+> This is a blocking I/O WebSocket implementation using a thread-per-connection model. If you need high-concurrency WebSocket support with non-blocking/async I/O (e.g., thousands of simultaneous connections), this is not the one that you want.
+
+## Features
+
+- **RFC 6455 compliant**: Full WebSocket protocol support
+- **Server and Client**: Both sides included
+- **SSL/TLS support**: `wss://` scheme for secure connections
+- **Text and Binary**: Both message types supported
+- **Automatic heartbeat**: Periodic Ping/Pong keeps connections alive
+- **Subprotocol negotiation**: `Sec-WebSocket-Protocol` support for GraphQL, MQTT, etc.
+
+## Quick Start
+
+### Server
+
+```cpp
+httplib::Server svr;
+
+svr.WebSocket("/ws", [](const httplib::Request &req, httplib::ws::WebSocket &ws) {
+    std::string msg;
+    while (ws.read(msg)) {
+        ws.send("echo: " + msg);
+    }
+});
+
+svr.listen("localhost", 8080);
+```
+
+### Client
+
+```cpp
+httplib::ws::WebSocketClient ws("ws://localhost:8080/ws");
+
+if (ws.connect()) {
+    ws.send("hello");
+
+    std::string msg;
+    if (ws.read(msg)) {
+        std::cout << msg << std::endl;  // "echo: hello"
+    }
+    ws.close();
+}
+```
+
+## API Reference
+
+### ReadResult
+
+```cpp
+enum ReadResult : int {
+    Fail   = 0,  // Connection closed or error
+    Text   = 1,  // UTF-8 text message
+    Binary = 2,  // Binary message
+};
+```
+
+Returned by `read()`. Since `Fail` is `0`, the result works naturally in boolean contexts — `while (ws.read(msg))` continues until the connection closes. When you need to distinguish text from binary, check the return value directly.
+
+### CloseStatus
+
+```cpp
+enum class CloseStatus : uint16_t {
+    Normal = 1000,
+    GoingAway = 1001,
+    ProtocolError = 1002,
+    UnsupportedData = 1003,
+    NoStatus = 1005,
+    Abnormal = 1006,
+    InvalidPayload = 1007,
+    PolicyViolation = 1008,
+    MessageTooBig = 1009,
+    MandatoryExtension = 1010,
+    InternalError = 1011,
+};
+```
+
+### Server Registration
+
+```cpp
+// Basic handler
+Server &WebSocket(const std::string &pattern, WebSocketHandler handler);
+
+// With subprotocol negotiation
+Server &WebSocket(const std::string &pattern, WebSocketHandler handler,
+                  SubProtocolSelector sub_protocol_selector);
+```
+
+**Type aliases:**
+
+```cpp
+using WebSocketHandler =
+    std::function<void(const Request &, ws::WebSocket &)>;
+using SubProtocolSelector =
+    std::function<std::string(const std::vector<std::string> &protocols)>;
+```
+
+The `SubProtocolSelector` receives the list of subprotocols proposed by the client (from the `Sec-WebSocket-Protocol` header) and returns the selected one. Return an empty string to decline all proposed subprotocols.
+
+### WebSocket (Server-side)
+
+Passed to the handler registered with `Server::WebSocket()`. The handler runs in a dedicated thread per connection.
+
+```cpp
+// Read next message (blocks until received, returns Fail/Text/Binary)
+ReadResult read(std::string &msg);
+
+// Send messages
+bool send(const std::string &data);              // Text
+bool send(const char *data, size_t len);          // Binary
+
+// Close the connection
+void close(CloseStatus status = CloseStatus::Normal,
+           const std::string &reason = "");
+
+// Access the original HTTP upgrade request
+const Request &request() const;
+
+// Check if the connection is still open
+bool is_open() const;
+```
+
+### WebSocketClient
+
+```cpp
+// Constructor - accepts ws:// or wss:// URL
+explicit WebSocketClient(const std::string &scheme_host_port_path,
+                         const Headers &headers = {});
+
+// Check if the URL was parsed successfully
+bool is_valid() const;
+
+// Connect (performs HTTP upgrade handshake)
+bool connect();
+
+// Get the subprotocol selected by the server (empty if none)
+const std::string &subprotocol() const;
+
+// Read/Send/Close (same as server-side WebSocket)
+ReadResult read(std::string &msg);
+bool send(const std::string &data);
+bool send(const char *data, size_t len);
+void close(CloseStatus status = CloseStatus::Normal,
+           const std::string &reason = "");
+bool is_open() const;
+
+// Timeouts
+void set_read_timeout(time_t sec, time_t usec = 0);
+void set_write_timeout(time_t sec, time_t usec = 0);
+
+// SSL configuration (wss:// only, requires CPPHTTPLIB_OPENSSL_SUPPORT)
+void set_ca_cert_path(const std::string &path);
+void set_ca_cert_store(tls::ca_store_t store);
+void enable_server_certificate_verification(bool enabled);
+```
+
+## Examples
+
+### Echo Server with Connection Logging
+
+```cpp
+httplib::Server svr;
+
+svr.WebSocket("/ws", [](const httplib::Request &req, httplib::ws::WebSocket &ws) {
+    std::cout << "Connected from " << req.remote_addr << std::endl;
+
+    std::string msg;
+    while (ws.read(msg)) {
+        ws.send("echo: " + msg);
+    }
+
+    std::cout << "Disconnected" << std::endl;
+});
+
+svr.listen("localhost", 8080);
+```
+
+### Client: Continuous Read Loop
+
+```cpp
+httplib::ws::WebSocketClient ws("ws://localhost:8080/ws");
+
+if (ws.connect()) {
+    ws.send("hello");
+    ws.send("world");
+
+    std::string msg;
+    while (ws.read(msg)) {           // blocks until a message arrives
+        std::cout << msg << std::endl; // "echo: hello", "echo: world"
+    }
+    // read() returns false when the server closes the connection
+}
+```
+
+### Text and Binary Messages
+
+Check the `ReadResult` return value to distinguish between text and binary:
+
+```cpp
+// Server
+svr.WebSocket("/ws", [](const httplib::Request &req, httplib::ws::WebSocket &ws) {
+    std::string msg;
+    httplib::ws::ReadResult ret;
+    while ((ret = ws.read(msg))) {
+        if (ret == httplib::ws::Text) {
+            ws.send("echo: " + msg);
+        } else {
+            ws.send(msg.data(), msg.size());  // Binary echo
+        }
+    }
+});
+
+// Client
+httplib::ws::WebSocketClient ws("ws://localhost:8080/ws");
+if (ws.connect()) {
+    // Send binary data
+    const char binary[] = {0x00, 0x01, 0x02, 0x03};
+    ws.send(binary, sizeof(binary));
+
+    // Receive and check the type
+    std::string msg;
+    if (ws.read(msg) == httplib::ws::Binary) {
+        // Process binary data in msg
+    }
+    ws.close();
+}
+```
+
+### SSL Client
+
+```cpp
+httplib::ws::WebSocketClient ws("wss://echo.example.com/ws");
+
+if (ws.connect()) {
+    ws.send("hello over TLS");
+
+    std::string msg;
+    if (ws.read(msg)) {
+        std::cout << msg << std::endl;
+    }
+    ws.close();
+}
+```
+
+### Close with Status
+
+```cpp
+// Client-side: close with a specific status code and reason
+ws.close(httplib::ws::CloseStatus::GoingAway, "shutting down");
+
+// Server-side: close with a policy violation status
+ws.close(httplib::ws::CloseStatus::PolicyViolation, "forbidden");
+```
+
+### Accessing the Upgrade Request
+
+```cpp
+svr.WebSocket("/ws", [](const httplib::Request &req, httplib::ws::WebSocket &ws) {
+    // Access headers from the original HTTP upgrade request
+    auto auth = req.get_header_value("Authorization");
+    if (auth.empty()) {
+        ws.close(httplib::ws::CloseStatus::PolicyViolation, "unauthorized");
+        return;
+    }
+
+    std::string msg;
+    while (ws.read(msg)) {
+        ws.send("echo: " + msg);
+    }
+});
+```
+
+### Custom Headers and Timeouts
+
+```cpp
+httplib::Headers headers = {
+    {"Authorization", "Bearer token123"}
+};
+
+httplib::ws::WebSocketClient ws("ws://localhost:8080/ws", headers);
+ws.set_read_timeout(30, 0);   // 30 seconds
+ws.set_write_timeout(10, 0);  // 10 seconds
+
+if (ws.connect()) {
+    std::string msg;
+    while (ws.read(msg)) {
+        std::cout << msg << std::endl;
+    }
+}
+```
+
+### Subprotocol Negotiation
+
+The server can negotiate a subprotocol with the client using `Sec-WebSocket-Protocol`. This is required for protocols like GraphQL over WebSocket (`graphql-ws`) and MQTT.
+
+```cpp
+// Server: register a handler with a subprotocol selector
+svr.WebSocket(
+    "/ws",
+    [](const httplib::Request &req, httplib::ws::WebSocket &ws) {
+        std::string msg;
+        while (ws.read(msg)) {
+            ws.send("echo: " + msg);
+        }
+    },
+    [](const std::vector<std::string> &protocols) -> std::string {
+        // The client proposed a list of subprotocols; pick one
+        for (const auto &p : protocols) {
+            if (p == "graphql-ws" || p == "graphql-transport-ws") {
+                return p;
+            }
+        }
+        return "";  // Decline all
+    });
+
+// Client: propose subprotocols via Sec-WebSocket-Protocol header
+httplib::Headers headers = {
+    {"Sec-WebSocket-Protocol", "graphql-ws, graphql-transport-ws"}
+};
+httplib::ws::WebSocketClient ws("ws://localhost:8080/ws", headers);
+
+if (ws.connect()) {
+    // Check which subprotocol the server selected
+    std::cout << "Subprotocol: " << ws.subprotocol() << std::endl;
+    // => "graphql-ws"
+    ws.close();
+}
+```
+
+### SSL Client with Certificate Configuration
+
+```cpp
+httplib::ws::WebSocketClient ws("wss://example.com/ws");
+ws.set_ca_cert_path("/path/to/ca-bundle.crt");
+ws.enable_server_certificate_verification(true);
+
+if (ws.connect()) {
+    ws.send("secure message");
+    ws.close();
+}
+```
+
+## Configuration
+
+| Macro                                       | Default           | Description                                              |
+|---------------------------------------------|-------------------|----------------------------------------------------------|
+| `CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH`   | `16777216` (16MB) | Maximum payload size per message                         |
+| `CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND`  | `300`             | Read timeout for WebSocket connections (seconds)         |
+| `CPPHTTPLIB_WEBSOCKET_CLOSE_TIMEOUT_SECOND` | `5`               | Timeout for waiting peer's Close response (seconds)      |
+| `CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND` | `30`              | Automatic Ping interval for heartbeat (seconds)          |
+
+## Threading Model
+
+WebSocket connections share the same thread pool as HTTP requests. Each WebSocket connection occupies one thread for its entire lifetime. The default thread pool size is `CPPHTTPLIB_THREAD_POOL_COUNT` (8).
+
+This means that if you have 8 simultaneous WebSocket connections with the default settings, **no more HTTP requests can be processed** until a WebSocket connection closes and frees up a thread.
+
+If your application uses WebSocket, you should increase the thread pool size:
+
+```cpp
+httplib::Server svr;
+
+svr.new_task_queue = [] {
+  return new httplib::ThreadPool(128); // Increase from default 8
+};
+```
+
+Choose a size that accounts for both your expected HTTP load and the maximum number of simultaneous WebSocket connections.
+
+## Protocol
+
+The implementation follows [RFC 6455](https://tools.ietf.org/html/rfc6455):
+
+- Handshake via HTTP Upgrade with `Sec-WebSocket-Key` / `Sec-WebSocket-Accept`
+- Subprotocol negotiation via `Sec-WebSocket-Protocol`
+- Frame masking (client-to-server)
+- Control frames: Close, Ping, Pong
+- Message fragmentation and reassembly
+- Close handshake with status codes
+
+## Browser Test
+
+Run the echo server example and open `http://localhost:8080` in a browser:
+
+```bash
+cd example && make wsecho && ./wsecho
+```

+ 49 - 4
README.md

@@ -711,20 +711,24 @@ Please see [Server example](https://github.com/yhirose/cpp-httplib/blob/master/e
 
 ### Default thread pool support
 
-`ThreadPool` is used as the **default** task queue, with a default thread count of 8 or `std::thread::hardware_concurrency() - 1`, whichever is greater. You can change it with `CPPHTTPLIB_THREAD_POOL_COUNT`.
+`ThreadPool` is used as the **default** task queue, with dynamic scaling support. By default, it maintains a base thread count of 8 or `std::thread::hardware_concurrency() - 1` (whichever is greater), and can scale up to 4x that count under load. You can change these with `CPPHTTPLIB_THREAD_POOL_COUNT` and `CPPHTTPLIB_THREAD_POOL_MAX_COUNT`.
 
-If you want to set the thread count at runtime, there is no convenient way... But here is how.
+When all threads are busy and a new task arrives, a temporary thread is spawned (up to the maximum). When a dynamic thread finishes its task and the queue is empty, or after an idle timeout, it exits automatically. The idle timeout defaults to 3 seconds, configurable via `CPPHTTPLIB_THREAD_POOL_IDLE_TIMEOUT`.
+
+If you want to set the thread counts at runtime:
 
 ```cpp
-svr.new_task_queue = [] { return new ThreadPool(12); };
+svr.new_task_queue = [] { return new ThreadPool(/*base_threads=*/8, /*max_threads=*/64); };
 ```
 
+#### Max queued requests
+
 You can also provide an optional parameter to limit the maximum number
 of pending requests, i.e. requests `accept()`ed by the listener but
 still waiting to be serviced by worker threads.
 
 ```cpp
-svr.new_task_queue = [] { return new ThreadPool(/*num_threads=*/12, /*max_queued_requests=*/18); };
+svr.new_task_queue = [] { return new ThreadPool(/*base_threads=*/12, /*max_threads=*/0, /*max_queued_requests=*/18); };
 ```
 
 Default limit is 0 (unlimited). Once the limit is reached, the listener
@@ -1344,6 +1348,47 @@ int main() {
 
 See [README-sse.md](README-sse.md) for more details.
 
+WebSocket
+---------
+
+```cpp
+// Server
+httplib::Server svr;
+
+svr.WebSocket("/ws", [](const httplib::Request &req, httplib::ws::WebSocket &ws) {
+    httplib::ws::Message msg;
+    while (ws.read(msg)) {
+        if (msg.is_text()) {
+            ws.send("Echo: " + msg.data);
+        }
+    }
+});
+
+svr.listen("localhost", 8080);
+```
+
+```cpp
+// Client
+httplib::ws::WebSocketClient ws("ws://localhost:8080/ws");
+
+if (ws.connect()) {
+    ws.send("Hello, WebSocket!");
+
+    std::string msg;
+    if (ws.read(msg)) {
+        std::cout << "Received: " << msg << std::endl;
+    }
+
+    ws.close();
+}
+```
+
+SSL is also supported via `wss://` scheme (e.g. `WebSocketClient("wss://example.com/ws")`). Subprotocol negotiation (`Sec-WebSocket-Protocol`) is supported via `SubProtocolSelector` callback.
+
+> **Note:** WebSocket connections occupy a thread for their entire lifetime. If you plan to handle many simultaneous WebSocket connections, consider using a dynamic thread pool: `svr.new_task_queue = [] { return new ThreadPool(8, 64); };`
+
+See [README-websocket.md](README-websocket.md) for more details.
+
 Split httplib.h into .h and .cc
 -------------------------------
 

+ 5 - 2
example/Makefile

@@ -18,7 +18,7 @@ ZLIB_SUPPORT = -DCPPHTTPLIB_ZLIB_SUPPORT -lz
 BROTLI_DIR = $(PREFIX)/opt/brotli
 BROTLI_SUPPORT = -DCPPHTTPLIB_BROTLI_SUPPORT -I$(BROTLI_DIR)/include -L$(BROTLI_DIR)/lib -lbrotlicommon -lbrotlienc -lbrotlidec
 
-all: server client hello simplecli simplesvr upload redirect ssesvr ssecli benchmark one_time_request server_and_client accept_header
+all: server client hello simplecli simplesvr upload redirect ssesvr ssecli wsecho benchmark one_time_request server_and_client accept_header
 
 server : server.cc ../httplib.h Makefile
 	$(CXX) -o server $(CXXFLAGS) server.cc $(OPENSSL_SUPPORT) $(ZLIB_SUPPORT) $(BROTLI_SUPPORT)
@@ -47,6 +47,9 @@ ssesvr : ssesvr.cc ../httplib.h Makefile
 ssecli : ssecli.cc ../httplib.h Makefile
 	$(CXX) -o ssecli $(CXXFLAGS) ssecli.cc $(OPENSSL_SUPPORT) $(ZLIB_SUPPORT) $(BROTLI_SUPPORT)
 
+wsecho : wsecho.cc ../httplib.h Makefile
+	$(CXX) -o wsecho $(CXXFLAGS) wsecho.cc $(OPENSSL_SUPPORT) $(ZLIB_SUPPORT) $(BROTLI_SUPPORT)
+
 benchmark : benchmark.cc ../httplib.h Makefile
 	$(CXX) -o benchmark $(CXXFLAGS) benchmark.cc $(OPENSSL_SUPPORT) $(ZLIB_SUPPORT) $(BROTLI_SUPPORT)
 
@@ -64,4 +67,4 @@ pem:
 	openssl req -new -key key.pem | openssl x509 -days 3650 -req -signkey key.pem > cert.pem
 
 clean:
-	rm server client hello simplecli simplesvr upload redirect ssesvr ssecli benchmark one_time_request server_and_client accept_header *.pem
+	rm server client hello simplecli simplesvr upload redirect ssesvr ssecli wsecho benchmark one_time_request server_and_client accept_header *.pem

+ 135 - 0
example/wsecho.cc

@@ -0,0 +1,135 @@
+#include <httplib.h>
+#include <iostream>
+
+using namespace httplib;
+
+const auto html = R"HTML(
+<!DOCTYPE html>
+<html lang="en">
+<head>
+<meta charset="UTF-8">
+<title>WebSocket Demo</title>
+<style>
+  body { font-family: monospace; margin: 2em; }
+  #log { height: 300px; overflow-y: scroll; border: 1px solid #ccc; padding: 8px; }
+  .controls { margin: 8px 0; }
+  button { margin-right: 4px; }
+</style>
+</head>
+<body>
+<h1>WebSocket Demo</h1>
+<p>Server accepts subprotocols: <b>echo</b>, <b>chat</b> (or none)</p>
+
+<div class="controls">
+  <label>Subprotocols: </label>
+  <input id="protos" type="text" value="echo, chat" placeholder="leave empty for none" />
+  <button onclick="doConnect()">Connect</button>
+  <button onclick="doDisconnect()">Disconnect</button>
+</div>
+
+<div class="controls">
+  <input id="msg" type="text" placeholder="Type a message..." />
+  <button onclick="doSend()">Send</button>
+</div>
+
+<div class="controls">
+  <button onclick="startAuto()">Start Auto (1s)</button>
+  <button onclick="stopAuto()">Stop Auto</button>
+  <span id="auto-status"></span>
+</div>
+
+<pre id="log"></pre>
+
+<script>
+var sock = null;
+var logEl = document.getElementById("log");
+var statusEl = document.getElementById("auto-status");
+var timer = null;
+var seq = 0;
+
+function appendLog(text) {
+  logEl.textContent += text + "\n";
+  logEl.scrollTop = logEl.scrollHeight;
+}
+
+function doConnect() {
+  if (sock && sock.readyState <= 1) { sock.close(); }
+  var input = document.getElementById("protos").value.trim();
+  var protocols = input ? input.split(/\s*,\s*/).filter(Boolean) : [];
+  sock = new WebSocket("ws://" + location.host + "/ws", protocols);
+  appendLog("[connecting] proposed: " + (protocols.length ? protocols.join(", ") : "(none)"));
+  sock.onopen = function() { appendLog("[connected] subprotocol: " + (sock.protocol || "(none)")); };
+  sock.onclose = function() { appendLog("[disconnected]"); stopAuto(); };
+  sock.onmessage = function(e) { appendLog("< " + e.data); };
+}
+
+function doDisconnect() {
+  if (sock) { sock.close(); }
+}
+
+function doSend() {
+  var input = document.getElementById("msg");
+  if (!sock || sock.readyState !== 1 || input.value === "") return;
+  sock.send(input.value);
+  appendLog("> " + input.value);
+  input.value = "";
+}
+
+function startAuto() {
+  if (timer || !sock || sock.readyState !== 1) return;
+  seq = 0;
+  statusEl.textContent = "running...";
+  timer = setInterval(function() {
+    if (!sock || sock.readyState !== 1) { stopAuto(); return; }
+    var msg = "auto #" + seq++;
+    sock.send(msg);
+    appendLog("> " + msg);
+  }, 1000);
+}
+
+function stopAuto() {
+  if (timer) { clearInterval(timer); timer = null; }
+  statusEl.textContent = "";
+}
+
+document.getElementById("msg").addEventListener("keydown", function(e) {
+  if (e.key === "Enter") doSend();
+});
+
+doConnect();
+</script>
+</body>
+</html>
+)HTML";
+
+int main(void) {
+  Server svr;
+
+  svr.Get("/", [&](const Request & /*req*/, Response &res) {
+    res.set_content(html, "text/html");
+  });
+
+  svr.WebSocket(
+      "/ws",
+      [](const Request &req, ws::WebSocket &ws) {
+        std::cout << "WebSocket connected from " << req.remote_addr
+                  << std::endl;
+
+        std::string msg;
+        while (ws.read(msg)) {
+          std::cout << "Received: " << msg << std::endl;
+          ws.send("echo: " + msg);
+        }
+
+        std::cout << "WebSocket disconnected" << std::endl;
+      },
+      [](const std::vector<std::string> &protocols) -> std::string {
+        for (const auto &p : protocols) {
+          if (p == "echo" || p == "chat") { return p; }
+        }
+        return "";
+      });
+
+  std::cout << "Listening on http://localhost:8080" << std::endl;
+  svr.listen("localhost", 8080);
+}

File diff suppressed because it is too large
+ 764 - 40
httplib.h


+ 8 - 0
test/Makefile

@@ -155,6 +155,10 @@ test_no_tls : test.cc include_httplib.cc ../httplib.h Makefile
 test_split_no_tls : test.cc ../httplib.h httplib.cc Makefile
 	$(CXX) -o $@ $(CXXFLAGS) test.cc httplib.cc $(TEST_ARGS_NO_TLS)
 
+# ThreadPool unit tests (no TLS, no compression needed)
+test_thread_pool : test_thread_pool.cc ../httplib.h Makefile
+	$(CXX) -o $@ -I.. $(CXXFLAGS) test_thread_pool.cc gtest/src/gtest-all.cc gtest/src/gtest_main.cc -Igtest -Igtest/include -lpthread
+
 check_abi:
 	@./check-shared-library-abi-compatibility.sh
 
@@ -180,6 +184,10 @@ style_check: $(STYLE_CHECK_FILES)
 		echo "All files are properly formatted."; \
 	fi
 
+test_websocket_heartbeat : test_websocket_heartbeat.cc ../httplib.h Makefile
+	$(CXX) -o $@ -I.. $(CXXFLAGS) test_websocket_heartbeat.cc $(TEST_ARGS)
+	@file $@
+
 test_proxy : test_proxy.cc ../httplib.h Makefile cert.pem
 	$(CXX) -o $@ -I.. $(CXXFLAGS) test_proxy.cc $(TEST_ARGS)
 

+ 763 - 9
test/test.cc

@@ -4627,15 +4627,9 @@ TEST_F(ServerTest, HeaderCountExceedsLimit) {
   cli_.set_keep_alive(true);
   auto res = cli_.Get("/hi", headers);
 
-  // The request should either fail or return 400 Bad Request
-  if (res) {
-    // If we get a response, it should be 400 Bad Request
-    EXPECT_EQ(StatusCode::BadRequest_400, res->status);
-  } else {
-    // Or the request should fail entirely
-    EXPECT_FALSE(res);
-  }
-
+  // The server should respond with 400 Bad Request
+  ASSERT_TRUE(res);
+  EXPECT_EQ(StatusCode::BadRequest_400, res->status);
   EXPECT_EQ("close", res->get_header_value("Connection"));
   EXPECT_FALSE(cli_.is_socket_open());
 }
@@ -15704,3 +15698,763 @@ TEST(SSLClientServerTest, CustomizeServerSSLCtxMbedTLS) {
   ASSERT_EQ(StatusCode::OK_200, res->status);
 }
 #endif
+
+// WebSocket Tests
+
+TEST(WebSocketTest, RSVBitsMustBeZero) {
+  // RFC 6455 Section 5.2: RSV1, RSV2, RSV3 MUST be 0 unless an extension
+  // defining the meaning of these bits has been negotiated.
+  auto make_frame = [](uint8_t first_byte) {
+    std::string frame;
+    frame += static_cast<char>(first_byte); // FIN + RSV + opcode
+    frame += static_cast<char>(0x05);       // mask=0, payload_len=5
+    frame += "Hello";
+    return frame;
+  };
+
+  // RSV1 set (0x40)
+  {
+    detail::BufferStream strm;
+    strm.write(make_frame(0x81 | 0x40).data(), 8); // FIN + RSV1 + Text
+    ws::Opcode opcode;
+    std::string payload;
+    bool fin;
+    EXPECT_FALSE(ws::impl::read_websocket_frame(strm, opcode, payload, fin,
+                                                false, 1024));
+  }
+
+  // RSV2 set (0x20)
+  {
+    detail::BufferStream strm;
+    strm.write(make_frame(0x81 | 0x20).data(), 8); // FIN + RSV2 + Text
+    ws::Opcode opcode;
+    std::string payload;
+    bool fin;
+    EXPECT_FALSE(ws::impl::read_websocket_frame(strm, opcode, payload, fin,
+                                                false, 1024));
+  }
+
+  // RSV3 set (0x10)
+  {
+    detail::BufferStream strm;
+    strm.write(make_frame(0x81 | 0x10).data(), 8); // FIN + RSV3 + Text
+    ws::Opcode opcode;
+    std::string payload;
+    bool fin;
+    EXPECT_FALSE(ws::impl::read_websocket_frame(strm, opcode, payload, fin,
+                                                false, 1024));
+  }
+
+  // No RSV bits set - should succeed
+  {
+    detail::BufferStream strm;
+    strm.write(make_frame(0x81).data(), 8); // FIN + Text, no RSV
+    ws::Opcode opcode;
+    std::string payload;
+    bool fin;
+    EXPECT_TRUE(ws::impl::read_websocket_frame(strm, opcode, payload, fin,
+                                               false, 1024));
+    EXPECT_EQ(ws::Opcode::Text, opcode);
+    EXPECT_EQ("Hello", payload);
+    EXPECT_TRUE(fin);
+  }
+}
+
+TEST(WebSocketTest, ControlFrameValidation) {
+  // RFC 6455 Section 5.5: control frames MUST have FIN=1 and
+  // payload length <= 125.
+
+  // Ping with FIN=0 - must be rejected
+  {
+    detail::BufferStream strm;
+    std::string frame;
+    frame += static_cast<char>(0x09); // FIN=0, opcode=Ping
+    frame += static_cast<char>(0x00); // mask=0, payload_len=0
+    strm.write(frame.data(), frame.size());
+    ws::Opcode opcode;
+    std::string payload;
+    bool fin;
+    EXPECT_FALSE(ws::impl::read_websocket_frame(strm, opcode, payload, fin,
+                                                false, 1024));
+  }
+
+  // Close with FIN=0 - must be rejected
+  {
+    detail::BufferStream strm;
+    std::string frame;
+    frame += static_cast<char>(0x08); // FIN=0, opcode=Close
+    frame += static_cast<char>(0x00); // mask=0, payload_len=0
+    strm.write(frame.data(), frame.size());
+    ws::Opcode opcode;
+    std::string payload;
+    bool fin;
+    EXPECT_FALSE(ws::impl::read_websocket_frame(strm, opcode, payload, fin,
+                                                false, 1024));
+  }
+
+  // Ping with payload_len=126 (extended length) - must be rejected
+  {
+    detail::BufferStream strm;
+    std::string frame;
+    frame += static_cast<char>(0x89); // FIN=1, opcode=Ping
+    frame += static_cast<char>(126);  // payload_len=126 (>125)
+    frame += static_cast<char>(0x00); // extended length high byte
+    frame += static_cast<char>(126);  // extended length low byte
+    frame += std::string(126, 'x');
+    strm.write(frame.data(), frame.size());
+    ws::Opcode opcode;
+    std::string payload;
+    bool fin;
+    EXPECT_FALSE(ws::impl::read_websocket_frame(strm, opcode, payload, fin,
+                                                false, 1024));
+  }
+
+  // Ping with FIN=1 and payload_len=125 - should succeed
+  {
+    detail::BufferStream strm;
+    std::string frame;
+    frame += static_cast<char>(0x89); // FIN=1, opcode=Ping
+    frame += static_cast<char>(125);  // payload_len=125
+    frame += std::string(125, 'x');
+    strm.write(frame.data(), frame.size());
+    ws::Opcode opcode;
+    std::string payload;
+    bool fin;
+    EXPECT_TRUE(ws::impl::read_websocket_frame(strm, opcode, payload, fin,
+                                               false, 1024));
+    EXPECT_EQ(ws::Opcode::Ping, opcode);
+    EXPECT_EQ(125u, payload.size());
+    EXPECT_TRUE(fin);
+  }
+}
+
+TEST(WebSocketTest, PayloadLength64BitMSBMustBeZero) {
+  // RFC 6455 Section 5.2: the most significant bit of a 64-bit payload
+  // length MUST be 0.
+
+  // MSB set - must be rejected
+  {
+    detail::BufferStream strm;
+    std::string frame;
+    frame += static_cast<char>(0x81); // FIN=1, opcode=Text
+    frame += static_cast<char>(127);  // 64-bit extended length
+    frame += static_cast<char>(0x80); // MSB set (invalid)
+    frame += std::string(7, '\0');    // remaining 7 bytes of length
+    strm.write(frame.data(), frame.size());
+    ws::Opcode opcode;
+    std::string payload;
+    bool fin;
+    EXPECT_FALSE(ws::impl::read_websocket_frame(strm, opcode, payload, fin,
+                                                false, 1024));
+  }
+
+  // MSB clear - should pass length parsing (will be rejected by max_len,
+  // but that's a different check; use a small length to verify)
+  {
+    detail::BufferStream strm;
+    std::string frame;
+    frame += static_cast<char>(0x81); // FIN=1, opcode=Text
+    frame += static_cast<char>(127);  // 64-bit extended length
+    frame += std::string(7, '\0');    // high bytes = 0
+    frame += static_cast<char>(0x03); // length = 3
+    frame += "abc";
+    strm.write(frame.data(), frame.size());
+    ws::Opcode opcode;
+    std::string payload;
+    bool fin;
+    EXPECT_TRUE(ws::impl::read_websocket_frame(strm, opcode, payload, fin,
+                                               false, 1024));
+    EXPECT_EQ(ws::Opcode::Text, opcode);
+    EXPECT_EQ("abc", payload);
+  }
+}
+
+TEST(WebSocketTest, InvalidUTF8TextFrame) {
+  // RFC 6455 Section 5.6: text frames must contain valid UTF-8.
+
+  // Valid UTF-8
+  EXPECT_TRUE(ws::impl::is_valid_utf8("Hello"));
+  EXPECT_TRUE(ws::impl::is_valid_utf8("\xC3\xA9"));         // é (U+00E9)
+  EXPECT_TRUE(ws::impl::is_valid_utf8("\xE3\x81\x82"));     // あ (U+3042)
+  EXPECT_TRUE(ws::impl::is_valid_utf8("\xF0\x9F\x98\x80")); // 😀 (U+1F600)
+  EXPECT_TRUE(ws::impl::is_valid_utf8(""));
+
+  // Invalid UTF-8
+  EXPECT_FALSE(ws::impl::is_valid_utf8("\x80"));     // Invalid start byte
+  EXPECT_FALSE(ws::impl::is_valid_utf8("\xC3\x28")); // Bad continuation
+  EXPECT_FALSE(ws::impl::is_valid_utf8("\xC0\xAF")); // Overlong encoding
+  EXPECT_FALSE(
+      ws::impl::is_valid_utf8("\xED\xA0\x80")); // Surrogate half U+D800
+  EXPECT_FALSE(ws::impl::is_valid_utf8("\xF4\x90\x80\x80")); // Beyond U+10FFFF
+}
+
+TEST(WebSocketTest, ConnectAndDisconnect) {
+  Server svr;
+  svr.WebSocket("/ws", [](const Request &, ws::WebSocket &ws) {
+    std::string msg;
+    while (ws.read(msg)) {}
+  });
+
+  auto port = svr.bind_to_any_port(HOST);
+  std::thread t([&]() { svr.listen_after_bind(); });
+  svr.wait_until_ready();
+
+  ws::WebSocketClient client("ws://localhost:" + std::to_string(port) + "/ws");
+  ASSERT_TRUE(client.connect());
+  EXPECT_TRUE(client.is_open());
+  client.close();
+  EXPECT_FALSE(client.is_open());
+
+  svr.stop();
+  t.join();
+}
+
+TEST(WebSocketTest, ValidURL) {
+  ws::WebSocketClient ws1("ws://localhost:8080/path");
+  EXPECT_TRUE(ws1.is_valid());
+
+  ws::WebSocketClient ws2("ws://example.com/path");
+  EXPECT_TRUE(ws2.is_valid());
+
+  ws::WebSocketClient ws3("ws://example.com:9090/path/to/endpoint");
+  EXPECT_TRUE(ws3.is_valid());
+
+#ifdef CPPHTTPLIB_SSL_ENABLED
+  ws::WebSocketClient wss1("wss://example.com/path");
+  EXPECT_TRUE(wss1.is_valid());
+
+  ws::WebSocketClient wss2("wss://example.com:443/path");
+  EXPECT_TRUE(wss2.is_valid());
+#endif
+}
+
+TEST(WebSocketTest, InvalidURL) {
+  // No scheme
+  ws::WebSocketClient ws1("localhost:8080/path");
+  EXPECT_FALSE(ws1.is_valid());
+
+  // No path
+  ws::WebSocketClient ws2("ws://localhost:8080");
+  EXPECT_FALSE(ws2.is_valid());
+
+  // Empty string
+  ws::WebSocketClient ws3("");
+  EXPECT_FALSE(ws3.is_valid());
+
+  // Missing host
+  ws::WebSocketClient ws4("ws://:8080/path");
+  EXPECT_FALSE(ws4.is_valid());
+}
+
+TEST(WebSocketTest, UnsupportedScheme) {
+#ifdef CPPHTTPLIB_NO_EXCEPTIONS
+  ws::WebSocketClient ws1("http://localhost:8080/path");
+  EXPECT_FALSE(ws1.is_valid());
+
+  ws::WebSocketClient ws2("https://localhost:8080/path");
+  EXPECT_FALSE(ws2.is_valid());
+
+  ws::WebSocketClient ws3("ftp://localhost:8080/path");
+  EXPECT_FALSE(ws3.is_valid());
+#else
+  EXPECT_THROW(ws::WebSocketClient("http://localhost:8080/path"),
+               std::invalid_argument);
+
+  EXPECT_THROW(ws::WebSocketClient("ftp://localhost:8080/path"),
+               std::invalid_argument);
+#endif
+}
+
+TEST(WebSocketTest, ConnectWhenInvalid) {
+  ws::WebSocketClient ws("not a valid url");
+  EXPECT_FALSE(ws.is_valid());
+  EXPECT_FALSE(ws.connect());
+}
+
+TEST(WebSocketTest, DefaultPort) {
+  ws::WebSocketClient ws1("ws://example.com/path");
+  EXPECT_TRUE(ws1.is_valid());
+  // ws:// defaults to port 80 (verified by successful parse)
+
+#ifdef CPPHTTPLIB_SSL_ENABLED
+  ws::WebSocketClient ws2("wss://example.com/path");
+  EXPECT_TRUE(ws2.is_valid());
+  // wss:// defaults to port 443 (verified by successful parse)
+#endif
+}
+
+TEST(WebSocketTest, IPv6LiteralAddress) {
+  ws::WebSocketClient ws1("ws://[::1]:8080/path");
+  EXPECT_TRUE(ws1.is_valid());
+
+  ws::WebSocketClient ws2("ws://[fe80::1]:3000/ws");
+  EXPECT_TRUE(ws2.is_valid());
+}
+
+TEST(WebSocketTest, ComplexPath) {
+  ws::WebSocketClient ws1("ws://localhost:8080/path/to/endpoint");
+  EXPECT_TRUE(ws1.is_valid());
+
+  ws::WebSocketClient ws2("ws://localhost:8080/");
+  EXPECT_TRUE(ws2.is_valid());
+}
+
+class WebSocketIntegrationTest : public ::testing::Test {
+protected:
+  void SetUp() override {
+    server_ = httplib::detail::make_unique<Server>();
+    setup_server();
+    start_server();
+  }
+
+  void TearDown() override {
+    server_->stop();
+    if (server_thread_.joinable()) { server_thread_.join(); }
+  }
+
+  void setup_server() {
+    server_->WebSocket("/ws-echo", [](const Request &, ws::WebSocket &ws) {
+      std::string msg;
+      ws::ReadResult ret;
+      while ((ret = ws.read(msg))) {
+        if (ret == ws::Binary) {
+          ws.send(msg.data(), msg.size());
+        } else {
+          ws.send(msg);
+        }
+      }
+    });
+
+    server_->WebSocket("/ws-echo-string",
+                       [](const Request &, ws::WebSocket &ws) {
+                         std::string msg;
+                         while (ws.read(msg)) {
+                           ws.send("echo: " + msg);
+                         }
+                       });
+
+    server_->WebSocket(
+        "/ws-request-info", [](const Request &req, ws::WebSocket &ws) {
+          // Echo back request metadata
+          ws.send("path:" + req.path);
+          ws.send("header:" + req.get_header_value("X-Test-Header"));
+          std::string msg;
+          while (ws.read(msg)) {}
+        });
+
+    server_->WebSocket("/ws-close", [](const Request &, ws::WebSocket &ws) {
+      std::string msg;
+      ws.read(msg); // wait for a message
+      ws.close();
+    });
+
+    server_->WebSocket("/ws-close-status",
+                       [](const Request &, ws::WebSocket &ws) {
+                         std::string msg;
+                         ws.read(msg); // wait for a message
+                         ws.close(ws::CloseStatus::GoingAway, "shutting down");
+                       });
+
+    server_->WebSocket(
+        "/ws-subprotocol",
+        [](const Request &, ws::WebSocket &ws) {
+          std::string msg;
+          while (ws.read(msg)) {
+            ws.send(msg);
+          }
+        },
+        [](const std::vector<std::string> &protocols) -> std::string {
+          for (const auto &p : protocols) {
+            if (p == "graphql-ws") { return p; }
+          }
+          return "";
+        });
+  }
+
+  void start_server() {
+    port_ = server_->bind_to_any_port(HOST);
+    server_thread_ = std::thread([this]() { server_->listen_after_bind(); });
+    server_->wait_until_ready();
+  }
+
+  std::unique_ptr<Server> server_;
+  std::thread server_thread_;
+  int port_ = 0;
+};
+
+TEST_F(WebSocketIntegrationTest, TextEcho) {
+  ws::WebSocketClient client("ws://localhost:" + std::to_string(port_) +
+                             "/ws-echo");
+  ASSERT_TRUE(client.connect());
+  ASSERT_TRUE(client.is_open());
+
+  ASSERT_TRUE(client.send("Hello WebSocket"));
+  std::string msg;
+  EXPECT_EQ(ws::Text, client.read(msg));
+  EXPECT_EQ("Hello WebSocket", msg);
+
+  client.close();
+}
+
+TEST_F(WebSocketIntegrationTest, BinaryEcho) {
+  ws::WebSocketClient client("ws://localhost:" + std::to_string(port_) +
+                             "/ws-echo");
+  ASSERT_TRUE(client.connect());
+
+  std::string binary_data = {'\x00', '\x01', '\x02', '\xFF', '\xFE'};
+  ASSERT_TRUE(client.send(binary_data.data(), binary_data.size()));
+
+  std::string msg;
+  EXPECT_EQ(ws::Binary, client.read(msg));
+  EXPECT_EQ(binary_data, msg);
+
+  client.close();
+}
+
+TEST_F(WebSocketIntegrationTest, MultipleMessages) {
+  ws::WebSocketClient client("ws://localhost:" + std::to_string(port_) +
+                             "/ws-echo");
+  ASSERT_TRUE(client.connect());
+
+  for (int i = 0; i < 10; i++) {
+    auto text = "message " + std::to_string(i);
+    ASSERT_TRUE(client.send(text));
+    std::string msg;
+    ASSERT_TRUE(client.read(msg));
+    EXPECT_EQ(text, msg);
+  }
+
+  client.close();
+}
+
+TEST_F(WebSocketIntegrationTest, CloseHandshake) {
+  ws::WebSocketClient client("ws://localhost:" + std::to_string(port_) +
+                             "/ws-close");
+  ASSERT_TRUE(client.connect());
+
+  // Send a message to trigger the server to close
+  ASSERT_TRUE(client.send("trigger close"));
+
+  // The server will close, so read should return false
+  std::string msg;
+  EXPECT_FALSE(client.read(msg));
+  EXPECT_FALSE(client.is_open());
+}
+
+TEST_F(WebSocketIntegrationTest, LargeMessage) {
+  ws::WebSocketClient client("ws://localhost:" + std::to_string(port_) +
+                             "/ws-echo");
+  ASSERT_TRUE(client.connect());
+
+  // 128KB message
+  std::string large_data(128 * 1024, 'X');
+  ASSERT_TRUE(client.send(large_data));
+  std::string msg;
+  ASSERT_TRUE(client.read(msg));
+  EXPECT_EQ(large_data, msg);
+
+  client.close();
+}
+
+TEST_F(WebSocketIntegrationTest, ConcurrentSend) {
+  ws::WebSocketClient client("ws://localhost:" + std::to_string(port_) +
+                             "/ws-echo");
+  ASSERT_TRUE(client.connect());
+
+  const int num_threads = 4;
+  std::vector<std::thread> threads;
+  std::atomic<int> send_count{0};
+
+  for (int t = 0; t < num_threads; t++) {
+    threads.emplace_back([&client, &send_count, t]() {
+      for (int i = 0; i < 5; i++) {
+        auto text = "thread" + std::to_string(t) + "_msg" + std::to_string(i);
+        if (client.send(text)) { send_count++; }
+      }
+    });
+  }
+
+  for (auto &th : threads) {
+    th.join();
+  }
+
+  int received = 0;
+  std::string msg;
+  while (received < send_count.load()) {
+    if (!client.read(msg)) { break; }
+    received++;
+  }
+  EXPECT_EQ(send_count.load(), received);
+
+  client.close();
+}
+
+TEST_F(WebSocketIntegrationTest, ReadString) {
+  ws::WebSocketClient client("ws://localhost:" + std::to_string(port_) +
+                             "/ws-echo-string");
+  ASSERT_TRUE(client.connect());
+
+  ASSERT_TRUE(client.send("hello"));
+  std::string msg;
+  ASSERT_TRUE(client.read(msg));
+  EXPECT_EQ("echo: hello", msg);
+
+  ASSERT_TRUE(client.send("world"));
+  ASSERT_TRUE(client.read(msg));
+  EXPECT_EQ("echo: world", msg);
+
+  client.close();
+}
+
+TEST_F(WebSocketIntegrationTest, RequestAccess) {
+  Headers headers = {{"X-Test-Header", "test-value"}};
+  ws::WebSocketClient client(
+      "ws://localhost:" + std::to_string(port_) + "/ws-request-info", headers);
+  ASSERT_TRUE(client.connect());
+
+  std::string msg;
+  ASSERT_TRUE(client.read(msg));
+  EXPECT_EQ("path:/ws-request-info", msg);
+
+  ASSERT_TRUE(client.read(msg));
+  EXPECT_EQ("header:test-value", msg);
+
+  client.close();
+}
+
+TEST_F(WebSocketIntegrationTest, ReadTimeout) {
+  ws::WebSocketClient client("ws://localhost:" + std::to_string(port_) +
+                             "/ws-echo");
+  client.set_read_timeout(1, 0); // 1 second
+  ASSERT_TRUE(client.connect());
+
+  // Don't send anything — server echo handler waits for a message,
+  // so read() should time out and return false.
+  std::string msg;
+  EXPECT_FALSE(client.read(msg));
+}
+
+TEST_F(WebSocketIntegrationTest, MaxPayloadExceeded) {
+  ws::WebSocketClient client("ws://localhost:" + std::to_string(port_) +
+                             "/ws-echo");
+  client.set_read_timeout(5, 0);
+  ASSERT_TRUE(client.connect());
+
+  // Send a message exceeding CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH (16MB).
+  // The server should reject it and close the connection.
+  std::string oversized(CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH + 1, 'A');
+  client.send(oversized);
+
+  // The server's read() should have failed due to payload limit,
+  // so our read() should return false (connection closed).
+  std::string msg;
+  EXPECT_FALSE(client.read(msg));
+}
+
+TEST_F(WebSocketIntegrationTest, MaxPayloadAtLimit) {
+  ws::WebSocketClient client("ws://localhost:" + std::to_string(port_) +
+                             "/ws-echo");
+  client.set_read_timeout(10, 0);
+  ASSERT_TRUE(client.connect());
+
+  // Send a message exactly at CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH (16MB).
+  // This should succeed.
+  std::string at_limit(CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH, 'B');
+  ASSERT_TRUE(client.send(at_limit));
+
+  std::string msg;
+  ASSERT_TRUE(client.read(msg));
+  EXPECT_EQ(at_limit.size(), msg.size());
+
+  client.close();
+}
+
+TEST_F(WebSocketIntegrationTest, ConnectToInvalidPath) {
+  ws::WebSocketClient client("ws://localhost:" + std::to_string(port_) +
+                             "/nonexistent");
+  EXPECT_FALSE(client.connect());
+  EXPECT_FALSE(client.is_open());
+}
+
+TEST_F(WebSocketIntegrationTest, EmptyMessage) {
+  ws::WebSocketClient client("ws://localhost:" + std::to_string(port_) +
+                             "/ws-echo");
+  ASSERT_TRUE(client.connect());
+
+  ASSERT_TRUE(client.send(""));
+  std::string msg;
+  EXPECT_EQ(ws::Text, client.read(msg));
+  EXPECT_EQ("", msg);
+
+  client.close();
+}
+
+TEST_F(WebSocketIntegrationTest, Reconnect) {
+  ws::WebSocketClient client("ws://localhost:" + std::to_string(port_) +
+                             "/ws-echo");
+
+  // First connection
+  ASSERT_TRUE(client.connect());
+  ASSERT_TRUE(client.send("first"));
+  std::string msg;
+  ASSERT_TRUE(client.read(msg));
+  EXPECT_EQ("first", msg);
+  client.close();
+  EXPECT_FALSE(client.is_open());
+
+  // Reconnect using the same client object
+  ASSERT_TRUE(client.connect());
+  ASSERT_TRUE(client.is_open());
+  ASSERT_TRUE(client.send("second"));
+  ASSERT_TRUE(client.read(msg));
+  EXPECT_EQ("second", msg);
+  client.close();
+}
+
+TEST_F(WebSocketIntegrationTest, CloseWithStatus) {
+  ws::WebSocketClient client("ws://localhost:" + std::to_string(port_) +
+                             "/ws-close-status");
+  ASSERT_TRUE(client.connect());
+
+  // Trigger the server to close with GoingAway status
+  ASSERT_TRUE(client.send("trigger"));
+
+  // read() should return false after receiving the close frame
+  std::string msg;
+  EXPECT_FALSE(client.read(msg));
+  EXPECT_FALSE(client.is_open());
+}
+
+TEST_F(WebSocketIntegrationTest, ClientCloseWithStatus) {
+  ws::WebSocketClient client("ws://localhost:" + std::to_string(port_) +
+                             "/ws-echo");
+  ASSERT_TRUE(client.connect());
+
+  client.close(ws::CloseStatus::GoingAway, "client leaving");
+  EXPECT_FALSE(client.is_open());
+}
+
+TEST_F(WebSocketIntegrationTest, SubProtocolNegotiation) {
+  Headers headers = {{"Sec-WebSocket-Protocol", "mqtt, graphql-ws"}};
+  ws::WebSocketClient client(
+      "ws://localhost:" + std::to_string(port_) + "/ws-subprotocol", headers);
+  ASSERT_TRUE(client.connect());
+
+  // Server should have selected graphql-ws
+  EXPECT_EQ("graphql-ws", client.subprotocol());
+
+  client.close();
+}
+
+TEST_F(WebSocketIntegrationTest, SubProtocolNoMatch) {
+  Headers headers = {{"Sec-WebSocket-Protocol", "mqtt, wamp"}};
+  ws::WebSocketClient client(
+      "ws://localhost:" + std::to_string(port_) + "/ws-subprotocol", headers);
+  ASSERT_TRUE(client.connect());
+
+  // Server should not have selected any subprotocol
+  EXPECT_TRUE(client.subprotocol().empty());
+
+  client.close();
+}
+
+TEST_F(WebSocketIntegrationTest, SubProtocolNotRequested) {
+  // Connect without requesting any subprotocol
+  ws::WebSocketClient client("ws://localhost:" + std::to_string(port_) +
+                             "/ws-subprotocol");
+  ASSERT_TRUE(client.connect());
+
+  EXPECT_TRUE(client.subprotocol().empty());
+
+  client.close();
+}
+
+TEST(WebSocketPreRoutingTest, RejectWithoutAuth) {
+  Server svr;
+
+  svr.set_pre_routing_handler([](const Request &req, Response &res) {
+    if (!req.has_header("Authorization")) {
+      res.status = StatusCode::Unauthorized_401;
+      res.set_content("Unauthorized", "text/plain");
+      return Server::HandlerResponse::Handled;
+    }
+    return Server::HandlerResponse::Unhandled;
+  });
+
+  svr.WebSocket("/ws", [](const Request &, ws::WebSocket &ws) {
+    std::string msg;
+    while (ws.read(msg)) {
+      ws.send(msg);
+    }
+  });
+
+  auto port = svr.bind_to_any_port("localhost");
+  std::thread t([&]() { svr.listen_after_bind(); });
+  svr.wait_until_ready();
+
+  // Without Authorization header - should be rejected before upgrade
+  ws::WebSocketClient client1("ws://localhost:" + std::to_string(port) + "/ws");
+  EXPECT_FALSE(client1.connect());
+
+  // With Authorization header - should succeed
+  Headers headers = {{"Authorization", "Bearer token123"}};
+  ws::WebSocketClient client2("ws://localhost:" + std::to_string(port) + "/ws",
+                              headers);
+  ASSERT_TRUE(client2.connect());
+  ASSERT_TRUE(client2.send("hello"));
+  std::string msg;
+  ASSERT_TRUE(client2.read(msg));
+  EXPECT_EQ("hello", msg);
+  client2.close();
+
+  svr.stop();
+  t.join();
+}
+
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+class WebSocketSSLIntegrationTest : public ::testing::Test {
+protected:
+  void SetUp() override {
+    server_ = httplib::detail::make_unique<SSLServer>(SERVER_CERT_FILE,
+                                                      SERVER_PRIVATE_KEY_FILE);
+    server_->WebSocket("/ws-echo", [](const Request &, ws::WebSocket &ws) {
+      std::string msg;
+      ws::ReadResult ret;
+      while ((ret = ws.read(msg))) {
+        if (ret == ws::Binary) {
+          ws.send(msg.data(), msg.size());
+        } else {
+          ws.send(msg);
+        }
+      }
+    });
+    port_ = server_->bind_to_any_port(HOST);
+    server_thread_ = std::thread([this]() { server_->listen_after_bind(); });
+    server_->wait_until_ready();
+  }
+
+  void TearDown() override {
+    server_->stop();
+    if (server_thread_.joinable()) { server_thread_.join(); }
+  }
+
+  std::unique_ptr<SSLServer> server_;
+  std::thread server_thread_;
+  int port_ = 0;
+};
+
+TEST_F(WebSocketSSLIntegrationTest, TextEcho) {
+  ws::WebSocketClient client("wss://localhost:" + std::to_string(port_) +
+                             "/ws-echo");
+  client.enable_server_certificate_verification(false);
+  ASSERT_TRUE(client.connect());
+  ASSERT_TRUE(client.is_open());
+
+  ASSERT_TRUE(client.send("Hello WSS"));
+  std::string msg;
+  EXPECT_EQ(ws::Text, client.read(msg));
+  EXPECT_EQ("Hello WSS", msg);
+
+  client.close();
+}
+#endif

+ 228 - 0
test/test_thread_pool.cc

@@ -0,0 +1,228 @@
+// ThreadPool unit tests
+// Set a short idle timeout for faster shrink tests
+#define CPPHTTPLIB_THREAD_POOL_IDLE_TIMEOUT 1
+
+#include <httplib.h>
+
+#include <gtest/gtest.h>
+
+#include <atomic>
+#include <chrono>
+#include <thread>
+#include <vector>
+
+using namespace httplib;
+
+TEST(ThreadPoolTest, BasicTaskExecution) {
+  ThreadPool pool(4);
+  std::atomic<int> count(0);
+
+  for (int i = 0; i < 10; i++) {
+    pool.enqueue([&count]() { count++; });
+  }
+
+  pool.shutdown();
+  EXPECT_EQ(10, count.load());
+}
+
+TEST(ThreadPoolTest, FixedPoolWhenMaxEqualsBase) {
+  // max_n == 0 means max = base (fixed pool behavior)
+  ThreadPool pool(4);
+  std::atomic<int> count(0);
+
+  for (int i = 0; i < 100; i++) {
+    pool.enqueue([&count]() { count++; });
+  }
+
+  pool.shutdown();
+  EXPECT_EQ(100, count.load());
+}
+
+TEST(ThreadPoolTest, DynamicScaleUp) {
+  // base=2, max=8: block 2 base threads, then enqueue more tasks
+  ThreadPool pool(2, 8);
+
+  std::atomic<int> active(0);
+  std::atomic<int> max_active(0);
+  std::atomic<int> completed(0);
+  std::mutex barrier_mutex;
+  std::condition_variable barrier_cv;
+  bool release = false;
+
+  // Occupy all base threads with blocking tasks
+  for (int i = 0; i < 2; i++) {
+    pool.enqueue([&]() {
+      active++;
+      {
+        std::unique_lock<std::mutex> lock(barrier_mutex);
+        barrier_cv.wait(lock, [&] { return release; });
+      }
+      active--;
+      completed++;
+    });
+  }
+
+  // Wait for base threads to be occupied
+  std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+  // These should trigger dynamic thread creation
+  for (int i = 0; i < 4; i++) {
+    pool.enqueue([&]() {
+      int cur = ++active;
+      // Track peak active count
+      int prev = max_active.load();
+      while (cur > prev && !max_active.compare_exchange_weak(prev, cur)) {}
+      std::this_thread::sleep_for(std::chrono::milliseconds(50));
+      active--;
+      completed++;
+    });
+  }
+
+  // Wait for dynamic tasks to complete
+  std::this_thread::sleep_for(std::chrono::milliseconds(500));
+
+  // Release the blocking tasks
+  {
+    std::unique_lock<std::mutex> lock(barrier_mutex);
+    release = true;
+  }
+  barrier_cv.notify_all();
+
+  pool.shutdown();
+  EXPECT_EQ(6, completed.load());
+  // More than 2 threads were active simultaneously
+  EXPECT_GT(max_active.load(), 2);
+}
+
+TEST(ThreadPoolTest, DynamicShrinkAfterIdle) {
+  // CPPHTTPLIB_THREAD_POOL_IDLE_TIMEOUT is set to 1 second
+  ThreadPool pool(2, 8);
+
+  std::atomic<int> completed(0);
+
+  // Enqueue tasks that require dynamic threads
+  for (int i = 0; i < 8; i++) {
+    pool.enqueue([&]() {
+      std::this_thread::sleep_for(std::chrono::milliseconds(100));
+      completed++;
+    });
+  }
+
+  // Wait for all tasks to complete + idle timeout + margin
+  std::this_thread::sleep_for(std::chrono::milliseconds(2500));
+
+  // Now enqueue a simple task to verify the pool still works
+  // (base threads are still alive)
+  std::atomic<bool> final_task_done(false);
+  pool.enqueue([&]() { final_task_done = true; });
+
+  std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+  pool.shutdown();
+  EXPECT_EQ(8, completed.load());
+  EXPECT_TRUE(final_task_done.load());
+}
+
+TEST(ThreadPoolTest, ShutdownWithActiveDynamicThreads) {
+  ThreadPool pool(2, 8);
+
+  std::atomic<int> started(0);
+
+  std::mutex block_mutex;
+  std::condition_variable block_cv;
+  bool release = false;
+
+  // Start tasks on dynamic threads that block until released
+  for (int i = 0; i < 6; i++) {
+    pool.enqueue([&]() {
+      started++;
+      std::unique_lock<std::mutex> lock(block_mutex);
+      block_cv.wait(lock, [&] { return release; });
+    });
+  }
+
+  // Wait for tasks to start
+  std::this_thread::sleep_for(std::chrono::milliseconds(200));
+  EXPECT_GE(started.load(), 2);
+
+  // Release all blocked threads, then shutdown
+  {
+    std::unique_lock<std::mutex> lock(block_mutex);
+    release = true;
+  }
+  block_cv.notify_all();
+
+  pool.shutdown();
+}
+
+TEST(ThreadPoolTest, MaxQueuedRequests) {
+  // base=2, max=2 (fixed), mqr=3
+  ThreadPool pool(2, 2, 3);
+
+  std::mutex block_mutex;
+  std::condition_variable block_cv;
+  bool release = false;
+
+  // Block both threads
+  for (int i = 0; i < 2; i++) {
+    EXPECT_TRUE(pool.enqueue([&]() {
+      std::unique_lock<std::mutex> lock(block_mutex);
+      block_cv.wait(lock, [&] { return release; });
+    }));
+  }
+
+  std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+  // Fill the queue up to max_queued_requests
+  EXPECT_TRUE(pool.enqueue([]() {}));
+  EXPECT_TRUE(pool.enqueue([]() {}));
+  EXPECT_TRUE(pool.enqueue([]() {}));
+
+  // This should fail - queue is full
+  EXPECT_FALSE(pool.enqueue([]() {}));
+
+  // Release blocked threads
+  {
+    std::unique_lock<std::mutex> lock(block_mutex);
+    release = true;
+  }
+  block_cv.notify_all();
+
+  pool.shutdown();
+}
+
+#ifndef CPPHTTPLIB_NO_EXCEPTIONS
+TEST(ThreadPoolTest, InvalidMaxThreadsThrows) {
+  // max_n < n should throw
+  EXPECT_THROW(ThreadPool(8, 4), std::invalid_argument);
+}
+#endif
+
+TEST(ThreadPoolTest, EnqueueAfterShutdownReturnsFalse) {
+  ThreadPool pool(2);
+  pool.shutdown();
+  EXPECT_FALSE(pool.enqueue([]() {}));
+}
+
+TEST(ThreadPoolTest, ConcurrentEnqueue) {
+  ThreadPool pool(4, 16);
+  std::atomic<int> count(0);
+  const int num_producers = 4;
+  const int tasks_per_producer = 100;
+
+  std::vector<std::thread> producers;
+  for (int p = 0; p < num_producers; p++) {
+    producers.emplace_back([&]() {
+      for (int i = 0; i < tasks_per_producer; i++) {
+        pool.enqueue([&count]() { count++; });
+      }
+    });
+  }
+
+  for (auto &t : producers) {
+    t.join();
+  }
+
+  pool.shutdown();
+  EXPECT_EQ(num_producers * tasks_per_producer, count.load());
+}

+ 77 - 0
test/test_websocket_heartbeat.cc

@@ -0,0 +1,77 @@
+// Standalone test for WebSocket automatic heartbeat.
+// Compiled with a 1-second ping interval so we can verify heartbeat behavior
+// without waiting 30 seconds.
+
+#define CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND 1
+#define CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND 3
+#include <httplib.h>
+
+#include "gtest/gtest.h"
+
+using namespace httplib;
+
+class WebSocketHeartbeatTest : public ::testing::Test {
+protected:
+  void SetUp() override {
+    svr_.WebSocket("/ws", [](const Request &, ws::WebSocket &ws) {
+      std::string msg;
+      while (ws.read(msg)) {
+        ws.send(msg);
+      }
+    });
+
+    port_ = svr_.bind_to_any_port("localhost");
+    thread_ = std::thread([this]() { svr_.listen_after_bind(); });
+    svr_.wait_until_ready();
+  }
+
+  void TearDown() override {
+    svr_.stop();
+    thread_.join();
+  }
+
+  Server svr_;
+  int port_;
+  std::thread thread_;
+};
+
+// Verify that an idle connection stays alive beyond the read timeout
+// thanks to automatic heartbeat pings.
+TEST_F(WebSocketHeartbeatTest, IdleConnectionStaysAlive) {
+  ws::WebSocketClient client("ws://localhost:" + std::to_string(port_) + "/ws");
+  ASSERT_TRUE(client.connect());
+
+  // Sleep longer than read timeout (3s). Without heartbeat, the connection
+  // would time out. With heartbeat pings every 1s, it stays alive.
+  std::this_thread::sleep_for(std::chrono::seconds(5));
+
+  // Connection should still be open
+  ASSERT_TRUE(client.is_open());
+
+  // Verify we can still exchange messages
+  ASSERT_TRUE(client.send("hello after idle"));
+  std::string msg;
+  ASSERT_TRUE(client.read(msg));
+  EXPECT_EQ("hello after idle", msg);
+
+  client.close();
+}
+
+// Verify that multiple heartbeat cycles work
+TEST_F(WebSocketHeartbeatTest, MultipleHeartbeatCycles) {
+  ws::WebSocketClient client("ws://localhost:" + std::to_string(port_) + "/ws");
+  ASSERT_TRUE(client.connect());
+
+  // Wait through several heartbeat cycles
+  for (int i = 0; i < 3; i++) {
+    std::this_thread::sleep_for(std::chrono::milliseconds(1500));
+    ASSERT_TRUE(client.is_open());
+    std::string text = "msg" + std::to_string(i);
+    ASSERT_TRUE(client.send(text));
+    std::string msg;
+    ASSERT_TRUE(client.read(msg));
+    EXPECT_EQ(text, msg);
+  }
+
+  client.close();
+}

Some files were not shown because too many files changed in this diff