소스 검색

SSE Client: Update Authorization Header
Fixes #2402

yhirose 1 주 전
부모
커밋
c2bdb1c5c1
3개의 변경된 파일138개의 추가작업 그리고 10개의 파일을 삭제
  1. 22 0
      README-sse.md
  2. 24 10
      httplib.h
  3. 92 0
      test/test.cc

+ 22 - 0
README-sse.md

@@ -74,6 +74,9 @@ sse.set_reconnect_interval(5000);
 
 // Set max reconnect attempts (default: 0 = unlimited)
 sse.set_max_reconnect_attempts(10);
+
+// Update headers at any time (thread-safe)
+sse.set_headers({{"Authorization", "Bearer new_token"}});
 ```
 
 #### Control
@@ -154,6 +157,25 @@ httplib::sse::SSEClient sse(cli, "/events", headers);
 sse.start();
 ```
 
+### Refreshing Auth Token on Reconnect
+
+```cpp
+httplib::sse::SSEClient sse(cli, "/events",
+    {{"Authorization", "Bearer " + get_token()}});
+
+// Preemptively refresh token on each successful connection
+sse.on_open([&sse]() {
+    sse.set_headers({{"Authorization", "Bearer " + get_token()}});
+});
+
+// Or reactively refresh on auth failure (401 triggers reconnect)
+sse.on_error([&sse](httplib::Error) {
+    sse.set_headers({{"Authorization", "Bearer " + refresh_token()}});
+});
+
+sse.start();
+```
+
 ### Error Handling
 
 ```cpp

+ 24 - 10
httplib.h

@@ -3660,6 +3660,9 @@ public:
   SSEClient &set_reconnect_interval(int ms);
   SSEClient &set_max_reconnect_attempts(int n);
 
+  // Update headers (thread-safe)
+  SSEClient &set_headers(const Headers &headers);
+
   // State accessors
   bool is_connected() const;
   const std::string &last_event_id() const;
@@ -3684,6 +3687,7 @@ private:
   Client &client_;
   std::string path_;
   Headers headers_;
+  mutable std::mutex headers_mutex_;
 
   // Callbacks
   MessageHandler on_message_;
@@ -3992,6 +3996,12 @@ inline SSEClient &SSEClient::set_max_reconnect_attempts(int n) {
   return *this;
 }
 
+inline SSEClient &SSEClient::set_headers(const Headers &headers) {
+  std::lock_guard<std::mutex> lock(headers_mutex_);
+  headers_ = headers;
+  return *this;
+}
+
 inline bool SSEClient::is_connected() const { return connected_.load(); }
 
 inline const std::string &SSEClient::last_event_id() const {
@@ -4070,7 +4080,11 @@ inline void SSEClient::run_event_loop() {
 
   while (running_.load()) {
     // Build headers, including Last-Event-ID if we have one
-    auto request_headers = headers_;
+    Headers request_headers;
+    {
+      std::lock_guard<std::mutex> lock(headers_mutex_);
+      request_headers = headers_;
+    }
     if (!last_event_id_.empty()) {
       request_headers.emplace("Last-Event-ID", last_event_id_);
     }
@@ -4089,19 +4103,19 @@ inline void SSEClient::run_event_loop() {
       continue;
     }
 
-    if (result.status() != 200) {
+    if (result.status() != StatusCode::OK_200) {
       connected_.store(false);
-      // For certain errors, don't reconnect
-      if (result.status() == 204 || // No Content - server wants us to stop
-          result.status() == 404 || // Not Found
-          result.status() == 401 || // Unauthorized
-          result.status() == 403) { // Forbidden
-        if (on_error_) { on_error_(Error::Connection); }
+      if (on_error_) { on_error_(Error::Connection); }
+
+      // For certain errors, don't reconnect.
+      // Note: 401 is intentionally absent so that handlers can refresh
+      // credentials via set_headers() and let the client reconnect.
+      if (result.status() == StatusCode::NoContent_204 ||
+          result.status() == StatusCode::NotFound_404 ||
+          result.status() == StatusCode::Forbidden_403) {
         break;
       }
 
-      if (on_error_) { on_error_(Error::Connection); }
-
       if (!should_reconnect(reconnect_count)) { break; }
       wait_for_reconnect();
       reconnect_count++;

+ 92 - 0
test/test.cc

@@ -16007,6 +16007,98 @@ TEST_F(SSEIntegrationTest, LastEventIdSentOnReconnect) {
   }
 }
 
+// Test: set_headers updates headers used on reconnect
+TEST_F(SSEIntegrationTest, SetHeadersUpdatesOnReconnect) {
+  std::vector<std::string> received_tokens;
+  std::mutex token_mutex;
+
+  // Endpoint that captures Authorization header
+  server_->Get("/auth-check", [&](const Request &req, Response &res) {
+    {
+      std::lock_guard<std::mutex> lock(token_mutex);
+      received_tokens.push_back(req.get_header_value("Authorization"));
+    }
+    res.set_chunked_content_provider(
+        "text/event-stream", [](size_t offset, DataSink &sink) {
+          if (offset == 0) {
+            std::string event = "data: hello\n\n";
+            sink.write(event.data(), event.size());
+          }
+          return false; // Close connection to trigger reconnect
+        });
+  });
+
+  Client client("localhost", get_port());
+  Headers headers = {{"Authorization", "Bearer old-token"}};
+  sse::SSEClient sse(client, "/auth-check", headers);
+
+  // Update headers on each successful connection
+  sse.on_open(
+      [&sse]() { sse.set_headers({{"Authorization", "Bearer new-token"}}); });
+
+  sse.set_reconnect_interval(100);
+  sse.set_max_reconnect_attempts(3);
+  sse.start_async();
+
+  std::this_thread::sleep_for(std::chrono::milliseconds(800));
+  sse.stop();
+
+  std::lock_guard<std::mutex> lock(token_mutex);
+  ASSERT_GE(received_tokens.size(), 2u);
+  // First connection uses original header
+  EXPECT_EQ(received_tokens[0], "Bearer old-token");
+  // Second connection uses updated header from set_headers
+  EXPECT_EQ(received_tokens[1], "Bearer new-token");
+}
+
+// Test: 401 allows reconnection (so on_error can refresh headers)
+TEST_F(SSEIntegrationTest, ReconnectOn401WithHeaderRefresh) {
+  std::atomic<int> connection_count{0};
+
+  // Endpoint: returns 401 on first attempt, 200 on second
+  server_->Get("/auth-retry", [&](const Request &req, Response &res) {
+    int conn = connection_count.fetch_add(1);
+    if (conn == 0 || req.get_header_value("Authorization") != "Bearer valid") {
+      res.status = StatusCode::Unauthorized_401;
+      res.set_content("Unauthorized", "text/plain");
+      return;
+    }
+    res.set_chunked_content_provider(
+        "text/event-stream", [](size_t offset, DataSink &sink) {
+          if (offset == 0) {
+            std::string event = "data: authenticated\n\n";
+            sink.write(event.data(), event.size());
+          }
+          return false;
+        });
+  });
+
+  Client client("localhost", get_port());
+  Headers headers = {{"Authorization", "Bearer expired"}};
+  sse::SSEClient sse(client, "/auth-retry", headers);
+
+  std::atomic<bool> message_received{false};
+
+  // Refresh token on error
+  sse.on_error(
+      [&sse](Error) { sse.set_headers({{"Authorization", "Bearer valid"}}); });
+
+  sse.on_message([&](const sse::SSEMessage &msg) {
+    if (msg.data == "authenticated") { message_received.store(true); }
+  });
+
+  sse.set_reconnect_interval(100);
+  sse.set_max_reconnect_attempts(3);
+  sse.start_async();
+
+  std::this_thread::sleep_for(std::chrono::milliseconds(800));
+  sse.stop();
+
+  // Should have reconnected after 401 and succeeded with new token
+  EXPECT_GE(connection_count.load(), 2);
+  EXPECT_TRUE(message_received.load());
+}
+
 TEST(Issue2318Test, EmptyHostString) {
   {
     httplib::Client cli_empty("", PORT);