yhirose hai 8 meses
pai
achega
365cbe37fa
Modificáronse 3 ficheiros con 104 adicións e 5 borrados
  1. 16 0
      README.md
  2. 30 4
      httplib.h
  3. 58 1
      test/test.cc

+ 16 - 0
README.md

@@ -285,6 +285,22 @@ svr.set_post_routing_handler([](const auto& req, auto& res) {
 });
 });
 ```
 ```
 
 
+### Pre request handler
+
+```cpp
+svr.set_pre_request_handler([](const auto& req, auto& res) {
+  if (req.matched_route == "/user/:user") {
+    auto user = req.path_params.at("user");
+    if (user != "john") {
+      res.status = StatusCode::Forbidden_403;
+      res.set_content("error", "text/html");
+      return Server::HandlerResponse::Handled;
+    }
+  }
+  return Server::HandlerResponse::Unhandled;
+});
+```
+
 ### 'multipart/form-data' POST data
 ### 'multipart/form-data' POST data
 
 
 ```cpp
 ```cpp

+ 30 - 4
httplib.h

@@ -636,6 +636,7 @@ using Ranges = std::vector<Range>;
 struct Request {
 struct Request {
   std::string method;
   std::string method;
   std::string path;
   std::string path;
+  std::string matched_route;
   Params params;
   Params params;
   Headers headers;
   Headers headers;
   std::string body;
   std::string body;
@@ -887,10 +888,16 @@ namespace detail {
 
 
 class MatcherBase {
 class MatcherBase {
 public:
 public:
+  MatcherBase(std::string pattern) : pattern_(pattern) {}
   virtual ~MatcherBase() = default;
   virtual ~MatcherBase() = default;
 
 
+  const std::string &pattern() const { return pattern_; }
+
   // Match request path and populate its matches and
   // Match request path and populate its matches and
   virtual bool match(Request &request) const = 0;
   virtual bool match(Request &request) const = 0;
+
+private:
+  std::string pattern_;
 };
 };
 
 
 /**
 /**
@@ -942,7 +949,8 @@ private:
  */
  */
 class RegexMatcher final : public MatcherBase {
 class RegexMatcher final : public MatcherBase {
 public:
 public:
-  RegexMatcher(const std::string &pattern) : regex_(pattern) {}
+  RegexMatcher(const std::string &pattern)
+      : MatcherBase(pattern), regex_(pattern) {}
 
 
   bool match(Request &request) const override;
   bool match(Request &request) const override;
 
 
@@ -1009,9 +1017,12 @@ public:
   }
   }
 
 
   Server &set_exception_handler(ExceptionHandler handler);
   Server &set_exception_handler(ExceptionHandler handler);
+
   Server &set_pre_routing_handler(HandlerWithResponse handler);
   Server &set_pre_routing_handler(HandlerWithResponse handler);
   Server &set_post_routing_handler(Handler handler);
   Server &set_post_routing_handler(Handler handler);
 
 
+  Server &set_pre_request_handler(HandlerWithResponse handler);
+
   Server &set_expect_100_continue_handler(Expect100ContinueHandler handler);
   Server &set_expect_100_continue_handler(Expect100ContinueHandler handler);
   Server &set_logger(Logger logger);
   Server &set_logger(Logger logger);
 
 
@@ -1153,6 +1164,7 @@ private:
   ExceptionHandler exception_handler_;
   ExceptionHandler exception_handler_;
   HandlerWithResponse pre_routing_handler_;
   HandlerWithResponse pre_routing_handler_;
   Handler post_routing_handler_;
   Handler post_routing_handler_;
+  HandlerWithResponse pre_request_handler_;
   Expect100ContinueHandler expect_100_continue_handler_;
   Expect100ContinueHandler expect_100_continue_handler_;
 
 
   Logger logger_;
   Logger logger_;
@@ -6224,7 +6236,8 @@ inline time_t BufferStream::duration() const { return 0; }
 
 
 inline const std::string &BufferStream::get_buffer() const { return buffer; }
 inline const std::string &BufferStream::get_buffer() const { return buffer; }
 
 
-inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern) {
+inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern)
+    : MatcherBase(pattern) {
   constexpr const char marker[] = "/:";
   constexpr const char marker[] = "/:";
 
 
   // One past the last ending position of a path param substring
   // One past the last ending position of a path param substring
@@ -6475,6 +6488,11 @@ inline Server &Server::set_post_routing_handler(Handler handler) {
   return *this;
   return *this;
 }
 }
 
 
+inline Server &Server::set_pre_request_handler(HandlerWithResponse handler) {
+  pre_request_handler_ = std::move(handler);
+  return *this;
+}
+
 inline Server &Server::set_logger(Logger logger) {
 inline Server &Server::set_logger(Logger logger) {
   logger_ = std::move(logger);
   logger_ = std::move(logger);
   return *this;
   return *this;
@@ -7129,7 +7147,11 @@ inline bool Server::dispatch_request(Request &req, Response &res,
     const auto &handler = x.second;
     const auto &handler = x.second;
 
 
     if (matcher->match(req)) {
     if (matcher->match(req)) {
-      handler(req, res);
+      req.matched_route = matcher->pattern();
+      if (!pre_request_handler_ ||
+          pre_request_handler_(req, res) != HandlerResponse::Handled) {
+        handler(req, res);
+      }
       return true;
       return true;
     }
     }
   }
   }
@@ -7256,7 +7278,11 @@ inline bool Server::dispatch_request_for_content_reader(
     const auto &handler = x.second;
     const auto &handler = x.second;
 
 
     if (matcher->match(req)) {
     if (matcher->match(req)) {
-      handler(req, res, content_reader);
+      req.matched_route = matcher->pattern();
+      if (!pre_request_handler_ ||
+          pre_request_handler_(req, res) != HandlerResponse::Handled) {
+        handler(req, res, content_reader);
+      }
       return true;
       return true;
     }
     }
   }
   }

+ 58 - 1
test/test.cc

@@ -2263,7 +2263,7 @@ TEST(NoContentTest, ContentLength) {
   }
   }
 }
 }
 
 
