Explorar o código

Fix problem caused by the recent performance improvement

yhirose hai 1 mes
pai
achega
21243b3c9e
Modificáronse 2 ficheiros con 34 adicións e 9 borrados
  1. 29 7
      httplib.h
  2. 5 2
      test/test.cc

+ 29 - 7
httplib.h

@@ -1382,6 +1382,7 @@ public:
   virtual bool is_readable() const = 0;
   virtual bool wait_readable() const = 0;
   virtual bool wait_writable() const = 0;
+  virtual bool is_peer_alive() const { return wait_writable(); }
 
   virtual ssize_t read(char *ptr, size_t size) = 0;
   virtual ssize_t write(const char *ptr, size_t size) = 0;
@@ -5453,6 +5454,7 @@ public:
   bool is_readable() const override;
   bool wait_readable() const override;
   bool wait_writable() const override;
+  bool is_peer_alive() const override;
   ssize_t read(char *ptr, size_t size) override;
   ssize_t write(const char *ptr, size_t size) override;
   void get_remote_ip_and_port(std::string &ip, int &port) const override;
@@ -7106,10 +7108,10 @@ inline bool write_content_with_progress(Stream &strm,
     return ok;
   };
 
-  data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); };
+  data_sink.is_writable = [&]() -> bool { return strm.is_peer_alive(); };
 
   while (offset < end_offset && !is_shutting_down()) {
-    if (!strm.wait_writable()) {
+    if (!strm.wait_writable() || !strm.is_peer_alive()) {
       error = Error::Write;
       return false;
     } else if (!content_provider(offset, end_offset - offset, data_sink)) {
@@ -7121,6 +7123,11 @@ inline bool write_content_with_progress(Stream &strm,
     }
   }
 
+  if (offset < end_offset) { // exited due to is_shutting_down(), not completion
+    error = Error::Write;
+    return false;
+  }
+
   error = Error::Success;
   return true;
 }
@@ -7160,12 +7167,12 @@ write_content_without_length(Stream &strm,
     return ok;
   };
 
-  data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); };
+  data_sink.is_writable = [&]() -> bool { return strm.is_peer_alive(); };
 
   data_sink.done = [&](void) { data_available = false; };
 
   while (data_available && !is_shutting_down()) {
-    if (!strm.wait_writable()) {
+    if (!strm.wait_writable() || !strm.is_peer_alive()) {
       return false;
     } else if (!content_provider(offset, 0, data_sink)) {
       return false;
@@ -7173,7 +7180,8 @@ write_content_without_length(Stream &strm,
       return false;
     }
   }
-  return true;
+  return !data_available; // true only if done() was called, false if shutting
+                          // down
 }
 
 template <typename T, typename U>
@@ -7209,7 +7217,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider,
     return ok;
   };
 
-  data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); };
+  data_sink.is_writable = [&]() -> bool { return strm.is_peer_alive(); };
 
   auto done_with_trailer = [&](const Headers *trailer) {
     if (!ok) { return; }
@@ -7259,7 +7267,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider,
   };
 
   while (data_available && !is_shutting_down()) {
-    if (!strm.wait_writable()) {
+    if (!strm.wait_writable() || !strm.is_peer_alive()) {
       error = Error::Write;
       return false;
     } else if (!content_provider(offset, 0, data_sink)) {
@@ -7271,6 +7279,11 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider,
     }
   }
 
+  if (data_available) { // exited due to is_shutting_down(), not done()
+    error = Error::Write;
+    return false;
+  }
+
   error = Error::Success;
   return true;
 }
@@ -8439,6 +8452,7 @@ public:
   bool is_readable() const override;
   bool wait_readable() const override;
   bool wait_writable() const override;
+  bool is_peer_alive() const override;
   ssize_t read(char *ptr, size_t size) override;
   ssize_t write(const char *ptr, size_t size) override;
   void get_remote_ip_and_port(std::string &ip, int &port) const override;
@@ -9865,6 +9879,10 @@ inline bool SocketStream::wait_writable() const {
   return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0;
 }
 
+inline bool SocketStream::is_peer_alive() const {
+  return detail::is_socket_alive(sock_);
+}
+
 inline ssize_t SocketStream::read(char *ptr, size_t size) {
 #ifdef _WIN32
   size =
@@ -10196,6 +10214,10 @@ inline bool SSLSocketStream::wait_writable() const {
          !tls::is_peer_closed(session_, sock_);
 }
 
+inline bool SSLSocketStream::is_peer_alive() const {
+  return !tls::is_peer_closed(session_, sock_);
+}
+
 inline ssize_t SSLSocketStream::read(char *ptr, size_t size) {
   if (tls::pending(session_) > 0) {
     tls::TlsError err;

+ 5 - 2
test/test.cc

@@ -323,9 +323,10 @@ TEST(SocketStream, wait_writable_UNIX) {
   asSocketStream(fds[0], [&](Stream &s0) {
     EXPECT_EQ(s0.socket(), fds[0]);
     EXPECT_TRUE(s0.wait_writable());
+    EXPECT_TRUE(s0.is_peer_alive());
 
     EXPECT_EQ(0, close(fds[1]));
-    EXPECT_FALSE(s0.wait_writable());
+    EXPECT_FALSE(s0.is_peer_alive());
 
     return true;
   });
@@ -367,7 +368,9 @@ TEST(SocketStream, wait_writable_INET) {
   };
   asSocketStream(disconnected_svr_sock, [&](Stream &ss) {
     EXPECT_EQ(ss.socket(), disconnected_svr_sock);
-    EXPECT_FALSE(ss.wait_writable());
+    // wait_writable() returns true because select_write() only checks if the
+    // send buffer has space. Peer disconnection is detected later by send().
+    EXPECT_TRUE(ss.wait_writable());
 
     return true;
   });