diff --git a/httpcore/_async/http_proxy.py b/httpcore/_async/http_proxy.py index cc9d9206..f83b12a6 100644 --- a/httpcore/_async/http_proxy.py +++ b/httpcore/_async/http_proxy.py @@ -312,9 +312,15 @@ async def handle_async_request(self, request: Request) -> Response: "server_hostname": self._remote_origin.host.decode("ascii"), "timeout": timeout, } - async with Trace("start_tls", logger, request, kwargs) as trace: - stream = await stream.start_tls(**kwargs) - trace.return_value = stream + try: + async with Trace("start_tls", logger, request, kwargs) as trace: + stream = await stream.start_tls(**kwargs) + trace.return_value = stream + except Exception: + # Close the underlying connection when TLS handshake fails to avoid + # zombie connections occupying the connection pool + await self._connection.aclose() + raise # Determine if we should be using HTTP/1.1 or HTTP/2 ssl_object = stream.get_extra_info("ssl_object") diff --git a/httpcore/_sync/http_proxy.py b/httpcore/_sync/http_proxy.py index ecca88f7..69310ea5 100644 --- a/httpcore/_sync/http_proxy.py +++ b/httpcore/_sync/http_proxy.py @@ -312,9 +312,15 @@ def handle_request(self, request: Request) -> Response: "server_hostname": self._remote_origin.host.decode("ascii"), "timeout": timeout, } - with Trace("start_tls", logger, request, kwargs) as trace: - stream = stream.start_tls(**kwargs) - trace.return_value = stream + try: + with Trace("start_tls", logger, request, kwargs) as trace: + stream = stream.start_tls(**kwargs) + trace.return_value = stream + except Exception: + # Close the underlying connection when TLS handshake fails to avoid + # zombie connections occupying the connection pool + self._connection.close() + raise # Determine if we should be using HTTP/1.1 or HTTP/2 ssl_object = stream.get_extra_info("ssl_object") diff --git a/tests/_async/test_http_proxy.py b/tests/_async/test_http_proxy.py index 84a984b8..81919466 100644 --- a/tests/_async/test_http_proxy.py +++ b/tests/_async/test_http_proxy.py @@ -224,7 +224,7 @@ async def test_proxy_tunneling_with_403(): """ network_backend = AsyncMockBackend( [ - b"HTTP/1.1 403 Permission Denied\r\n" b"\r\n", + b"HTTP/1.1 403 Permission Denied\r\n\r\n", ] ) @@ -276,3 +276,45 @@ def test_proxy_headers(): assert proxy.headers == [ (b"Proxy-Authorization", b"Basic dXNlcm5hbWU6cGFzc3dvcmQ=") ] + + +@pytest.mark.anyio +async def test_proxy_tunneling_tls_error(): + """ + Send an HTTPS request via a proxy, but the TLS handshake fails. + """ + + class BrokenTLSStream(AsyncMockStream): + async def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: typing.Optional[str] = None, + timeout: typing.Optional[float] = None, + ) -> AsyncNetworkStream: + raise OSError("TLS Failure") + + class BrokenTLSBackend(AsyncMockBackend): + async def connect_tcp( + self, + host: str, + port: int, + timeout: typing.Optional[float] = None, + local_address: typing.Optional[str] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> AsyncNetworkStream: + return BrokenTLSStream(list(self._buffer)) + + network_backend = BrokenTLSBackend( + [ + b"HTTP/1.1 200 OK\r\n\r\n", + ] + ) + + async with AsyncConnectionPool( + proxy=Proxy("http://localhost:8080/"), + network_backend=network_backend, + ) as proxy: + with pytest.raises(OSError, match="TLS Failure"): + await proxy.request("GET", "https://example.com/") + + assert not proxy.connections diff --git a/tests/_sync/test_http_proxy.py b/tests/_sync/test_http_proxy.py index 966672dd..34bbc166 100644 --- a/tests/_sync/test_http_proxy.py +++ b/tests/_sync/test_http_proxy.py @@ -224,7 +224,7 @@ def test_proxy_tunneling_with_403(): """ network_backend = MockBackend( [ - b"HTTP/1.1 403 Permission Denied\r\n" b"\r\n", + b"HTTP/1.1 403 Permission Denied\r\n\r\n", ] ) @@ -276,3 +276,45 @@ def test_proxy_headers(): assert proxy.headers == [ (b"Proxy-Authorization", b"Basic dXNlcm5hbWU6cGFzc3dvcmQ=") ] + + + +def test_proxy_tunneling_tls_error(): + """ + Send an HTTPS request via a proxy, but the TLS handshake fails. + """ + + class BrokenTLSStream(MockStream): + def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: typing.Optional[str] = None, + timeout: typing.Optional[float] = None, + ) -> NetworkStream: + raise OSError("TLS Failure") + + class BrokenTLSBackend(MockBackend): + def connect_tcp( + self, + host: str, + port: int, + timeout: typing.Optional[float] = None, + local_address: typing.Optional[str] = None, + socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, + ) -> NetworkStream: + return BrokenTLSStream(list(self._buffer)) + + network_backend = BrokenTLSBackend( + [ + b"HTTP/1.1 200 OK\r\n\r\n", + ] + ) + + with ConnectionPool( + proxy=Proxy("http://localhost:8080/"), + network_backend=network_backend, + ) as proxy: + with pytest.raises(OSError, match="TLS Failure"): + proxy.request("GET", "https://example.com/") + + assert not proxy.connections