-TEST(RoutingHandlerTest, PreRoutingHandler) {
+TEST(RoutingHandlerTest, PreAndPostRoutingHandlers) {
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
 #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
   SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE);
   SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE);
   ASSERT_TRUE(svr.is_valid());
   ASSERT_TRUE(svr.is_valid());
@@ -2354,6 +2354,63 @@ TEST(RoutingHandlerTest, PreRoutingHandler) {
   }
   }
 }
 }
 
 
+TEST(RequestHandlerTest, PreRequestHandler) {
+  auto route_path = "/user/:user";
+
+  Server svr;
+
+  svr.Get("/hi", [](const Request &, Response &res) {
+    res.set_content("hi", "text/plain");
+  });
+
+  svr.Get(route_path, [](const Request &req, Response &res) {
+    res.set_content(req.path_params.at("user"), "text/plain");
+  });
+
+  svr.set_pre_request_handler([&](const Request &req, Response &res) {
+    if (req.matched_route == route_path) {
+      auto user = req.path_params.at("user");
+      if (user != "john") {
+        res.status = StatusCode::Forbidden_403;
+        res.set_content("error", "text/html");
+        return Server::HandlerResponse::Handled;
+      }
+    }
+    return Server::HandlerResponse::Unhandled;
+  });
+
+  auto thread = std::thread([&]() { svr.listen(HOST, PORT); });
+  auto se = detail::scope_exit([&] {
+    svr.stop();
+    thread.join();
+    ASSERT_FALSE(svr.is_running());
+  });
+
+  svr.wait_until_ready();
+
+  Client cli(HOST, PORT);
+  {
+    auto res = cli.Get("/hi");
+    ASSERT_TRUE(res);
+    EXPECT_EQ(StatusCode::OK_200, res->status);
+    EXPECT_EQ("hi", res->body);
+  }
+
+  {
+    auto res = cli.Get("/user/john");
+    ASSERT_TRUE(res);
+    EXPECT_EQ(StatusCode::OK_200, res->status);
+    EXPECT_EQ("john", res->body);
+  }
+
+  {
+    auto res = cli.Get("/user/invalid-user");
+    ASSERT_TRUE(res);
+    EXPECT_EQ(StatusCode::Forbidden_403, res->status);
+    EXPECT_EQ("error", res->body);
+  }
+}
+
 TEST(InvalidFormatTest, StatusCode) {
 TEST(InvalidFormatTest, StatusCode) {
   Server svr;
   Server svr;