1
0
Эх сурвалжийг харах

Add runtime configuration for WebSocket ping interval and related tests

yhirose 2 долоо хоног өмнө
parent
commit
257b266190

+ 20 - 0
README-websocket.md

@@ -353,6 +353,26 @@ if (ws.connect()) {
 | `CPPHTTPLIB_WEBSOCKET_CLOSE_TIMEOUT_SECOND` | `5`               | Timeout for waiting peer's Close response (seconds)      |
 | `CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND` | `30`              | Automatic Ping interval for heartbeat (seconds)          |
 
+### Runtime Ping Interval
+
+You can override the ping interval at runtime instead of changing the compile-time macro. Set it to `0` to disable automatic pings entirely.
+
+```cpp
+// Server side
+httplib::Server svr;
+svr.set_websocket_ping_interval(10);  // 10 seconds
+
+// Or using std::chrono
+svr.set_websocket_ping_interval(std::chrono::seconds(10));
+
+// Client side
+httplib::ws::WebSocketClient ws("ws://localhost:8080/ws");
+ws.set_websocket_ping_interval(10);  // 10 seconds
+
+// Disable automatic pings
+ws.set_websocket_ping_interval(0);
+```
+
 ## Threading Model
 
 WebSocket connections share the same thread pool as HTTP requests. Each WebSocket connection occupies one thread for its entire lifetime.

+ 43 - 9
httplib.h

@@ -1670,6 +1670,11 @@ public:
 
   Server &set_payload_max_length(size_t length);
 
+  Server &set_websocket_ping_interval(time_t sec);
+  template <class Rep, class Period>
+  Server &set_websocket_ping_interval(
+      const std::chrono::duration<Rep, Period> &duration);
+
   bool bind_to_port(const std::string &host, int port, int socket_flags = 0);
   int bind_to_any_port(const std::string &host, int socket_flags = 0);
   bool listen_after_bind();
@@ -1704,6 +1709,8 @@ protected:
   time_t idle_interval_sec_ = CPPHTTPLIB_IDLE_INTERVAL_SECOND;
   time_t idle_interval_usec_ = CPPHTTPLIB_IDLE_INTERVAL_USECOND;
   size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH;
+  time_t websocket_ping_interval_sec_ =
+      CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND;
 
 private:
   using Handlers =
@@ -3729,15 +3736,19 @@ private:
   friend class httplib::Server;
   friend class WebSocketClient;
 
-  WebSocket(Stream &strm, const Request &req, bool is_server)
-      : strm_(strm), req_(req), is_server_(is_server) {
+  WebSocket(
+      Stream &strm, const Request &req, bool is_server,
+      time_t ping_interval_sec = CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND)
+      : strm_(strm), req_(req), is_server_(is_server),
+        ping_interval_sec_(ping_interval_sec) {
     start_heartbeat();
   }
 
-  WebSocket(std::unique_ptr<Stream> &&owned_strm, const Request &req,
-            bool is_server)
+  WebSocket(
+      std::unique_ptr<Stream> &&owned_strm, const Request &req, bool is_server,
+      time_t ping_interval_sec = CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND)
       : strm_(*owned_strm), owned_strm_(std::move(owned_strm)), req_(req),
-        is_server_(is_server) {
+        is_server_(is_server), ping_interval_sec_(ping_interval_sec) {
     start_heartbeat();
   }
 
@@ -3748,6 +3759,7 @@ private:
   std::unique_ptr<Stream> owned_strm_;
   Request req_;
   bool is_server_;
+  time_t ping_interval_sec_;
   std::atomic<bool> closed_{false};
   std::mutex write_mutex_;
   std::thread ping_thread_;
@@ -3776,6 +3788,7 @@ public:
   const std::string &subprotocol() const;
   void set_read_timeout(time_t sec, time_t usec = 0);
   void set_write_timeout(time_t sec, time_t usec = 0);
+  void set_websocket_ping_interval(time_t sec);
 
 #ifdef CPPHTTPLIB_SSL_ENABLED
   void set_ca_cert_path(const std::string &path);
@@ -3799,6 +3812,8 @@ private:
   time_t read_timeout_usec_ = 0;
   time_t write_timeout_sec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND;
   time_t write_timeout_usec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND;
+  time_t websocket_ping_interval_sec_ =
+      CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND;
 
 #ifdef CPPHTTPLIB_SSL_ENABLED
   bool is_ssl_ = false;
@@ -10814,6 +10829,20 @@ inline Server &Server::set_payload_max_length(size_t length) {
   return *this;
 }
 
+inline Server &Server::set_websocket_ping_interval(time_t sec) {
+  websocket_ping_interval_sec_ = sec;
+  return *this;
+}
+
+template <class Rep, class Period>
+inline Server &Server::set_websocket_ping_interval(
+    const std::chrono::duration<Rep, Period> &duration) {
+  detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t /*usec*/) {
+    set_websocket_ping_interval(sec);
+  });
+  return *this;
+}
+
 inline bool Server::bind_to_port(const std::string &host, int port,
                                  int socket_flags) {
   auto ret = bind_internal(host, port, socket_flags);
@@ -11964,7 +11993,7 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
         {
           // Use WebSocket-specific read timeout instead of HTTP timeout
           strm.set_read_timeout(CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND, 0);
-          ws::WebSocket ws(strm, req, true);
+          ws::WebSocket ws(strm, req, true, websocket_ping_interval_sec_);
           entry.handler(req, ws);
         }
         return true;
@@ -20017,11 +20046,11 @@ inline WebSocket::~WebSocket() {
 }
 
 inline void WebSocket::start_heartbeat() {
+  if (ping_interval_sec_ == 0) { return; }
   ping_thread_ = std::thread([this]() {
     std::unique_lock<std::mutex> lock(ping_mutex_);
     while (!closed_) {
-      ping_cv_.wait_for(lock, std::chrono::seconds(
-                                  CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND));
+      ping_cv_.wait_for(lock, std::chrono::seconds(ping_interval_sec_));
       if (closed_) { break; }
       lock.unlock();
       if (!send_frame(Opcode::Ping, nullptr, 0)) {
@@ -20159,7 +20188,8 @@ inline bool WebSocketClient::connect() {
   Request req;
   req.method = "GET";
   req.path = path_;
-  ws_ = std::unique_ptr<WebSocket>(new WebSocket(std::move(strm), req, false));
+  ws_ = std::unique_ptr<WebSocket>(
+      new WebSocket(std::move(strm), req, false, websocket_ping_interval_sec_));
   return true;
 }
 
@@ -20199,6 +20229,10 @@ inline void WebSocketClient::set_write_timeout(time_t sec, time_t usec) {
   write_timeout_usec_ = usec;
 }
 
+inline void WebSocketClient::set_websocket_ping_interval(time_t sec) {
+  websocket_ping_interval_sec_ = sec;
+}
+
 #ifdef CPPHTTPLIB_SSL_ENABLED
 
 inline void WebSocketClient::set_ca_cert_path(const std::string &path) {

+ 79 - 0
test/test_websocket_heartbeat.cc

@@ -57,6 +57,85 @@ TEST_F(WebSocketHeartbeatTest, IdleConnectionStaysAlive) {
   client.close();
 }
 
+// Verify that set_websocket_ping_interval overrides the compile-time default
+TEST_F(WebSocketHeartbeatTest, RuntimePingIntervalOverride) {
+  // The server is already using the compile-time default (1s).
+  // Create a client with a custom runtime interval.
+  ws::WebSocketClient client("ws://localhost:" + std::to_string(port_) + "/ws");
+  client.set_websocket_ping_interval(2);
+  ASSERT_TRUE(client.connect());
+
+  // Sleep longer than read timeout (3s). Client heartbeat at 2s keeps alive.
+  std::this_thread::sleep_for(std::chrono::seconds(5));
+
+  ASSERT_TRUE(client.is_open());
+  ASSERT_TRUE(client.send("runtime interval"));
+  std::string msg;
+  ASSERT_TRUE(client.read(msg));
+  EXPECT_EQ("runtime interval", msg);
+
+  client.close();
+}
+
+// Verify that ping_interval=0 disables heartbeat without breaking basic I/O.
+TEST_F(WebSocketHeartbeatTest, ZeroDisablesHeartbeat) {
+  ws::WebSocketClient client("ws://localhost:" + std::to_string(port_) + "/ws");
+  client.set_websocket_ping_interval(0);
+  ASSERT_TRUE(client.connect());
+
+  // Basic send/receive still works with heartbeat disabled
+  ASSERT_TRUE(client.send("no client ping"));
+  std::string msg;
+  ASSERT_TRUE(client.read(msg));
+  EXPECT_EQ("no client ping", msg);
+
+  client.close();
+}
+
+// Verify that Server::set_websocket_ping_interval works at runtime
+class WebSocketServerPingIntervalTest : public ::testing::Test {
+protected:
+  void SetUp() override {
+    svr_.set_websocket_ping_interval(2);
+    svr_.WebSocket("/ws", [](const Request &, ws::WebSocket &ws) {
+      std::string msg;
+      while (ws.read(msg)) {
+        ws.send(msg);
+      }
+    });
+
+    port_ = svr_.bind_to_any_port("localhost");
+    thread_ = std::thread([this]() { svr_.listen_after_bind(); });
+    svr_.wait_until_ready();
+  }
+
+  void TearDown() override {
+    svr_.stop();
+    thread_.join();
+  }
+
+  Server svr_;
+  int port_;
+  std::thread thread_;
+};
+
+TEST_F(WebSocketServerPingIntervalTest, ServerRuntimeInterval) {
+  ws::WebSocketClient client("ws://localhost:" + std::to_string(port_) + "/ws");
+  ASSERT_TRUE(client.connect());
+
+  // Server ping interval is 2s; client uses compile-time default (1s).
+  // Both keep the connection alive.
+  std::this_thread::sleep_for(std::chrono::seconds(5));
+
+  ASSERT_TRUE(client.is_open());
+  ASSERT_TRUE(client.send("server interval"));
+  std::string msg;
+  ASSERT_TRUE(client.read(msg));
+  EXPECT_EQ("server interval", msg);
+
+  client.close();
+}
+
 // Verify that multiple heartbeat cycles work
 TEST_F(WebSocketHeartbeatTest, MultipleHeartbeatCycles) {
   ws::WebSocketClient client("ws://localhost:" + std::to_string(port_) + "/ws");