fix: 포트 충돌 회피 — note_bridge 8098, intent_service 8099
Jellyfin(8096), OrbStack(8097) 포트 충돌으로 변경. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
This subpackage hold anything that is very relevant
|
||||
to the HTTP ecosystem but not per-say Niquests core logic.
|
||||
"""
|
||||
@@ -0,0 +1,394 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import typing
|
||||
from datetime import timedelta
|
||||
|
||||
from pyodide.ffi import run_sync # type: ignore[import]
|
||||
from pyodide.http import pyfetch # type: ignore[import]
|
||||
|
||||
from ..._constant import DEFAULT_RETRIES
|
||||
from ...adapters import BaseAdapter
|
||||
from ...exceptions import ConnectionError, ConnectTimeout
|
||||
from ...models import PreparedRequest, Response
|
||||
from ...packages.urllib3.exceptions import MaxRetryError
|
||||
from ...packages.urllib3.response import BytesQueueBuffer
|
||||
from ...packages.urllib3.response import HTTPResponse as BaseHTTPResponse
|
||||
from ...packages.urllib3.util import Timeout as TimeoutSauce
|
||||
from ...packages.urllib3.util.retry import Retry
|
||||
from ...structures import CaseInsensitiveDict
|
||||
from ...utils import get_encoding_from_headers
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ...typing import ProxyType, RetryType, TLSClientCertType, TLSVerifyType
|
||||
|
||||
|
||||
class _PyodideRawIO:
|
||||
"""File-like wrapper around a Pyodide Fetch response with true streaming via JSPI.
|
||||
|
||||
When constructed with a JS Response object, reads chunks incrementally from the
|
||||
JavaScript ReadableStream using ``run_sync(reader.read())`` per chunk.
|
||||
When constructed with preloaded content (non-streaming), serves from a memory buffer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
js_response: typing.Any = None,
|
||||
preloaded_content: bytes | None = None,
|
||||
) -> None:
|
||||
self._js_response = js_response
|
||||
self._buffer = BytesQueueBuffer()
|
||||
self._closed = False
|
||||
self._finished = preloaded_content is not None or js_response is None
|
||||
self._reader: typing.Any = None
|
||||
self.headers: dict[str, str] = {}
|
||||
self.extension: typing.Any = None
|
||||
|
||||
if preloaded_content is not None:
|
||||
self._buffer.put(preloaded_content)
|
||||
|
||||
def _ensure_reader(self) -> None:
|
||||
"""Initialize the ReadableStream reader if not already done."""
|
||||
if self._reader is None and self._js_response is not None:
|
||||
try:
|
||||
body = self._js_response.body
|
||||
if body is not None:
|
||||
self._reader = body.getReader()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _get_next_chunk(self) -> bytes | None:
|
||||
"""Read the next chunk from the JS ReadableStream, blocking via JSPI."""
|
||||
self._ensure_reader()
|
||||
|
||||
if self._reader is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
result = run_sync(self._reader.read())
|
||||
|
||||
if result.done:
|
||||
return None
|
||||
|
||||
value = result.value
|
||||
if value is not None:
|
||||
return bytes(value.to_py())
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def read(
|
||||
self,
|
||||
amt: int | None = None,
|
||||
decode_content: bool = True,
|
||||
) -> bytes:
|
||||
if self._closed:
|
||||
return self._buffer.get(len(self._buffer)) if len(self._buffer) > 0 else b""
|
||||
|
||||
if self._finished:
|
||||
if len(self._buffer) == 0:
|
||||
return b""
|
||||
if amt is None or amt < 0:
|
||||
return self._buffer.get(len(self._buffer))
|
||||
return self._buffer.get(min(amt, len(self._buffer)))
|
||||
|
||||
if amt is None or amt < 0:
|
||||
# Read everything remaining
|
||||
while True:
|
||||
chunk = self._get_next_chunk()
|
||||
if chunk is None:
|
||||
break
|
||||
self._buffer.put(chunk)
|
||||
self._finished = True
|
||||
return self._buffer.get(len(self._buffer)) if len(self._buffer) > 0 else b""
|
||||
|
||||
# Read until we have enough bytes or stream ends
|
||||
while len(self._buffer) < amt and not self._finished:
|
||||
chunk = self._get_next_chunk()
|
||||
if chunk is None:
|
||||
self._finished = True
|
||||
break
|
||||
self._buffer.put(chunk)
|
||||
|
||||
if len(self._buffer) == 0:
|
||||
return b""
|
||||
|
||||
return self._buffer.get(min(amt, len(self._buffer)))
|
||||
|
||||
def stream(self, amt: int, decode_content: bool = True) -> typing.Generator[bytes, None, None]:
|
||||
"""Iterate over chunks of the response."""
|
||||
while True:
|
||||
chunk = self.read(amt)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
def close(self) -> None:
|
||||
self._closed = True
|
||||
if self._reader is not None:
|
||||
try:
|
||||
run_sync(self._reader.cancel())
|
||||
except Exception:
|
||||
pass
|
||||
self._reader = None
|
||||
self._js_response = None
|
||||
|
||||
def __iter__(self) -> typing.Iterator[bytes]:
|
||||
return self
|
||||
|
||||
def __next__(self) -> bytes:
|
||||
chunk = self.read(8192)
|
||||
if not chunk:
|
||||
raise StopIteration
|
||||
return chunk
|
||||
|
||||
|
||||
class PyodideAdapter(BaseAdapter):
|
||||
"""Synchronous adapter for making HTTP requests in Pyodide using JSPI + pyfetch."""
|
||||
|
||||
def __init__(self, max_retries: RetryType = DEFAULT_RETRIES) -> None:
|
||||
super().__init__()
|
||||
|
||||
if isinstance(max_retries, Retry):
|
||||
self.max_retries = max_retries
|
||||
else:
|
||||
self.max_retries = Retry.from_int(max_retries)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<PyodideAdapter WASM/>"
|
||||
|
||||
def send(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
stream: bool = False,
|
||||
timeout: int | float | tuple | TimeoutSauce | None = None,
|
||||
verify: TLSVerifyType = True,
|
||||
cert: TLSClientCertType | None = None,
|
||||
proxies: ProxyType | None = None,
|
||||
on_post_connection: typing.Callable[[typing.Any], None] | None = None,
|
||||
on_upload_body: typing.Callable[[int, int | None, bool, bool], None] | None = None,
|
||||
on_early_response: typing.Callable[[Response], None] | None = None,
|
||||
multiplexed: bool = False,
|
||||
) -> Response:
|
||||
"""Send a PreparedRequest using Pyodide's pyfetch (synchronous via JSPI)."""
|
||||
if isinstance(timeout, tuple):
|
||||
if len(timeout) == 3:
|
||||
timeout = timeout[2] or timeout[0] # prefer total, fallback connect
|
||||
else:
|
||||
timeout = timeout[0] # use connect
|
||||
elif isinstance(timeout, TimeoutSauce):
|
||||
timeout = timeout.total or timeout.connect_timeout
|
||||
|
||||
retries = self.max_retries
|
||||
method = request.method or "GET"
|
||||
start = time.time()
|
||||
|
||||
while True:
|
||||
try:
|
||||
response = self._do_send(request, stream, timeout)
|
||||
except Exception as err:
|
||||
retries = retries.increment(method, request.url, error=err)
|
||||
|
||||
retries.sleep()
|
||||
continue
|
||||
|
||||
base_response = BaseHTTPResponse(
|
||||
body=b"",
|
||||
headers=response.headers,
|
||||
status=response.status_code,
|
||||
request_method=request.method,
|
||||
request_url=request.url,
|
||||
)
|
||||
|
||||
has_retry_after = bool(response.headers.get("Retry-After"))
|
||||
|
||||
if retries.is_retry(method, response.status_code, has_retry_after):
|
||||
try:
|
||||
retries = retries.increment(method, request.url, response=base_response)
|
||||
except MaxRetryError:
|
||||
if retries.raise_on_status:
|
||||
raise
|
||||
return response
|
||||
|
||||
retries.sleep(base_response)
|
||||
continue
|
||||
|
||||
response.elapsed = timedelta(seconds=time.time() - start)
|
||||
return response
|
||||
|
||||
def _do_send(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
stream: bool,
|
||||
timeout: int | float | None,
|
||||
) -> Response:
|
||||
"""Perform the actual request using pyfetch made synchronous via JSPI."""
|
||||
url = request.url or ""
|
||||
scheme = url.split("://")[0].lower() if "://" in url else ""
|
||||
|
||||
# WebSocket: delegate to browser native WebSocket API
|
||||
if scheme in ("ws", "wss"):
|
||||
return self._do_send_ws(request, url)
|
||||
|
||||
# SSE: delegate to pyfetch streaming + manual SSE parsing
|
||||
if scheme in ("sse", "psse"):
|
||||
return self._do_send_sse(request, url, scheme)
|
||||
|
||||
# Prepare headers
|
||||
headers_dict: dict[str, str] = {}
|
||||
if request.headers:
|
||||
for key, value in request.headers.items():
|
||||
if key.lower() not in ("host", "content-length", "connection", "transfer-encoding"):
|
||||
headers_dict[key] = value
|
||||
|
||||
# Prepare body
|
||||
body = request.body
|
||||
if body is not None:
|
||||
if isinstance(body, str):
|
||||
body = body.encode("utf-8")
|
||||
elif isinstance(body, typing.Iterable) and not isinstance(body, (bytes, bytearray)):
|
||||
chunks: list[bytes] = []
|
||||
for chunk in body:
|
||||
if isinstance(chunk, str):
|
||||
chunks.append(chunk.encode("utf-8"))
|
||||
elif isinstance(chunk, bytes):
|
||||
chunks.append(chunk)
|
||||
body = b"".join(chunks)
|
||||
|
||||
# Build fetch options
|
||||
fetch_options: dict[str, typing.Any] = {
|
||||
"method": request.method or "GET",
|
||||
"headers": headers_dict,
|
||||
}
|
||||
|
||||
if body:
|
||||
fetch_options["body"] = body
|
||||
|
||||
# Use AbortSignal.timeout() for timeout — the browser-native mechanism.
|
||||
# run_sync cannot interrupt a JS Promise.
|
||||
signal = None
|
||||
|
||||
if timeout is not None:
|
||||
from js import AbortSignal # type: ignore[import]
|
||||
|
||||
signal = AbortSignal.timeout(int(timeout * 1000))
|
||||
|
||||
try:
|
||||
js_response = run_sync(pyfetch(request.url, signal=signal, **fetch_options))
|
||||
except Exception as e:
|
||||
err_str = str(e).lower()
|
||||
if "abort" in err_str or "timeout" in err_str or "timed out" in err_str:
|
||||
raise ConnectTimeout(f"Connection to {request.url} timed out")
|
||||
raise ConnectionError(f"Failed to fetch {request.url}: {e}")
|
||||
|
||||
# Parse response headers
|
||||
response_headers: dict[str, str] = {}
|
||||
try:
|
||||
if hasattr(js_response, "headers"):
|
||||
js_headers = js_response.headers
|
||||
if hasattr(js_headers, "items"):
|
||||
for key, value in js_headers.items():
|
||||
response_headers[key] = value
|
||||
elif hasattr(js_headers, "entries"):
|
||||
for entry in js_headers.entries():
|
||||
response_headers[entry[0]] = entry[1]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Build response object
|
||||
response = Response()
|
||||
response.status_code = js_response.status
|
||||
response.headers = CaseInsensitiveDict(response_headers)
|
||||
response.request = request
|
||||
response.url = js_response.url or request.url
|
||||
response.encoding = get_encoding_from_headers(response_headers)
|
||||
|
||||
try:
|
||||
response.reason = js_response.status_text or ""
|
||||
except Exception:
|
||||
response.reason = ""
|
||||
|
||||
if stream:
|
||||
# Streaming: pass the underlying JS Response to the raw IO
|
||||
# so it can read chunks incrementally via run_sync(reader.read())
|
||||
raw_io = _PyodideRawIO(js_response=js_response.js_response)
|
||||
raw_io.headers = response_headers
|
||||
response.raw = raw_io # type: ignore
|
||||
response._content = False # type: ignore[assignment]
|
||||
response._content_consumed = False
|
||||
else:
|
||||
# Non-streaming: read full body upfront
|
||||
try:
|
||||
response_body: bytes = run_sync(js_response.bytes())
|
||||
except Exception:
|
||||
response_body = b""
|
||||
|
||||
raw_io = _PyodideRawIO(preloaded_content=response_body)
|
||||
raw_io.headers = response_headers
|
||||
response.raw = raw_io # type: ignore
|
||||
response._content = response_body
|
||||
|
||||
return response
|
||||
|
||||
def _do_send_ws(self, request: PreparedRequest, url: str) -> Response:
|
||||
"""Handle WebSocket connections via browser native WebSocket API."""
|
||||
from ._ws import PyodideWebSocketExtension
|
||||
|
||||
try:
|
||||
ext = PyodideWebSocketExtension(url)
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"WebSocket connection to {url} failed: {e}")
|
||||
|
||||
response = Response()
|
||||
response.status_code = 101
|
||||
response.headers = CaseInsensitiveDict({"upgrade": "websocket", "connection": "upgrade"})
|
||||
response.request = request
|
||||
response.url = url
|
||||
response.reason = "Switching Protocols"
|
||||
|
||||
raw_io = _PyodideRawIO()
|
||||
raw_io.extension = ext
|
||||
response.raw = raw_io # type: ignore
|
||||
response._content = b""
|
||||
|
||||
return response
|
||||
|
||||
def _do_send_sse(self, request: PreparedRequest, url: str, scheme: str) -> Response:
|
||||
"""Handle SSE connections via pyfetch streaming + manual parsing."""
|
||||
from ._sse import PyodideSSEExtension
|
||||
|
||||
http_url = url.replace("sse://", "https://", 1) if scheme == "sse" else url.replace("psse://", "http://", 1)
|
||||
|
||||
# Pass through user-provided headers
|
||||
headers_dict: dict[str, str] = {}
|
||||
if request.headers:
|
||||
for key, value in request.headers.items():
|
||||
if key.lower() not in ("host", "content-length", "connection"):
|
||||
headers_dict[key] = value
|
||||
|
||||
try:
|
||||
ext = PyodideSSEExtension(http_url, headers=headers_dict)
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"SSE connection to {url} failed: {e}")
|
||||
|
||||
response = Response()
|
||||
response.status_code = 200
|
||||
response.headers = CaseInsensitiveDict({"content-type": "text/event-stream"})
|
||||
response.request = request
|
||||
response.url = url
|
||||
response.reason = "OK"
|
||||
|
||||
raw_io = _PyodideRawIO()
|
||||
raw_io.extension = ext
|
||||
response.raw = raw_io # type: ignore
|
||||
response._content = False # type: ignore[assignment]
|
||||
response._content_consumed = False
|
||||
|
||||
return response
|
||||
|
||||
def close(self) -> None:
|
||||
"""Clean up adapter resources."""
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ("PyodideAdapter",)
|
||||
@@ -0,0 +1,451 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import typing
|
||||
from datetime import timedelta
|
||||
|
||||
from pyodide.http import pyfetch # type: ignore[import]
|
||||
|
||||
from ...._constant import DEFAULT_RETRIES
|
||||
from ....adapters import AsyncBaseAdapter
|
||||
from ....exceptions import ConnectionError, ConnectTimeout, ReadTimeout
|
||||
from ....models import AsyncResponse, PreparedRequest, Response
|
||||
from ....packages.urllib3._async.response import AsyncHTTPResponse as BaseHTTPResponse
|
||||
from ....packages.urllib3.contrib.ssa._timeout import timeout as asyncio_timeout
|
||||
from ....packages.urllib3.exceptions import MaxRetryError
|
||||
from ....packages.urllib3.response import BytesQueueBuffer
|
||||
from ....packages.urllib3.util import Timeout as TimeoutSauce
|
||||
from ....packages.urllib3.util.retry import Retry
|
||||
from ....structures import CaseInsensitiveDict
|
||||
from ....utils import _swap_context, get_encoding_from_headers
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ....typing import ProxyType, RetryType, TLSClientCertType, TLSVerifyType
|
||||
|
||||
|
||||
class _AsyncPyodideRawIO:
|
||||
"""
|
||||
Async file-like wrapper around Pyodide Fetch response for true streaming.
|
||||
|
||||
This class uses the JavaScript ReadableStream API through Pyodide to provide
|
||||
genuine streaming support without buffering the entire response in memory.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
js_response: typing.Any, # JavaScript Response object
|
||||
timeout: float | None = None,
|
||||
) -> None:
|
||||
self._js_response = js_response
|
||||
self._timeout = timeout
|
||||
self._buffer = BytesQueueBuffer()
|
||||
self._closed = False
|
||||
self._finished = False
|
||||
self._reader: typing.Any = None # JavaScript ReadableStreamDefaultReader
|
||||
self.headers: dict[str, str] = {}
|
||||
self.extension: typing.Any = None
|
||||
|
||||
async def _ensure_reader(self) -> None:
|
||||
"""Initialize the stream reader if not already done."""
|
||||
if self._reader is None and self._js_response is not None:
|
||||
try:
|
||||
body = self._js_response.body
|
||||
if body is not None:
|
||||
self._reader = body.getReader()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def read(self, amt: int | None = None, decode_content: bool = True) -> bytes:
|
||||
"""
|
||||
Read up to `amt` bytes from the response stream.
|
||||
|
||||
When `amt` is None, reads the entire remaining response.
|
||||
"""
|
||||
if self._closed:
|
||||
return self._buffer.get(len(self._buffer)) if len(self._buffer) > 0 else b""
|
||||
|
||||
if self._finished:
|
||||
if len(self._buffer) == 0:
|
||||
return b""
|
||||
if amt is None or amt < 0:
|
||||
return self._buffer.get(len(self._buffer))
|
||||
return self._buffer.get(min(amt, len(self._buffer)))
|
||||
|
||||
if amt is None or amt < 0:
|
||||
# Read everything remaining
|
||||
async for chunk in self._stream_chunks():
|
||||
self._buffer.put(chunk)
|
||||
self._finished = True
|
||||
return self._buffer.get(len(self._buffer)) if len(self._buffer) > 0 else b""
|
||||
|
||||
# Read until we have enough bytes or stream ends
|
||||
while len(self._buffer) < amt and not self._finished:
|
||||
chunk = await self._get_next_chunk() # type: ignore[assignment]
|
||||
if chunk is None:
|
||||
self._finished = True
|
||||
break
|
||||
self._buffer.put(chunk)
|
||||
|
||||
if len(self._buffer) == 0:
|
||||
return b""
|
||||
|
||||
return self._buffer.get(min(amt, len(self._buffer)))
|
||||
|
||||
async def _get_next_chunk(self) -> bytes | None:
|
||||
"""Read the next chunk from the JavaScript ReadableStream."""
|
||||
await self._ensure_reader()
|
||||
|
||||
if self._reader is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
if self._timeout is not None:
|
||||
async with asyncio_timeout(self._timeout):
|
||||
result = await self._reader.read()
|
||||
else:
|
||||
result = await self._reader.read()
|
||||
|
||||
if result.done:
|
||||
return None
|
||||
|
||||
value = result.value
|
||||
if value is not None:
|
||||
# Convert Uint8Array to bytes
|
||||
return bytes(value.to_py())
|
||||
return None
|
||||
except asyncio.TimeoutError:
|
||||
raise ReadTimeout("Read timed out while streaming Pyodide response")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def _stream_chunks(self) -> typing.AsyncGenerator[bytes, None]:
|
||||
"""Async generator that yields chunks from the stream."""
|
||||
await self._ensure_reader()
|
||||
|
||||
if self._reader is None:
|
||||
return
|
||||
|
||||
while True:
|
||||
chunk = await self._get_next_chunk()
|
||||
if chunk is None:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
def stream(self, amt: int, decode_content: bool = True) -> typing.AsyncGenerator[bytes, None]:
|
||||
"""Return an async generator that yields chunks of `amt` bytes."""
|
||||
return self._async_stream(amt)
|
||||
|
||||
async def _async_stream(self, amt: int) -> typing.AsyncGenerator[bytes, None]:
|
||||
"""Internal async generator for streaming."""
|
||||
while True:
|
||||
chunk = await self.read(amt)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the stream and release resources."""
|
||||
self._closed = True
|
||||
if self._reader is not None:
|
||||
try:
|
||||
await self._reader.cancel()
|
||||
except Exception:
|
||||
pass
|
||||
self._reader = None
|
||||
self._js_response = None
|
||||
|
||||
def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> bytes:
|
||||
chunk = await self.read(8192)
|
||||
if not chunk:
|
||||
raise StopAsyncIteration
|
||||
return chunk
|
||||
|
||||
|
||||
class AsyncPyodideAdapter(AsyncBaseAdapter):
|
||||
"""Async adapter for making HTTP requests in Pyodide using the native pyfetch API."""
|
||||
|
||||
def __init__(self, max_retries: RetryType = DEFAULT_RETRIES) -> None:
|
||||
"""
|
||||
Initialize the async Pyodide adapter.
|
||||
|
||||
:param max_retries: Maximum number of retries for requests.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if isinstance(max_retries, Retry):
|
||||
self.max_retries = max_retries
|
||||
else:
|
||||
self.max_retries = Retry.from_int(max_retries)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<AsyncPyodideAdapter WASM/>"
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
stream: bool = False,
|
||||
timeout: int | float | tuple | TimeoutSauce | None = None,
|
||||
verify: TLSVerifyType = True,
|
||||
cert: TLSClientCertType | None = None,
|
||||
proxies: ProxyType | None = None,
|
||||
on_post_connection: typing.Callable[[typing.Any], typing.Awaitable[None]] | None = None,
|
||||
on_upload_body: typing.Callable[[int, int | None, bool, bool], typing.Awaitable[None]] | None = None,
|
||||
on_early_response: typing.Callable[[Response], typing.Awaitable[None]] | None = None,
|
||||
multiplexed: bool = False,
|
||||
) -> AsyncResponse:
|
||||
"""Send a PreparedRequest using Pyodide's pyfetch (JavaScript Fetch API)."""
|
||||
if isinstance(timeout, tuple):
|
||||
if len(timeout) == 3:
|
||||
timeout = timeout[2] or timeout[0] # prefer total, fallback connect
|
||||
else:
|
||||
timeout = timeout[0] # use connect
|
||||
elif isinstance(timeout, TimeoutSauce):
|
||||
timeout = timeout.total or timeout.connect_timeout
|
||||
|
||||
retries = self.max_retries
|
||||
method = request.method or "GET"
|
||||
start = time.time()
|
||||
|
||||
while True:
|
||||
try:
|
||||
response = await self._do_send(request, stream, timeout)
|
||||
except Exception as err:
|
||||
retries = retries.increment(method, request.url, error=err)
|
||||
await retries.async_sleep()
|
||||
continue
|
||||
|
||||
# We rely on the urllib3 implementation for retries
|
||||
# so we basically mock a response to get it to work
|
||||
base_response = BaseHTTPResponse(
|
||||
body=b"",
|
||||
headers=response.headers,
|
||||
status=response.status_code,
|
||||
request_method=request.method,
|
||||
request_url=request.url,
|
||||
)
|
||||
|
||||
# Check if we should retry based on status code
|
||||
has_retry_after = bool(response.headers.get("Retry-After"))
|
||||
|
||||
if retries.is_retry(method, response.status_code, has_retry_after):
|
||||
try:
|
||||
retries = retries.increment(method, request.url, response=base_response)
|
||||
except MaxRetryError:
|
||||
if retries.raise_on_status:
|
||||
raise
|
||||
return response
|
||||
|
||||
await retries.async_sleep(base_response)
|
||||
continue
|
||||
|
||||
response.elapsed = timedelta(seconds=time.time() - start)
|
||||
return response
|
||||
|
||||
async def _do_send(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
stream: bool,
|
||||
timeout: int | float | None,
|
||||
) -> AsyncResponse:
|
||||
"""Perform the actual request using Pyodide's pyfetch."""
|
||||
url = request.url or ""
|
||||
scheme = url.split("://")[0].lower() if "://" in url else ""
|
||||
|
||||
# WebSocket: delegate to browser native WebSocket API
|
||||
if scheme in ("ws", "wss"):
|
||||
return await self._do_send_ws(request, url)
|
||||
|
||||
# SSE: delegate to pyfetch streaming + manual SSE parsing
|
||||
if scheme in ("sse", "psse"):
|
||||
return await self._do_send_sse(request, url, scheme)
|
||||
|
||||
# Prepare headers
|
||||
headers_dict: dict[str, str] = {}
|
||||
if request.headers:
|
||||
for key, value in request.headers.items():
|
||||
# Skip headers that browsers don't allow to be set
|
||||
if key.lower() not in ("host", "content-length", "connection", "transfer-encoding"):
|
||||
headers_dict[key] = value
|
||||
|
||||
# Prepare body
|
||||
body = request.body
|
||||
if body is not None:
|
||||
if isinstance(body, str):
|
||||
body = body.encode("utf-8")
|
||||
elif hasattr(body, "__aiter__"):
|
||||
# Consume async iterable body
|
||||
chunks: list[bytes] = []
|
||||
async for chunk in body: # type: ignore[union-attr]
|
||||
if isinstance(chunk, str):
|
||||
chunks.append(chunk.encode("utf-8"))
|
||||
else:
|
||||
chunks.append(chunk)
|
||||
body = b"".join(chunks)
|
||||
elif isinstance(body, typing.Iterable) and not isinstance(body, (bytes, bytearray)):
|
||||
# Consume sync iterable body
|
||||
chunks = []
|
||||
for chunk in body:
|
||||
if isinstance(chunk, str):
|
||||
chunks.append(chunk.encode("utf-8"))
|
||||
elif isinstance(chunk, bytes):
|
||||
chunks.append(chunk)
|
||||
body = b"".join(chunks)
|
||||
|
||||
# Build fetch options
|
||||
fetch_options: dict[str, typing.Any] = {
|
||||
"method": request.method or "GET",
|
||||
"headers": headers_dict,
|
||||
}
|
||||
|
||||
if body:
|
||||
fetch_options["body"] = body
|
||||
|
||||
# Use AbortSignal.timeout() for timeout — the browser-native mechanism.
|
||||
# asyncio_timeout cannot interrupt a single JS Promise await.
|
||||
signal = None
|
||||
|
||||
if timeout is not None:
|
||||
from js import AbortSignal # type: ignore[import]
|
||||
|
||||
signal = AbortSignal.timeout(int(timeout * 1000))
|
||||
|
||||
try:
|
||||
js_response = await pyfetch(request.url, signal=signal, **fetch_options)
|
||||
except Exception as e:
|
||||
err_str = str(e).lower()
|
||||
if "abort" in err_str or "timeout" in err_str or "timed out" in err_str:
|
||||
raise ConnectTimeout(f"Connection to {request.url} timed out")
|
||||
raise ConnectionError(f"Failed to fetch {request.url}: {e}")
|
||||
|
||||
# Parse response headers
|
||||
response_headers: dict[str, str] = {}
|
||||
try:
|
||||
# Pyodide's FetchResponse has headers as a dict-like object
|
||||
if hasattr(js_response, "headers"):
|
||||
js_headers = js_response.headers
|
||||
if hasattr(js_headers, "items"):
|
||||
for key, value in js_headers.items():
|
||||
response_headers[key] = value
|
||||
elif hasattr(js_headers, "entries"):
|
||||
# JavaScript Headers.entries() returns an iterator
|
||||
for entry in js_headers.entries():
|
||||
response_headers[entry[0]] = entry[1]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Build response object
|
||||
response = Response()
|
||||
response.status_code = js_response.status
|
||||
response.headers = CaseInsensitiveDict(response_headers)
|
||||
response.request = request
|
||||
response.url = js_response.url or request.url
|
||||
response.encoding = get_encoding_from_headers(response_headers)
|
||||
|
||||
# Try to get status text
|
||||
try:
|
||||
response.reason = js_response.status_text or ""
|
||||
except Exception:
|
||||
response.reason = ""
|
||||
|
||||
if stream:
|
||||
# For streaming: set up async raw IO using the JS Response object
|
||||
# This provides true streaming without buffering entire response
|
||||
raw_io = _AsyncPyodideRawIO(js_response.js_response, timeout)
|
||||
raw_io.headers = response_headers
|
||||
response.raw = raw_io # type: ignore
|
||||
response._content = False # type: ignore[assignment]
|
||||
response._content_consumed = False
|
||||
else:
|
||||
# For non-streaming: get full response body
|
||||
try:
|
||||
if timeout is not None:
|
||||
async with asyncio_timeout(timeout):
|
||||
response_body = await js_response.bytes()
|
||||
else:
|
||||
response_body = await js_response.bytes()
|
||||
except asyncio.TimeoutError:
|
||||
raise ReadTimeout(f"Read timed out for {request.url}")
|
||||
|
||||
response._content = response_body
|
||||
raw_io = _AsyncPyodideRawIO(None, timeout)
|
||||
raw_io.headers = response_headers
|
||||
response.raw = raw_io # type: ignore
|
||||
|
||||
_swap_context(response)
|
||||
|
||||
return response # type: ignore[return-value]
|
||||
|
||||
async def _do_send_ws(self, request: PreparedRequest, url: str) -> AsyncResponse:
|
||||
"""Handle WebSocket connections via browser native WebSocket API."""
|
||||
from ._ws import AsyncPyodideWebSocketExtension
|
||||
|
||||
ext = AsyncPyodideWebSocketExtension()
|
||||
|
||||
try:
|
||||
await ext.start(url)
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"WebSocket connection to {url} failed: {e}")
|
||||
|
||||
response = Response()
|
||||
response.status_code = 101
|
||||
response.headers = CaseInsensitiveDict({"upgrade": "websocket", "connection": "upgrade"})
|
||||
response.request = request
|
||||
response.url = url
|
||||
response.reason = "Switching Protocols"
|
||||
|
||||
raw_io = _AsyncPyodideRawIO(None)
|
||||
raw_io.extension = ext
|
||||
response.raw = raw_io # type: ignore
|
||||
response._content = b""
|
||||
|
||||
_swap_context(response)
|
||||
|
||||
return response # type: ignore[return-value]
|
||||
|
||||
async def _do_send_sse(self, request: PreparedRequest, url: str, scheme: str) -> AsyncResponse:
|
||||
"""Handle SSE connections via pyfetch streaming + manual parsing."""
|
||||
from ._sse import AsyncPyodideSSEExtension
|
||||
|
||||
http_url = url.replace("sse://", "https://", 1) if scheme == "sse" else url.replace("psse://", "http://", 1)
|
||||
|
||||
# Pass through user-provided headers
|
||||
headers_dict: dict[str, str] = {}
|
||||
if request.headers:
|
||||
for key, value in request.headers.items():
|
||||
if key.lower() not in ("host", "content-length", "connection"):
|
||||
headers_dict[key] = value
|
||||
|
||||
ext = AsyncPyodideSSEExtension()
|
||||
|
||||
try:
|
||||
await ext.start(http_url, headers=headers_dict)
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"SSE connection to {url} failed: {e}")
|
||||
|
||||
response = Response()
|
||||
response.status_code = 200
|
||||
response.headers = CaseInsensitiveDict({"content-type": "text/event-stream"})
|
||||
response.request = request
|
||||
response.url = url
|
||||
response.reason = "OK"
|
||||
|
||||
raw_io = _AsyncPyodideRawIO(None)
|
||||
raw_io.extension = ext
|
||||
response.raw = raw_io # type: ignore
|
||||
response._content = False # type: ignore[assignment]
|
||||
response._content_consumed = False
|
||||
|
||||
_swap_context(response)
|
||||
|
||||
return response # type: ignore[return-value]
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Clean up adapter resources."""
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ("AsyncPyodideAdapter",)
|
||||
@@ -0,0 +1,152 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from pyodide.http import pyfetch # type: ignore[import]
|
||||
|
||||
from ....packages.urllib3.contrib.webextensions.sse import ServerSentEvent
|
||||
|
||||
|
||||
class AsyncPyodideSSEExtension:
|
||||
"""Async SSE extension for Pyodide using pyfetch streaming + manual SSE parsing.
|
||||
|
||||
Reads from a ReadableStream reader, buffers partial lines,
|
||||
and parses complete SSE events.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._closed = False
|
||||
self._buffer: str = ""
|
||||
self._last_event_id: str | None = None
|
||||
self._reader: typing.Any = None
|
||||
|
||||
async def start(self, url: str, headers: dict[str, str] | None = None) -> None:
|
||||
"""Open the SSE stream via pyfetch."""
|
||||
fetch_options: dict[str, typing.Any] = {
|
||||
"method": "GET",
|
||||
"headers": {
|
||||
"Accept": "text/event-stream",
|
||||
"Cache-Control": "no-store",
|
||||
**(headers or {}),
|
||||
},
|
||||
}
|
||||
|
||||
js_response = await pyfetch(url, **fetch_options)
|
||||
|
||||
body = js_response.js_response.body
|
||||
if body is not None:
|
||||
self._reader = body.getReader()
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._closed
|
||||
|
||||
async def _read_chunk(self) -> str | None:
|
||||
"""Read the next chunk from the ReadableStream."""
|
||||
if self._reader is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
result = await self._reader.read()
|
||||
if result.done:
|
||||
return None
|
||||
value = result.value
|
||||
if value is not None:
|
||||
return bytes(value.to_py()).decode("utf-8")
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def next_payload(self, *, raw: bool = False) -> ServerSentEvent | str | None:
|
||||
"""Read and parse the next SSE event from the stream.
|
||||
Returns None when the stream ends."""
|
||||
if self._closed:
|
||||
raise OSError("The SSE extension is closed")
|
||||
|
||||
# Keep reading chunks until we have a complete event (double newline)
|
||||
while True:
|
||||
# Check if we already have a complete event in the buffer
|
||||
sep_idx = self._buffer.find("\n\n")
|
||||
if sep_idx == -1:
|
||||
sep_idx = self._buffer.find("\r\n\r\n")
|
||||
if sep_idx != -1:
|
||||
sep_len = 4
|
||||
else:
|
||||
sep_len = 2
|
||||
else:
|
||||
sep_len = 2
|
||||
|
||||
if sep_idx != -1:
|
||||
raw_event = self._buffer[:sep_idx]
|
||||
self._buffer = self._buffer[sep_idx + sep_len :]
|
||||
|
||||
event = self._parse_event(raw_event)
|
||||
|
||||
if event is not None:
|
||||
if raw:
|
||||
return raw_event + "\n\n"
|
||||
return event
|
||||
|
||||
# Empty event (e.g. just comments), try next
|
||||
continue
|
||||
|
||||
# Need more data
|
||||
chunk = await self._read_chunk()
|
||||
if chunk is None:
|
||||
self._closed = True
|
||||
return None
|
||||
|
||||
self._buffer += chunk
|
||||
|
||||
def _parse_event(self, raw_event: str) -> ServerSentEvent | None:
|
||||
"""Parse a raw SSE event block into a ServerSentEvent."""
|
||||
kwargs: dict[str, typing.Any] = {}
|
||||
|
||||
for line in raw_event.splitlines():
|
||||
if not line or line.startswith(":"):
|
||||
continue
|
||||
key, _, value = line.partition(":")
|
||||
if key not in {"event", "data", "retry", "id"}:
|
||||
continue
|
||||
if value.startswith(" "):
|
||||
value = value[1:]
|
||||
if key == "id":
|
||||
if "\u0000" in value:
|
||||
continue
|
||||
if key == "retry":
|
||||
try:
|
||||
value = int(value) # type: ignore[assignment]
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
kwargs[key] = value
|
||||
|
||||
if not kwargs:
|
||||
return None
|
||||
|
||||
if "id" not in kwargs and self._last_event_id is not None:
|
||||
kwargs["id"] = self._last_event_id
|
||||
|
||||
event = ServerSentEvent(**kwargs)
|
||||
|
||||
if event.id:
|
||||
self._last_event_id = event.id
|
||||
|
||||
return event
|
||||
|
||||
async def send_payload(self, buf: str | bytes) -> None:
|
||||
"""SSE is one-way only."""
|
||||
raise NotImplementedError("SSE is only one-way. Sending is forbidden.")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the stream and release resources."""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
|
||||
if self._reader is not None:
|
||||
try:
|
||||
await self._reader.cancel()
|
||||
except Exception:
|
||||
pass
|
||||
self._reader = None
|
||||
@@ -0,0 +1,124 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import typing
|
||||
|
||||
from pyodide.ffi import create_proxy # type: ignore[import]
|
||||
|
||||
try:
|
||||
from js import WebSocket as JSWebSocket # type: ignore[import]
|
||||
except ImportError:
|
||||
JSWebSocket = None
|
||||
|
||||
|
||||
class AsyncPyodideWebSocketExtension:
|
||||
"""Async WebSocket extension for Pyodide using the browser's native WebSocket API.
|
||||
|
||||
Messages are queued from JS callbacks and dequeued by next_payload()
|
||||
which awaits until a message arrives.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._closed = False
|
||||
self._queue: asyncio.Queue[str | bytes | None] = asyncio.Queue()
|
||||
self._proxies: list[typing.Any] = []
|
||||
self._ws: typing.Any = None
|
||||
|
||||
async def start(self, url: str) -> None:
|
||||
"""Open the WebSocket connection and wait until it's ready."""
|
||||
if JSWebSocket is None: # Defensive: depends on JS runtime
|
||||
raise OSError(
|
||||
"WebSocket is not available in this JavaScript runtime. "
|
||||
"Browser environment required (not supported in Node.js)."
|
||||
)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
open_future: asyncio.Future[None] = loop.create_future()
|
||||
|
||||
self._ws = JSWebSocket.new(url)
|
||||
self._ws.binaryType = "arraybuffer"
|
||||
|
||||
# JS→Python callbacks: invoked by the browser event loop, not Python call
|
||||
# frames, so coverage cannot trace into them.
|
||||
def _onopen(event: typing.Any) -> None: # Defensive: JS callback
|
||||
if not open_future.done():
|
||||
open_future.set_result(None)
|
||||
|
||||
def _onerror(event: typing.Any) -> None: # Defensive: JS callback
|
||||
if not open_future.done():
|
||||
open_future.set_exception(ConnectionError("WebSocket connection failed"))
|
||||
|
||||
def _onmessage(event: typing.Any) -> None: # Defensive: JS callback
|
||||
data = event.data
|
||||
if isinstance(data, str):
|
||||
self._queue.put_nowait(data)
|
||||
else:
|
||||
# ArrayBuffer → bytes via Pyodide
|
||||
self._queue.put_nowait(bytes(data.to_py()))
|
||||
|
||||
def _onclose(event: typing.Any) -> None: # Defensive: JS callback
|
||||
self._queue.put_nowait(None)
|
||||
|
||||
for name, fn in [
|
||||
("onopen", _onopen),
|
||||
("onerror", _onerror),
|
||||
("onmessage", _onmessage),
|
||||
("onclose", _onclose),
|
||||
]:
|
||||
proxy = create_proxy(fn)
|
||||
self._proxies.append(proxy)
|
||||
setattr(self._ws, name, proxy)
|
||||
|
||||
await open_future
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._closed
|
||||
|
||||
async def next_payload(self) -> str | bytes | None:
|
||||
"""Await the next message from the WebSocket.
|
||||
Returns None when the remote end closes the connection."""
|
||||
if self._closed: # Defensive: caller should not call after close
|
||||
raise OSError("The WebSocket extension is closed")
|
||||
|
||||
msg = await self._queue.get()
|
||||
|
||||
if msg is None: # Defensive: server-initiated close sentinel
|
||||
self._closed = True
|
||||
|
||||
return msg
|
||||
|
||||
async def send_payload(self, buf: str | bytes) -> None:
|
||||
"""Send a message over the WebSocket."""
|
||||
if self._closed: # Defensive: caller should not call after close
|
||||
raise OSError("The WebSocket extension is closed")
|
||||
|
||||
if isinstance(buf, (bytes, bytearray)):
|
||||
from js import Uint8Array # type: ignore[import]
|
||||
|
||||
self._ws.send(Uint8Array.new(buf))
|
||||
else:
|
||||
self._ws.send(buf)
|
||||
|
||||
async def ping(self) -> None:
|
||||
"""No-op — browser WebSocket handles ping/pong at protocol level."""
|
||||
pass
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebSocket and clean up proxies."""
|
||||
if self._closed: # Defensive: idempotent close
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
|
||||
try:
|
||||
self._ws.close()
|
||||
except Exception: # Defensive: suppress JS errors on teardown
|
||||
pass
|
||||
|
||||
for proxy in self._proxies:
|
||||
try:
|
||||
proxy.destroy()
|
||||
except Exception: # Defensive: suppress JS errors on teardown
|
||||
pass
|
||||
self._proxies.clear()
|
||||
@@ -0,0 +1,151 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from pyodide.ffi import run_sync # type: ignore[import]
|
||||
from pyodide.http import pyfetch # type: ignore[import]
|
||||
|
||||
from ...packages.urllib3.contrib.webextensions.sse import ServerSentEvent
|
||||
|
||||
|
||||
class PyodideSSEExtension:
|
||||
"""SSE extension for Pyodide using pyfetch streaming + manual SSE parsing.
|
||||
|
||||
Synchronous via JSPI (run_sync). Reads from a ReadableStream reader,
|
||||
buffers partial lines, and parses complete SSE events.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, headers: dict[str, str] | None = None) -> None:
|
||||
self._closed = False
|
||||
self._buffer: str = ""
|
||||
self._last_event_id: str | None = None
|
||||
self._reader: typing.Any = None
|
||||
|
||||
fetch_options: dict[str, typing.Any] = {
|
||||
"method": "GET",
|
||||
"headers": {
|
||||
"Accept": "text/event-stream",
|
||||
"Cache-Control": "no-store",
|
||||
**(headers or {}),
|
||||
},
|
||||
}
|
||||
|
||||
js_response = run_sync(pyfetch(url, **fetch_options))
|
||||
|
||||
body = js_response.js_response.body
|
||||
if body is not None:
|
||||
self._reader = body.getReader()
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._closed
|
||||
|
||||
def _read_chunk(self) -> str | None:
|
||||
"""Read the next chunk from the ReadableStream, blocking via JSPI."""
|
||||
if self._reader is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
result = run_sync(self._reader.read())
|
||||
if result.done:
|
||||
return None
|
||||
value = result.value
|
||||
if value is not None:
|
||||
return bytes(value.to_py()).decode("utf-8")
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def next_payload(self, *, raw: bool = False) -> ServerSentEvent | str | None:
|
||||
"""Read and parse the next SSE event from the stream.
|
||||
Returns None when the stream ends."""
|
||||
if self._closed:
|
||||
raise OSError("The SSE extension is closed")
|
||||
|
||||
# Keep reading chunks until we have a complete event (double newline)
|
||||
while True:
|
||||
# Check if we already have a complete event in the buffer
|
||||
sep_idx = self._buffer.find("\n\n")
|
||||
if sep_idx == -1:
|
||||
sep_idx = self._buffer.find("\r\n\r\n")
|
||||
if sep_idx != -1:
|
||||
sep_len = 4
|
||||
else:
|
||||
sep_len = 2
|
||||
else:
|
||||
sep_len = 2
|
||||
|
||||
if sep_idx != -1:
|
||||
raw_event = self._buffer[:sep_idx]
|
||||
self._buffer = self._buffer[sep_idx + sep_len :]
|
||||
|
||||
event = self._parse_event(raw_event)
|
||||
|
||||
if event is not None:
|
||||
if raw:
|
||||
return raw_event + "\n\n"
|
||||
return event
|
||||
|
||||
# Empty event (e.g. just comments), try next
|
||||
continue
|
||||
|
||||
# Need more data
|
||||
chunk = self._read_chunk()
|
||||
if chunk is None:
|
||||
self._closed = True
|
||||
return None
|
||||
|
||||
self._buffer += chunk
|
||||
|
||||
def _parse_event(self, raw_event: str) -> ServerSentEvent | None:
|
||||
"""Parse a raw SSE event block into a ServerSentEvent."""
|
||||
kwargs: dict[str, typing.Any] = {}
|
||||
|
||||
for line in raw_event.splitlines():
|
||||
if not line or line.startswith(":"):
|
||||
continue
|
||||
key, _, value = line.partition(":")
|
||||
if key not in {"event", "data", "retry", "id"}:
|
||||
continue
|
||||
if value.startswith(" "):
|
||||
value = value[1:]
|
||||
if key == "id":
|
||||
if "\u0000" in value:
|
||||
continue
|
||||
if key == "retry":
|
||||
try:
|
||||
value = int(value) # type: ignore[assignment]
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
kwargs[key] = value
|
||||
|
||||
if not kwargs:
|
||||
return None
|
||||
|
||||
if "id" not in kwargs and self._last_event_id is not None:
|
||||
kwargs["id"] = self._last_event_id
|
||||
|
||||
event = ServerSentEvent(**kwargs)
|
||||
|
||||
if event.id:
|
||||
self._last_event_id = event.id
|
||||
|
||||
return event
|
||||
|
||||
def send_payload(self, buf: str | bytes) -> None:
|
||||
"""SSE is one-way only."""
|
||||
raise NotImplementedError("SSE is only one-way. Sending is forbidden.")
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the stream and release resources."""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
|
||||
if self._reader is not None:
|
||||
try:
|
||||
run_sync(self._reader.cancel())
|
||||
except Exception:
|
||||
pass
|
||||
self._reader = None
|
||||
@@ -0,0 +1,174 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from js import Promise # type: ignore[import]
|
||||
from pyodide.ffi import create_proxy, run_sync # type: ignore[import]
|
||||
|
||||
try:
|
||||
from js import WebSocket as JSWebSocket # type: ignore[import]
|
||||
except ImportError:
|
||||
JSWebSocket = None
|
||||
|
||||
|
||||
class PyodideWebSocketExtension:
|
||||
"""WebSocket extension for Pyodide using the browser's native WebSocket API.
|
||||
|
||||
Synchronous via JSPI (run_sync). Uses JS Promises for signaling instead of
|
||||
asyncio primitives, because run_sync bypasses the asyncio event loop.
|
||||
Messages from JS callbacks are buffered in a list and delivered via
|
||||
next_payload() which blocks via run_sync on a JS Promise.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str) -> None:
|
||||
if JSWebSocket is None: # Defensive: depends on JS runtime
|
||||
raise OSError(
|
||||
"WebSocket is not available in this JavaScript runtime. "
|
||||
"Browser environment required (not supported in Node.js)."
|
||||
)
|
||||
|
||||
self._closed = False
|
||||
self._pending: list[str | bytes | None] = []
|
||||
self._waiting_resolve: typing.Any = None
|
||||
self._last_msg: str | bytes | None = None
|
||||
self._proxies: list[typing.Any] = []
|
||||
|
||||
# Create browser WebSocket
|
||||
self._ws = JSWebSocket.new(url)
|
||||
self._ws.binaryType = "arraybuffer"
|
||||
|
||||
# Open/error signaling via a JS Promise (not asyncio.Future,
|
||||
# because run_sync/JSPI does not drive the asyncio event loop).
|
||||
_open_state: dict[str, typing.Any] = {
|
||||
"resolve": None,
|
||||
"reject": None,
|
||||
}
|
||||
|
||||
def _open_executor(resolve: typing.Any, reject: typing.Any) -> None:
|
||||
_open_state["resolve"] = resolve
|
||||
_open_state["reject"] = reject
|
||||
|
||||
exec_proxy = create_proxy(_open_executor)
|
||||
open_promise = Promise.new(exec_proxy)
|
||||
self._proxies.append(exec_proxy)
|
||||
|
||||
# JS→Python callbacks: invoked by the browser event loop, not Python call
|
||||
# frames, so coverage cannot trace into them.
|
||||
def _onopen(event: typing.Any) -> None: # Defensive: JS callback
|
||||
r = _open_state["resolve"]
|
||||
if r is not None:
|
||||
_open_state["resolve"] = None
|
||||
_open_state["reject"] = None
|
||||
r()
|
||||
|
||||
def _onerror(event: typing.Any) -> None: # Defensive: JS callback
|
||||
r = _open_state["reject"]
|
||||
if r is not None:
|
||||
_open_state["resolve"] = None
|
||||
_open_state["reject"] = None
|
||||
r("WebSocket connection failed")
|
||||
|
||||
def _onmessage(event: typing.Any) -> None: # Defensive: JS callback
|
||||
data = event.data
|
||||
if isinstance(data, str):
|
||||
msg: str | bytes = data
|
||||
else:
|
||||
# ArrayBuffer → bytes via Pyodide
|
||||
msg = bytes(data.to_py())
|
||||
|
||||
if self._waiting_resolve is not None:
|
||||
self._last_msg = msg
|
||||
r = self._waiting_resolve
|
||||
self._waiting_resolve = None
|
||||
r()
|
||||
else:
|
||||
self._pending.append(msg)
|
||||
|
||||
def _onclose(event: typing.Any) -> None: # Defensive: JS callback
|
||||
if self._waiting_resolve is not None:
|
||||
self._last_msg = None
|
||||
r = self._waiting_resolve
|
||||
self._waiting_resolve = None
|
||||
r()
|
||||
else:
|
||||
self._pending.append(None)
|
||||
|
||||
# Create proxies so JS can call these Python functions
|
||||
for name, fn in [
|
||||
("onopen", _onopen),
|
||||
("onerror", _onerror),
|
||||
("onmessage", _onmessage),
|
||||
("onclose", _onclose),
|
||||
]:
|
||||
proxy = create_proxy(fn)
|
||||
self._proxies.append(proxy)
|
||||
setattr(self._ws, name, proxy)
|
||||
|
||||
# Block until the connection is open (or error rejects the promise)
|
||||
run_sync(open_promise)
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._closed
|
||||
|
||||
def next_payload(self) -> str | bytes | None:
|
||||
"""Block (via JSPI) until the next message arrives.
|
||||
Returns None when the remote end closes the connection."""
|
||||
if self._closed: # Defensive: caller should not call after close
|
||||
raise OSError("The WebSocket extension is closed")
|
||||
|
||||
# Drain from buffer first
|
||||
if self._pending:
|
||||
msg = self._pending.pop(0)
|
||||
if msg is None: # Defensive: serverinitiated close sentinel
|
||||
self._closed = True
|
||||
return msg
|
||||
|
||||
# Wait for the next message via a JS Promise
|
||||
def _executor(resolve: typing.Any, reject: typing.Any) -> None:
|
||||
self._waiting_resolve = resolve
|
||||
|
||||
exec_proxy = create_proxy(_executor)
|
||||
promise = Promise.new(exec_proxy)
|
||||
run_sync(promise)
|
||||
exec_proxy.destroy()
|
||||
|
||||
msg = self._last_msg
|
||||
if msg is None: # Defensive: server-initiated close sentinel
|
||||
self._closed = True
|
||||
return msg
|
||||
|
||||
def send_payload(self, buf: str | bytes) -> None:
|
||||
"""Send a message over the WebSocket."""
|
||||
if self._closed: # Defensive: caller should not call after close
|
||||
raise OSError("The WebSocket extension is closed")
|
||||
|
||||
if isinstance(buf, (bytes, bytearray)):
|
||||
from js import Uint8Array # type: ignore[import]
|
||||
|
||||
self._ws.send(Uint8Array.new(buf))
|
||||
else:
|
||||
self._ws.send(buf)
|
||||
|
||||
def ping(self) -> None:
|
||||
"""No-op — browser WebSocket handles ping/pong at protocol level."""
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the WebSocket and clean up proxies."""
|
||||
if self._closed: # Defensive: idempotent close
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
|
||||
try:
|
||||
self._ws.close()
|
||||
except Exception: # Defensive: suppress JS errors on teardown
|
||||
pass
|
||||
|
||||
for proxy in self._proxies:
|
||||
try:
|
||||
proxy.destroy()
|
||||
except Exception: # Defensive: suppress JS errors on teardown
|
||||
pass
|
||||
self._proxies.clear()
|
||||
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class RevocationStrategy(Enum):
|
||||
PREFER_OCSP = 0
|
||||
PREFER_CRL = 1
|
||||
CHECK_ALL = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class RevocationConfiguration:
|
||||
strategy: RevocationStrategy | None = RevocationStrategy.PREFER_OCSP
|
||||
strict_mode: bool = False
|
||||
|
||||
|
||||
DEFAULT_STRATEGY: RevocationConfiguration = RevocationConfiguration()
|
||||
@@ -0,0 +1,374 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import datetime
|
||||
import ipaddress
|
||||
import ssl
|
||||
import threading
|
||||
import typing
|
||||
import warnings
|
||||
from random import randint
|
||||
|
||||
from qh3._hazmat import (
|
||||
Certificate,
|
||||
CertificateRevocationList,
|
||||
)
|
||||
|
||||
from ....exceptions import RequestException, SSLError
|
||||
from ....models import PreparedRequest
|
||||
from ....packages.urllib3 import ConnectionInfo
|
||||
from ....packages.urllib3.contrib.resolver import BaseResolver
|
||||
from ....packages.urllib3.exceptions import SecurityWarning
|
||||
from ....typing import ProxyType
|
||||
from .._ocsp import _parse_x509_der_cached, _str_fingerprint_of, readable_revocation_reason
|
||||
|
||||
|
||||
class InMemoryRevocationList:
|
||||
def __init__(self, max_size: int = 256):
|
||||
self._max_size: int = max_size
|
||||
self._store: dict[str, CertificateRevocationList] = {}
|
||||
self._issuers_map: dict[str, Certificate] = {}
|
||||
self._crl_endpoints: dict[str, str] = {}
|
||||
self._failure_count: int = 0
|
||||
self._access_lock = threading.RLock()
|
||||
self._second_level_locks: dict[str, threading.RLock] = {}
|
||||
|
||||
@contextlib.contextmanager
|
||||
def lock_for(self, peer_certificate: Certificate) -> typing.Generator[None]:
|
||||
fingerprint: str = _str_fingerprint_of(peer_certificate)
|
||||
|
||||
with self._access_lock:
|
||||
if fingerprint not in self._second_level_locks:
|
||||
self._second_level_locks[fingerprint] = threading.RLock()
|
||||
|
||||
lock = self._second_level_locks[fingerprint]
|
||||
|
||||
lock.acquire()
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
def __getstate__(self) -> dict[str, typing.Any]:
|
||||
with self._access_lock:
|
||||
return {
|
||||
"_max_size": self._max_size,
|
||||
"_store": {k: v.serialize() for k, v in self._store.items()},
|
||||
"_issuers_map": {k: v.serialize() for k, v in self._issuers_map.items()},
|
||||
"_failure_count": self._failure_count,
|
||||
"_crl_endpoints": self._crl_endpoints,
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict[str, typing.Any]) -> None:
|
||||
if "_store" not in state or "_issuers_map" not in state or "_max_size" not in state or "_crl_endpoints" not in state:
|
||||
raise OSError("unrecoverable state for InMemoryRevocationStatus")
|
||||
|
||||
self._access_lock = threading.RLock()
|
||||
self._second_level_locks = {}
|
||||
self._max_size = state["_max_size"]
|
||||
self._failure_count = state["_failure_count"] if "_failure_count" in state else 0
|
||||
self._crl_endpoints = state["_crl_endpoints"]
|
||||
|
||||
self._store = {}
|
||||
|
||||
for k, v in state["_store"].items():
|
||||
self._store[k] = CertificateRevocationList.deserialize(v)
|
||||
|
||||
self._issuers_map = {}
|
||||
|
||||
for k, v in state["_issuers_map"].items():
|
||||
self._issuers_map[k] = Certificate.deserialize(v)
|
||||
|
||||
def get_issuer_of(self, peer_certificate: Certificate) -> Certificate | None:
|
||||
with self._access_lock:
|
||||
fingerprint: str = _str_fingerprint_of(peer_certificate)
|
||||
|
||||
if fingerprint not in self._issuers_map:
|
||||
return None
|
||||
|
||||
return self._issuers_map[fingerprint]
|
||||
|
||||
def __len__(self) -> int:
|
||||
with self._access_lock:
|
||||
return len(self._store)
|
||||
|
||||
def incr_failure(self) -> None:
|
||||
with self._access_lock:
|
||||
self._failure_count += 1
|
||||
|
||||
@property
|
||||
def failure_count(self) -> int:
|
||||
return self._failure_count
|
||||
|
||||
def check(self, crl_distribution_point: str) -> CertificateRevocationList | None:
|
||||
with self._access_lock:
|
||||
if crl_distribution_point not in self._store:
|
||||
return None
|
||||
|
||||
cached_response = self._store[crl_distribution_point]
|
||||
|
||||
if cached_response.next_update_at and datetime.datetime.now().timestamp() >= cached_response.next_update_at:
|
||||
del self._store[crl_distribution_point]
|
||||
return None
|
||||
|
||||
return cached_response
|
||||
|
||||
def get_previous_crl_endpoint(self, leaf: Certificate) -> str | None:
|
||||
fingerprint = _str_fingerprint_of(leaf)
|
||||
|
||||
if fingerprint in self._crl_endpoints:
|
||||
return self._crl_endpoints[fingerprint]
|
||||
|
||||
return None
|
||||
|
||||
def save(
|
||||
self,
|
||||
leaf: Certificate,
|
||||
issuer: Certificate,
|
||||
crl: CertificateRevocationList,
|
||||
crl_distribution_point: str,
|
||||
) -> None:
|
||||
with self._access_lock:
|
||||
if len(self._store) >= self._max_size:
|
||||
tbd_key: str | None = None
|
||||
closest_next_update: int | None = None
|
||||
|
||||
for k in self._store:
|
||||
if closest_next_update is None:
|
||||
closest_next_update = self._store[k].next_update_at
|
||||
tbd_key = k
|
||||
continue
|
||||
if self._store[k].next_update_at > closest_next_update: # type: ignore
|
||||
closest_next_update = self._store[k].next_update_at
|
||||
tbd_key = k
|
||||
|
||||
if tbd_key:
|
||||
del self._store[tbd_key]
|
||||
else:
|
||||
first_key = list(self._store.keys())[0]
|
||||
del self._store[first_key]
|
||||
|
||||
peer_fingerprint: str = _str_fingerprint_of(leaf)
|
||||
|
||||
self._store[crl_distribution_point] = crl
|
||||
self._crl_endpoints[peer_fingerprint] = crl_distribution_point
|
||||
self._issuers_map[peer_fingerprint] = issuer
|
||||
self._failure_count = 0
|
||||
|
||||
|
||||
def verify(
|
||||
r: PreparedRequest,
|
||||
strict: bool = False,
|
||||
timeout: float | int = 0.2,
|
||||
proxies: ProxyType | None = None,
|
||||
resolver: BaseResolver | None = None,
|
||||
happy_eyeballs: bool | int = False,
|
||||
cache: InMemoryRevocationList | None = None,
|
||||
) -> None:
|
||||
conn_info: ConnectionInfo | None = r.conn_info
|
||||
|
||||
# we can't do anything in that case.
|
||||
if conn_info is None or conn_info.certificate_der is None or conn_info.certificate_dict is None:
|
||||
return
|
||||
|
||||
endpoints: list[str] = [ # type: ignore
|
||||
# exclude non-HTTP endpoint. like ldap.
|
||||
ep # type: ignore
|
||||
for ep in list(conn_info.certificate_dict.get("crlDistributionPoints", [])) # type: ignore
|
||||
if ep.startswith("http://") # type: ignore
|
||||
]
|
||||
|
||||
# no CRL distribution point available.
|
||||
if not endpoints:
|
||||
return
|
||||
|
||||
if cache is None:
|
||||
cache = InMemoryRevocationList()
|
||||
|
||||
if not strict:
|
||||
if cache.failure_count >= 4:
|
||||
return
|
||||
|
||||
# some corporate environment
|
||||
# have invalid OCSP implementation
|
||||
# they use a cert that IS NOT in the chain
|
||||
# to sign the response. It's weird but true.
|
||||
# see https://github.com/jawah/niquests/issues/274
|
||||
ignore_signature_without_strict = ipaddress.ip_address(conn_info.destination_address[0]).is_private or bool(proxies)
|
||||
verify_signature = strict is True or ignore_signature_without_strict is False
|
||||
|
||||
peer_certificate: Certificate = _parse_x509_der_cached(conn_info.certificate_der)
|
||||
|
||||
crl_distribution_point: str = cache.get_previous_crl_endpoint(peer_certificate) or endpoints[randint(0, len(endpoints) - 1)]
|
||||
|
||||
with cache.lock_for(peer_certificate):
|
||||
cached_revocation_list = cache.check(crl_distribution_point)
|
||||
|
||||
if cached_revocation_list is not None:
|
||||
issuer_certificate = cache.get_issuer_of(peer_certificate)
|
||||
|
||||
if issuer_certificate:
|
||||
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
|
||||
|
||||
revocation_status = cached_revocation_list.is_revoked(peer_certificate.serial_number)
|
||||
|
||||
if revocation_status is not None:
|
||||
r.ocsp_verified = False
|
||||
raise SSLError(
|
||||
(
|
||||
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
|
||||
f"by issuer ({readable_revocation_reason(revocation_status.reason)}). "
|
||||
"You should avoid trying to request anything from it as the remote has been compromised. ",
|
||||
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
|
||||
"for more information.",
|
||||
)
|
||||
)
|
||||
else:
|
||||
r.ocsp_verified = True
|
||||
|
||||
return
|
||||
|
||||
from ....sessions import Session
|
||||
|
||||
with Session(resolver=resolver, happy_eyeballs=happy_eyeballs) as session:
|
||||
session.trust_env = False
|
||||
if proxies:
|
||||
session.proxies = proxies
|
||||
|
||||
# When using Python native capabilities, you won't have the issuerCA DER by default.
|
||||
# Unfortunately! But no worries, we can circumvent it!
|
||||
# Three ways are valid to fetch it (in order of preference, safest to riskiest):
|
||||
# - The issuer can be (but unlikely) a root CA.
|
||||
# - Retrieve it by asking it from the TLS layer.
|
||||
# - Downloading it using specified caIssuers from the peer certificate.
|
||||
if conn_info.issuer_certificate_der is None:
|
||||
# It could be a root (self-signed) certificate. Or a previously seen issuer.
|
||||
issuer_certificate = cache.get_issuer_of(peer_certificate)
|
||||
|
||||
hint_ca_issuers: list[str] = [
|
||||
ep # type: ignore
|
||||
for ep in list(conn_info.certificate_dict.get("caIssuers", [])) # type: ignore
|
||||
if ep.startswith("http://") # type: ignore
|
||||
]
|
||||
|
||||
if issuer_certificate is None and hint_ca_issuers:
|
||||
try:
|
||||
raw_intermediary_response = session.get(hint_ca_issuers[0])
|
||||
except RequestException:
|
||||
pass
|
||||
else:
|
||||
if raw_intermediary_response.status_code and 300 > raw_intermediary_response.status_code >= 200:
|
||||
raw_intermediary_content = raw_intermediary_response.content
|
||||
|
||||
if raw_intermediary_content is not None:
|
||||
# binary DER
|
||||
if b"-----BEGIN CERTIFICATE-----" not in raw_intermediary_content:
|
||||
issuer_certificate = Certificate(raw_intermediary_content)
|
||||
# b64 PEM
|
||||
elif b"-----BEGIN CERTIFICATE-----" in raw_intermediary_content:
|
||||
issuer_certificate = Certificate(
|
||||
ssl.PEM_cert_to_DER_cert(raw_intermediary_content.decode())
|
||||
)
|
||||
|
||||
# Well! We're out of luck. No further should we go.
|
||||
if issuer_certificate is None:
|
||||
cache.incr_failure()
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate "
|
||||
"via CRL. You are seeing this warning due to enabling strict mode for OCSP / "
|
||||
"Revocation check. Reason: Remote did not provide any intermediary certificate."
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
|
||||
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
|
||||
else:
|
||||
issuer_certificate = Certificate(conn_info.issuer_certificate_der)
|
||||
|
||||
try:
|
||||
crl_http_response = session.get(
|
||||
crl_distribution_point,
|
||||
timeout=timeout,
|
||||
)
|
||||
except RequestException as e:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via CRL. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
f"Reason: {e}"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
|
||||
if crl_http_response.status_code and 300 > crl_http_response.status_code >= 200:
|
||||
if crl_http_response.content is None:
|
||||
return
|
||||
|
||||
try:
|
||||
crl = CertificateRevocationList(crl_http_response.content)
|
||||
except ValueError:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via CRL. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
"Reason: The X509 CRL parser failed to read the response"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
|
||||
if verify_signature:
|
||||
# Verify the signature of the OCSP response with issuer public key
|
||||
try:
|
||||
if not crl.authenticate_for(issuer_certificate.public_bytes()):
|
||||
raise SSLError(
|
||||
f"Unable to establish a secure connection to {r.url} "
|
||||
"because the CRL response received has been tampered. "
|
||||
"You could be targeted by a MITM attack."
|
||||
)
|
||||
except ValueError:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate "
|
||||
"via CRL. You are seeing this warning due to enabling strict mode for OCSP / "
|
||||
"Revocation check. Reason: The X509 CRL is signed using an unsupported algorithm."
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
|
||||
cache.save(peer_certificate, issuer_certificate, crl, crl_distribution_point)
|
||||
|
||||
revocation_status = crl.is_revoked(peer_certificate.serial_number)
|
||||
|
||||
if revocation_status is not None:
|
||||
r.ocsp_verified = False
|
||||
raise SSLError(
|
||||
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
|
||||
f"by issuer ({readable_revocation_reason(revocation_status.reason)}). "
|
||||
"You should avoid trying to request anything from it as the remote has been compromised. "
|
||||
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
|
||||
"for more information."
|
||||
)
|
||||
else:
|
||||
r.ocsp_verified = True
|
||||
else:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via CRL. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
f"CRL endpoint: {str(crl_http_response)}"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ("verify",)
|
||||
@@ -0,0 +1,405 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import ipaddress
|
||||
import ssl
|
||||
import typing
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
from random import randint
|
||||
|
||||
from qh3._hazmat import (
|
||||
Certificate,
|
||||
CertificateRevocationList,
|
||||
)
|
||||
|
||||
from .....exceptions import RequestException, SSLError
|
||||
from .....models import PreparedRequest
|
||||
from .....packages.urllib3 import ConnectionInfo
|
||||
from .....packages.urllib3.contrib.resolver._async import AsyncBaseResolver
|
||||
from .....packages.urllib3.exceptions import SecurityWarning
|
||||
from .....typing import ProxyType
|
||||
from .....utils import is_cancelled_error_root_cause
|
||||
from ..._ocsp import _parse_x509_der_cached, _str_fingerprint_of, readable_revocation_reason
|
||||
|
||||
|
||||
class InMemoryRevocationList:
|
||||
def __init__(self, max_size: int = 256):
|
||||
self._max_size: int = max_size
|
||||
self._store: dict[str, CertificateRevocationList] = {}
|
||||
self._semaphores: dict[str, asyncio.Semaphore] = {}
|
||||
self._issuers_map: dict[str, Certificate] = {}
|
||||
self._crl_endpoints: dict[str, str] = {}
|
||||
self._failure_count: int = 0
|
||||
|
||||
def __getstate__(self) -> dict[str, typing.Any]:
|
||||
return {
|
||||
"_max_size": self._max_size,
|
||||
"_store": {k: v.serialize() for k, v in self._store.items()},
|
||||
"_issuers_map": {k: v.serialize() for k, v in self._issuers_map.items()},
|
||||
"_failure_count": self._failure_count,
|
||||
"_crl_endpoints": self._crl_endpoints,
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict[str, typing.Any]) -> None:
|
||||
if "_store" not in state or "_issuers_map" not in state or "_max_size" not in state or "_crl_endpoints" not in state:
|
||||
raise OSError("unrecoverable state for InMemoryRevocationStatus")
|
||||
|
||||
self._max_size = state["_max_size"]
|
||||
self._failure_count = state["_failure_count"] if "_failure_count" in state else 0
|
||||
self._crl_endpoints = state["_crl_endpoints"]
|
||||
self._semaphores = {}
|
||||
|
||||
self._store = {}
|
||||
|
||||
for k, v in state["_store"].items():
|
||||
self._store[k] = CertificateRevocationList.deserialize(v)
|
||||
|
||||
self._issuers_map = {}
|
||||
|
||||
for k, v in state["_issuers_map"].items():
|
||||
self._issuers_map[k] = Certificate.deserialize(v)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lock(self, peer_certificate: Certificate) -> typing.AsyncGenerator[None, None]:
|
||||
fingerprint: str = _str_fingerprint_of(peer_certificate)
|
||||
|
||||
if fingerprint not in self._semaphores:
|
||||
self._semaphores[fingerprint] = asyncio.Semaphore()
|
||||
|
||||
await self._semaphores[fingerprint].acquire()
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._semaphores[fingerprint].release()
|
||||
|
||||
def get_issuer_of(self, peer_certificate: Certificate) -> Certificate | None:
|
||||
fingerprint: str = _str_fingerprint_of(peer_certificate)
|
||||
|
||||
if fingerprint not in self._issuers_map:
|
||||
return None
|
||||
|
||||
return self._issuers_map[fingerprint]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._store)
|
||||
|
||||
def incr_failure(self) -> None:
|
||||
self._failure_count += 1
|
||||
|
||||
@property
|
||||
def failure_count(self) -> int:
|
||||
return self._failure_count
|
||||
|
||||
def check(self, crl_distribution_point: str) -> CertificateRevocationList | None:
|
||||
if crl_distribution_point not in self._store:
|
||||
return None
|
||||
|
||||
cached_response = self._store[crl_distribution_point]
|
||||
|
||||
if cached_response.next_update_at and datetime.datetime.now().timestamp() >= cached_response.next_update_at:
|
||||
del self._store[crl_distribution_point]
|
||||
return None
|
||||
|
||||
return cached_response
|
||||
|
||||
def get_previous_crl_endpoint(self, leaf: Certificate) -> str | None:
|
||||
fingerprint = _str_fingerprint_of(leaf)
|
||||
|
||||
if fingerprint in self._crl_endpoints:
|
||||
return self._crl_endpoints[fingerprint]
|
||||
|
||||
return None
|
||||
|
||||
def save(
|
||||
self,
|
||||
leaf: Certificate,
|
||||
issuer: Certificate,
|
||||
crl: CertificateRevocationList,
|
||||
crl_distribution_point: str,
|
||||
) -> None:
|
||||
if len(self._store) >= self._max_size:
|
||||
tbd_key: str | None = None
|
||||
closest_next_update: int | None = None
|
||||
|
||||
for k in self._store:
|
||||
if closest_next_update is None:
|
||||
closest_next_update = self._store[k].next_update_at
|
||||
tbd_key = k
|
||||
continue
|
||||
if self._store[k].next_update_at > closest_next_update: # type: ignore
|
||||
closest_next_update = self._store[k].next_update_at
|
||||
tbd_key = k
|
||||
|
||||
if tbd_key:
|
||||
del self._store[tbd_key]
|
||||
else:
|
||||
first_key = list(self._store.keys())[0]
|
||||
del self._store[first_key]
|
||||
|
||||
peer_fingerprint: str = _str_fingerprint_of(leaf)
|
||||
|
||||
self._store[crl_distribution_point] = crl
|
||||
self._crl_endpoints[peer_fingerprint] = crl_distribution_point
|
||||
self._issuers_map[peer_fingerprint] = issuer
|
||||
self._failure_count = 0
|
||||
|
||||
|
||||
async def verify(
|
||||
r: PreparedRequest,
|
||||
strict: bool = False,
|
||||
timeout: float | int = 0.2,
|
||||
proxies: ProxyType | None = None,
|
||||
resolver: AsyncBaseResolver | None = None,
|
||||
happy_eyeballs: bool | int = False,
|
||||
cache: InMemoryRevocationList | None = None,
|
||||
) -> None:
|
||||
conn_info: ConnectionInfo | None = r.conn_info
|
||||
|
||||
# we can't do anything in that case.
|
||||
if conn_info is None or conn_info.certificate_der is None or conn_info.certificate_dict is None:
|
||||
return
|
||||
|
||||
endpoints: list[str] = [ # type: ignore
|
||||
# exclude non-HTTP endpoint. like ldap.
|
||||
ep # type: ignore
|
||||
for ep in list(conn_info.certificate_dict.get("crlDistributionPoints", [])) # type: ignore
|
||||
if ep.startswith("http://") # type: ignore
|
||||
]
|
||||
|
||||
# well... not all issued certificate have a OCSP entry. e.g. mkcert.
|
||||
if not endpoints:
|
||||
return
|
||||
|
||||
if cache is None:
|
||||
cache = InMemoryRevocationList()
|
||||
|
||||
if not strict:
|
||||
if cache.failure_count >= 4:
|
||||
return
|
||||
|
||||
# some corporate environment
|
||||
# have invalid OCSP implementation
|
||||
# they use a cert that IS NOT in the chain
|
||||
# to sign the response. It's weird but true.
|
||||
# see https://github.com/jawah/niquests/issues/274
|
||||
ignore_signature_without_strict = ipaddress.ip_address(conn_info.destination_address[0]).is_private or bool(proxies)
|
||||
verify_signature = strict is True or ignore_signature_without_strict is False
|
||||
|
||||
peer_certificate: Certificate = _parse_x509_der_cached(conn_info.certificate_der)
|
||||
|
||||
crl_distribution_point: str = cache.get_previous_crl_endpoint(peer_certificate) or endpoints[randint(0, len(endpoints) - 1)]
|
||||
|
||||
cached_revocation_list = cache.check(crl_distribution_point)
|
||||
|
||||
if cached_revocation_list is not None:
|
||||
issuer_certificate = cache.get_issuer_of(peer_certificate)
|
||||
|
||||
if issuer_certificate is not None:
|
||||
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
|
||||
|
||||
revocation_status = cached_revocation_list.is_revoked(peer_certificate.serial_number)
|
||||
|
||||
if revocation_status is not None:
|
||||
r.ocsp_verified = False
|
||||
raise SSLError(
|
||||
(
|
||||
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
|
||||
f"by issuer ({readable_revocation_reason(revocation_status.reason)}). "
|
||||
"You should avoid trying to request anything from it as the remote has been compromised. ",
|
||||
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
|
||||
"for more information.",
|
||||
)
|
||||
)
|
||||
else:
|
||||
r.ocsp_verified = True
|
||||
|
||||
return
|
||||
|
||||
async with cache.lock(peer_certificate):
|
||||
# why are we doing this twice?
|
||||
# because using Semaphore to prevent concurrent
|
||||
# revocation check have a heavy toll!
|
||||
# todo: respect DRY
|
||||
cached_revocation_list = cache.check(crl_distribution_point)
|
||||
|
||||
if cached_revocation_list is not None:
|
||||
issuer_certificate = cache.get_issuer_of(peer_certificate)
|
||||
|
||||
if issuer_certificate is not None:
|
||||
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
|
||||
|
||||
revocation_status = cached_revocation_list.is_revoked(peer_certificate.serial_number)
|
||||
|
||||
if revocation_status is not None:
|
||||
r.ocsp_verified = False
|
||||
raise SSLError(
|
||||
(
|
||||
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
|
||||
f"by issuer ({readable_revocation_reason(revocation_status.reason)}). "
|
||||
"You should avoid trying to request anything from it as the remote has been compromised. ",
|
||||
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
|
||||
"for more information.",
|
||||
)
|
||||
)
|
||||
else:
|
||||
r.ocsp_verified = True
|
||||
|
||||
return
|
||||
|
||||
from .....async_session import AsyncSession
|
||||
|
||||
async with AsyncSession(resolver=resolver, happy_eyeballs=happy_eyeballs) as session:
|
||||
session.trust_env = False
|
||||
if proxies:
|
||||
session.proxies = proxies
|
||||
|
||||
# When using Python native capabilities, you won't have the issuerCA DER by default.
|
||||
# Unfortunately! But no worries, we can circumvent it!
|
||||
# Three ways are valid to fetch it (in order of preference, safest to riskiest):
|
||||
# - The issuer can be (but unlikely) a root CA.
|
||||
# - Retrieve it by asking it from the TLS layer.
|
||||
# - Downloading it using specified caIssuers from the peer certificate.
|
||||
if conn_info.issuer_certificate_der is None:
|
||||
# It could be a root (self-signed) certificate. Or a previously seen issuer.
|
||||
issuer_certificate = cache.get_issuer_of(peer_certificate)
|
||||
|
||||
hint_ca_issuers: list[str] = [
|
||||
ep # type: ignore
|
||||
for ep in list(conn_info.certificate_dict.get("caIssuers", [])) # type: ignore
|
||||
if ep.startswith("http://") # type: ignore
|
||||
]
|
||||
|
||||
if issuer_certificate is None and hint_ca_issuers:
|
||||
try:
|
||||
raw_intermediary_response = await session.get(hint_ca_issuers[0])
|
||||
except RequestException as e:
|
||||
if is_cancelled_error_root_cause(e):
|
||||
return
|
||||
except asyncio.CancelledError: # don't raise any error or warnings!
|
||||
return
|
||||
else:
|
||||
if raw_intermediary_response.status_code and 300 > raw_intermediary_response.status_code >= 200:
|
||||
raw_intermediary_content = raw_intermediary_response.content
|
||||
|
||||
if raw_intermediary_content is not None:
|
||||
# binary DER
|
||||
if b"-----BEGIN CERTIFICATE-----" not in raw_intermediary_content:
|
||||
issuer_certificate = Certificate(raw_intermediary_content)
|
||||
# b64 PEM
|
||||
elif b"-----BEGIN CERTIFICATE-----" in raw_intermediary_content:
|
||||
issuer_certificate = Certificate(
|
||||
ssl.PEM_cert_to_DER_cert(raw_intermediary_content.decode())
|
||||
)
|
||||
|
||||
# Well! We're out of luck. No further should we go.
|
||||
if issuer_certificate is None:
|
||||
# aia fetching should be counted as general ocsp failure too.
|
||||
cache.incr_failure()
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate "
|
||||
"via CRL. You are seeing this warning due to enabling strict mode for OCSP / "
|
||||
"Revocation check. Reason: Remote did not provide any intermediary certificate."
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
|
||||
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
|
||||
else:
|
||||
issuer_certificate = Certificate(conn_info.issuer_certificate_der)
|
||||
|
||||
try:
|
||||
crl_http_response = await session.get(
|
||||
crl_distribution_point,
|
||||
timeout=timeout,
|
||||
)
|
||||
except RequestException as e:
|
||||
if is_cancelled_error_root_cause(e):
|
||||
return
|
||||
# aia fetching should be counted as general ocsp failure too.
|
||||
cache.incr_failure()
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via CRL. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
f"Reason: {e}"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
except asyncio.CancelledError: # don't raise any error or warnings!
|
||||
return
|
||||
|
||||
if crl_http_response.status_code and 300 > crl_http_response.status_code >= 200:
|
||||
if crl_http_response.content is None:
|
||||
return
|
||||
|
||||
try:
|
||||
crl = CertificateRevocationList(crl_http_response.content)
|
||||
except ValueError:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via CRL. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
"Reason: The X509 CRL parser failed to read the response"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
|
||||
if verify_signature:
|
||||
# Verify the signature of the OCSP response with issuer public key
|
||||
try:
|
||||
if not crl.authenticate_for(issuer_certificate.public_bytes()):
|
||||
raise SSLError(
|
||||
f"Unable to establish a secure connection to {r.url} "
|
||||
"because the CRL response received has been tampered. "
|
||||
"You could be targeted by a MITM attack."
|
||||
)
|
||||
except ValueError:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid "
|
||||
"certificate via CRL. You are seeing this warning due to enabling strict "
|
||||
"mode for OCSP / Revocation check. "
|
||||
"Reason: The X509 CRL is signed using an unsupported algorithm."
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
|
||||
cache.save(peer_certificate, issuer_certificate, crl, crl_distribution_point)
|
||||
|
||||
revocation_status = crl.is_revoked(peer_certificate.serial_number)
|
||||
|
||||
if revocation_status is not None:
|
||||
r.ocsp_verified = False
|
||||
raise SSLError(
|
||||
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
|
||||
f"by issuer ({readable_revocation_reason(revocation_status.reason)}). "
|
||||
"You should avoid trying to request anything from it as the remote has been compromised. "
|
||||
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
|
||||
"for more information."
|
||||
)
|
||||
else:
|
||||
r.ocsp_verified = True
|
||||
else:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via CRL. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
f"CRL endpoint: {str(crl_http_response)}"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ("verify",)
|
||||
@@ -0,0 +1,477 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import datetime
|
||||
import ipaddress
|
||||
import ssl
|
||||
import threading
|
||||
import typing
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
from hashlib import sha256
|
||||
from random import randint
|
||||
|
||||
from qh3._hazmat import (
|
||||
Certificate,
|
||||
OCSPCertStatus,
|
||||
OCSPRequest,
|
||||
OCSPResponse,
|
||||
OCSPResponseStatus,
|
||||
ReasonFlags,
|
||||
)
|
||||
|
||||
from ....exceptions import RequestException, SSLError
|
||||
from ....models import PreparedRequest
|
||||
from ....packages.urllib3 import ConnectionInfo
|
||||
from ....packages.urllib3.contrib.resolver import BaseResolver
|
||||
from ....packages.urllib3.exceptions import SecurityWarning
|
||||
from ....typing import ProxyType
|
||||
|
||||
|
||||
@lru_cache(maxsize=64)
|
||||
def _parse_x509_der_cached(der: bytes) -> Certificate:
|
||||
return Certificate(der)
|
||||
|
||||
|
||||
@lru_cache(maxsize=64)
|
||||
def _fingerprint_raw_data(payload: bytes) -> str:
|
||||
return "".join([format(i, "02x") for i in sha256(payload).digest()])
|
||||
|
||||
|
||||
def _str_fingerprint_of(certificate: Certificate) -> str:
|
||||
return _fingerprint_raw_data(certificate.public_bytes())
|
||||
|
||||
|
||||
def readable_revocation_reason(flag: ReasonFlags | None) -> str | None:
|
||||
return str(flag).split(".")[-1].lower() if flag is not None else None
|
||||
|
||||
|
||||
class InMemoryRevocationStatus:
|
||||
def __init__(self, max_size: int = 2048):
|
||||
self._max_size: int = max_size
|
||||
self._store: dict[str, OCSPResponse] = {}
|
||||
self._issuers_map: dict[str, Certificate] = {}
|
||||
self._timings: list[datetime.datetime] = []
|
||||
self._failure_count: int = 0
|
||||
self._access_lock = threading.RLock()
|
||||
self._second_level_locks: dict[str, threading.RLock] = {}
|
||||
self.hold: bool = False
|
||||
|
||||
@staticmethod
|
||||
def support_pickle() -> bool:
|
||||
"""This gives you a hint on whether you can cache it to restore later."""
|
||||
return hasattr(OCSPResponse, "serialize")
|
||||
|
||||
def __getstate__(self) -> dict[str, typing.Any]:
|
||||
with self._access_lock:
|
||||
return {
|
||||
"_max_size": self._max_size,
|
||||
"_store": {k: v.serialize() for k, v in self._store.items()},
|
||||
"_issuers_map": {k: v.serialize() for k, v in self._issuers_map.items()},
|
||||
"_failure_count": self._failure_count,
|
||||
}
|
||||
|
||||
@contextlib.contextmanager
|
||||
def lock_for(self, peer_certificate: Certificate) -> typing.Generator[None]:
|
||||
fingerprint: str = _str_fingerprint_of(peer_certificate)
|
||||
|
||||
with self._access_lock:
|
||||
if fingerprint not in self._second_level_locks:
|
||||
self._second_level_locks[fingerprint] = threading.RLock()
|
||||
|
||||
lock = self._second_level_locks[fingerprint]
|
||||
|
||||
lock.acquire()
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
def __setstate__(self, state: dict[str, typing.Any]) -> None:
|
||||
if "_store" not in state or "_issuers_map" not in state or "_max_size" not in state:
|
||||
raise OSError("unrecoverable state for InMemoryRevocationStatus")
|
||||
|
||||
self._access_lock = threading.RLock()
|
||||
self._second_level_locks = {}
|
||||
self.hold = False
|
||||
self._timings = []
|
||||
|
||||
self._max_size = state["_max_size"]
|
||||
self._failure_count = state["_failure_count"] if "_failure_count" in state else 0
|
||||
|
||||
self._store = {}
|
||||
|
||||
for k, v in state["_store"].items():
|
||||
self._store[k] = OCSPResponse.deserialize(v)
|
||||
|
||||
self._issuers_map = {}
|
||||
|
||||
for k, v in state["_issuers_map"].items():
|
||||
self._issuers_map[k] = Certificate.deserialize(v)
|
||||
|
||||
def get_issuer_of(self, peer_certificate: Certificate) -> Certificate | None:
|
||||
with self._access_lock:
|
||||
fingerprint: str = _str_fingerprint_of(peer_certificate)
|
||||
|
||||
if fingerprint not in self._issuers_map:
|
||||
return None
|
||||
|
||||
return self._issuers_map[fingerprint]
|
||||
|
||||
def __len__(self) -> int:
|
||||
with self._access_lock:
|
||||
return len(self._store)
|
||||
|
||||
def incr_failure(self) -> None:
|
||||
with self._access_lock:
|
||||
self._failure_count += 1
|
||||
|
||||
@property
|
||||
def failure_count(self) -> int:
|
||||
return self._failure_count
|
||||
|
||||
def rate(self):
|
||||
with self._access_lock:
|
||||
previous_dt: datetime.datetime | None = None
|
||||
delays: list[float] = []
|
||||
|
||||
for dt in self._timings:
|
||||
if previous_dt is None:
|
||||
previous_dt = dt
|
||||
continue
|
||||
delays.append((dt - previous_dt).total_seconds())
|
||||
previous_dt = dt
|
||||
|
||||
return sum(delays) / len(delays) if delays else 0.0
|
||||
|
||||
def check(self, peer_certificate: Certificate) -> OCSPResponse | None:
|
||||
with self._access_lock:
|
||||
fingerprint: str = _str_fingerprint_of(peer_certificate)
|
||||
|
||||
if fingerprint not in self._store:
|
||||
return None
|
||||
|
||||
cached_response = self._store[fingerprint]
|
||||
|
||||
if cached_response.certificate_status == OCSPCertStatus.GOOD:
|
||||
if cached_response.next_update and datetime.datetime.now().timestamp() >= cached_response.next_update:
|
||||
del self._store[fingerprint]
|
||||
return None
|
||||
return cached_response
|
||||
|
||||
return cached_response
|
||||
|
||||
def save(
|
||||
self,
|
||||
peer_certificate: Certificate,
|
||||
issuer_certificate: Certificate,
|
||||
ocsp_response: OCSPResponse,
|
||||
) -> None:
|
||||
with self._access_lock:
|
||||
if len(self._store) >= self._max_size:
|
||||
tbd_key: str | None = None
|
||||
closest_next_update: int | None = None
|
||||
|
||||
for k in self._store:
|
||||
if self._store[k].response_status != OCSPResponseStatus.SUCCESSFUL:
|
||||
tbd_key = k
|
||||
break
|
||||
|
||||
if self._store[k].certificate_status != OCSPCertStatus.REVOKED:
|
||||
if closest_next_update is None:
|
||||
closest_next_update = self._store[k].next_update
|
||||
tbd_key = k
|
||||
continue
|
||||
if self._store[k].next_update > closest_next_update: # type: ignore
|
||||
closest_next_update = self._store[k].next_update
|
||||
tbd_key = k
|
||||
|
||||
if tbd_key:
|
||||
del self._store[tbd_key]
|
||||
del self._issuers_map[tbd_key]
|
||||
else:
|
||||
first_key = list(self._store.keys())[0]
|
||||
del self._store[first_key]
|
||||
del self._issuers_map[first_key]
|
||||
|
||||
peer_fingerprint: str = _str_fingerprint_of(peer_certificate)
|
||||
|
||||
self._store[peer_fingerprint] = ocsp_response
|
||||
self._issuers_map[peer_fingerprint] = issuer_certificate
|
||||
self._failure_count = 0
|
||||
|
||||
self._timings.append(datetime.datetime.now())
|
||||
|
||||
if len(self._timings) >= self._max_size:
|
||||
self._timings.pop(0)
|
||||
|
||||
|
||||
def verify(
|
||||
r: PreparedRequest,
|
||||
strict: bool = False,
|
||||
timeout: float | int = 0.2,
|
||||
proxies: ProxyType | None = None,
|
||||
resolver: BaseResolver | None = None,
|
||||
happy_eyeballs: bool | int = False,
|
||||
cache: InMemoryRevocationStatus | None = None,
|
||||
) -> None:
|
||||
conn_info: ConnectionInfo | None = r.conn_info
|
||||
|
||||
# we can't do anything in that case.
|
||||
if conn_info is None or conn_info.certificate_der is None or conn_info.certificate_dict is None:
|
||||
return
|
||||
|
||||
endpoints: list[str] = [ # type: ignore
|
||||
# exclude non-HTTP endpoint. like ldap.
|
||||
ep # type: ignore
|
||||
for ep in list(conn_info.certificate_dict.get("OCSP", [])) # type: ignore
|
||||
if ep.startswith("http://") # type: ignore
|
||||
]
|
||||
|
||||
# well... not all issued certificate have a OCSP entry. e.g. mkcert.
|
||||
if not endpoints:
|
||||
return
|
||||
|
||||
if cache is None:
|
||||
cache = InMemoryRevocationStatus()
|
||||
|
||||
# this feature, by default, is reserved for a reasonable usage.
|
||||
if not strict:
|
||||
if cache.failure_count >= 4:
|
||||
return
|
||||
|
||||
mean_rate_sec = cache.rate()
|
||||
cache_count = len(cache)
|
||||
|
||||
if cache_count >= 10 and mean_rate_sec <= 1.0:
|
||||
cache.hold = True
|
||||
|
||||
if cache.hold:
|
||||
return
|
||||
|
||||
# some corporate environment
|
||||
# have invalid OCSP implementation
|
||||
# they use a cert that IS NOT in the chain
|
||||
# to sign the response. It's weird but true.
|
||||
# see https://github.com/jawah/niquests/issues/274
|
||||
ignore_signature_without_strict = ipaddress.ip_address(conn_info.destination_address[0]).is_private or bool(proxies)
|
||||
verify_signature = strict is True or ignore_signature_without_strict is False
|
||||
|
||||
peer_certificate = _parse_x509_der_cached(conn_info.certificate_der)
|
||||
|
||||
with cache.lock_for(peer_certificate):
|
||||
cached_response = cache.check(peer_certificate)
|
||||
|
||||
if cached_response is not None:
|
||||
issuer_certificate = cache.get_issuer_of(peer_certificate)
|
||||
|
||||
if issuer_certificate:
|
||||
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
|
||||
|
||||
if cached_response.response_status == OCSPResponseStatus.SUCCESSFUL:
|
||||
if cached_response.certificate_status == OCSPCertStatus.REVOKED:
|
||||
r.ocsp_verified = False
|
||||
raise SSLError(
|
||||
(
|
||||
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
|
||||
f"by issuer ({readable_revocation_reason(cached_response.revocation_reason) or 'unspecified'}). "
|
||||
"You should avoid trying to request anything from it as the remote has been compromised. ",
|
||||
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
|
||||
"for more information.",
|
||||
)
|
||||
)
|
||||
elif cached_response.certificate_status == OCSPCertStatus.UNKNOWN:
|
||||
r.ocsp_verified = False
|
||||
if strict is True:
|
||||
raise SSLError(
|
||||
f"Unable to establish a secure connection to {r.url} because the issuer does not know "
|
||||
"whether certificate is valid or not. This error occurred because you enabled strict mode "
|
||||
"for the OCSP / Revocation check."
|
||||
)
|
||||
else:
|
||||
r.ocsp_verified = True
|
||||
|
||||
return
|
||||
|
||||
from ....sessions import Session
|
||||
|
||||
with Session(resolver=resolver, happy_eyeballs=happy_eyeballs) as session:
|
||||
session.trust_env = False
|
||||
if proxies:
|
||||
session.proxies = proxies
|
||||
|
||||
# When using Python native capabilities, you won't have the issuerCA DER by default.
|
||||
# Unfortunately! But no worries, we can circumvent it!
|
||||
# Three ways are valid to fetch it (in order of preference, safest to riskiest):
|
||||
# - The issuer can be (but unlikely) a root CA.
|
||||
# - Retrieve it by asking it from the TLS layer.
|
||||
# - Downloading it using specified caIssuers from the peer certificate.
|
||||
if conn_info.issuer_certificate_der is None:
|
||||
# It could be a root (self-signed) certificate. Or a previously seen issuer.
|
||||
issuer_certificate = cache.get_issuer_of(peer_certificate)
|
||||
|
||||
hint_ca_issuers: list[str] = [
|
||||
ep # type: ignore
|
||||
for ep in list(conn_info.certificate_dict.get("caIssuers", [])) # type: ignore
|
||||
if ep.startswith("http://") # type: ignore
|
||||
]
|
||||
|
||||
# try to do AIA fetching of intermediate certificate (issuer)
|
||||
if issuer_certificate is None and hint_ca_issuers:
|
||||
try:
|
||||
raw_intermediary_response = session.get(hint_ca_issuers[0])
|
||||
except RequestException:
|
||||
pass
|
||||
else:
|
||||
if raw_intermediary_response.status_code and 300 > raw_intermediary_response.status_code >= 200:
|
||||
raw_intermediary_content = raw_intermediary_response.content
|
||||
|
||||
if raw_intermediary_content is not None:
|
||||
# binary DER
|
||||
if b"-----BEGIN CERTIFICATE-----" not in raw_intermediary_content:
|
||||
issuer_certificate = Certificate(raw_intermediary_content)
|
||||
# b64 PEM
|
||||
elif b"-----BEGIN CERTIFICATE-----" in raw_intermediary_content:
|
||||
issuer_certificate = Certificate(
|
||||
ssl.PEM_cert_to_DER_cert(raw_intermediary_content.decode())
|
||||
)
|
||||
|
||||
# Well! We're out of luck. No further should we go.
|
||||
if issuer_certificate is None:
|
||||
# aia fetching should be counted as general ocsp failure too.
|
||||
cache.incr_failure()
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate "
|
||||
"via OCSP. You are seeing this warning due to enabling strict mode for OCSP / "
|
||||
"Revocation check. Reason: Remote did not provide any intermediary certificate."
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
|
||||
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
|
||||
else:
|
||||
issuer_certificate = Certificate(conn_info.issuer_certificate_der)
|
||||
|
||||
try:
|
||||
req = OCSPRequest(peer_certificate.public_bytes(), issuer_certificate.public_bytes())
|
||||
except ValueError:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
"Reason: The X509 OCSP generator failed to assemble the request."
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
ocsp_http_response = session.post(
|
||||
endpoints[randint(0, len(endpoints) - 1)],
|
||||
data=req.public_bytes(),
|
||||
headers={"Content-Type": "application/ocsp-request"},
|
||||
timeout=timeout,
|
||||
)
|
||||
except RequestException as e:
|
||||
# we want to monitor failures related to the responder.
|
||||
# we don't want to ruin the http experience in normal circumstances.
|
||||
cache.incr_failure()
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
f"Reason: {e}"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
|
||||
if ocsp_http_response.status_code and 300 > ocsp_http_response.status_code >= 200:
|
||||
if ocsp_http_response.content is None:
|
||||
return
|
||||
|
||||
try:
|
||||
ocsp_resp = OCSPResponse(ocsp_http_response.content)
|
||||
except ValueError:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
"Reason: The X509 OCSP parser failed to read the response"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
|
||||
# Verify the signature of the OCSP response with issuer public key
|
||||
if verify_signature:
|
||||
try:
|
||||
if not ocsp_resp.authenticate_for(issuer_certificate.public_bytes()): # type: ignore[attr-defined]
|
||||
raise SSLError(
|
||||
f"Unable to establish a secure connection to {r.url} "
|
||||
"because the OCSP response received has been tampered. "
|
||||
"You could be targeted by a MITM attack."
|
||||
)
|
||||
except ValueError: # Defensive: unsupported signature case
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate "
|
||||
"via OCSP. You are seeing this warning due to enabling strict mode for OCSP / "
|
||||
"Revocation check. Reason: The X509 OCSP response is signed using an unsupported algorithm."
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
|
||||
cache.save(peer_certificate, issuer_certificate, ocsp_resp)
|
||||
|
||||
if ocsp_resp.response_status == OCSPResponseStatus.SUCCESSFUL:
|
||||
if ocsp_resp.certificate_status == OCSPCertStatus.REVOKED:
|
||||
r.ocsp_verified = False
|
||||
raise SSLError(
|
||||
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
|
||||
f"by issuer ({readable_revocation_reason(ocsp_resp.revocation_reason) or 'unspecified'}). "
|
||||
"You should avoid trying to request anything from it as the remote has been compromised. "
|
||||
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
|
||||
"for more information."
|
||||
)
|
||||
if ocsp_resp.certificate_status == OCSPCertStatus.UNKNOWN:
|
||||
r.ocsp_verified = False
|
||||
if strict is True:
|
||||
raise SSLError(
|
||||
f"Unable to establish a secure connection to {r.url} because the issuer does not know whether "
|
||||
"certificate is valid or not. This error occurred because you enabled strict mode for "
|
||||
"the OCSP / Revocation check."
|
||||
)
|
||||
else:
|
||||
r.ocsp_verified = True
|
||||
else:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
f"OCSP Server Status: {ocsp_resp.response_status}"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
else:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
f"OCSP Server Status: {str(ocsp_http_response)}"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ("verify",)
|
||||
@@ -0,0 +1,492 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import ipaddress
|
||||
import ssl
|
||||
import typing
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
from random import randint
|
||||
|
||||
from qh3._hazmat import (
|
||||
Certificate,
|
||||
OCSPCertStatus,
|
||||
OCSPRequest,
|
||||
OCSPResponse,
|
||||
OCSPResponseStatus,
|
||||
)
|
||||
|
||||
from .....exceptions import RequestException, SSLError
|
||||
from .....models import PreparedRequest
|
||||
from .....packages.urllib3 import ConnectionInfo
|
||||
from .....packages.urllib3.contrib.resolver._async import AsyncBaseResolver
|
||||
from .....packages.urllib3.exceptions import SecurityWarning
|
||||
from .....typing import ProxyType
|
||||
from .....utils import is_cancelled_error_root_cause
|
||||
from .. import (
|
||||
_parse_x509_der_cached,
|
||||
_str_fingerprint_of,
|
||||
readable_revocation_reason,
|
||||
)
|
||||
|
||||
|
||||
class InMemoryRevocationStatus:
|
||||
def __init__(self, max_size: int = 2048):
|
||||
self._max_size: int = max_size
|
||||
self._store: dict[str, OCSPResponse] = {}
|
||||
self._semaphores: dict[str, asyncio.Semaphore] = {}
|
||||
self._issuers_map: dict[str, Certificate] = {}
|
||||
self._timings: list[datetime.datetime] = []
|
||||
self._failure_count: int = 0
|
||||
self.hold: bool = False
|
||||
|
||||
@staticmethod
|
||||
def support_pickle() -> bool:
|
||||
"""This gives you a hint on whether you can cache it to restore later."""
|
||||
return hasattr(OCSPResponse, "serialize")
|
||||
|
||||
def __getstate__(self) -> dict[str, typing.Any]:
|
||||
return {
|
||||
"_max_size": self._max_size,
|
||||
"_store": {k: v.serialize() for k, v in self._store.items()},
|
||||
"_issuers_map": {k: v.serialize() for k, v in self._issuers_map.items()},
|
||||
"_failure_count": self._failure_count,
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict[str, typing.Any]) -> None:
|
||||
if "_store" not in state or "_issuers_map" not in state or "_max_size" not in state:
|
||||
raise OSError("unrecoverable state for InMemoryRevocationStatus")
|
||||
|
||||
self.hold = False
|
||||
self._timings = []
|
||||
|
||||
self._max_size = state["_max_size"]
|
||||
self._failure_count = state["_failure_count"] if "_failure_count" in state else 0
|
||||
|
||||
self._store = {}
|
||||
self._semaphores = {}
|
||||
|
||||
for k, v in state["_store"].items():
|
||||
self._store[k] = OCSPResponse.deserialize(v)
|
||||
self._semaphores[k] = asyncio.Semaphore()
|
||||
|
||||
self._issuers_map = {}
|
||||
|
||||
for k, v in state["_issuers_map"].items():
|
||||
self._issuers_map[k] = Certificate.deserialize(v)
|
||||
|
||||
def get_issuer_of(self, peer_certificate: Certificate) -> Certificate | None:
|
||||
fingerprint: str = _str_fingerprint_of(peer_certificate)
|
||||
|
||||
if fingerprint not in self._issuers_map:
|
||||
return None
|
||||
|
||||
return self._issuers_map[fingerprint]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._store)
|
||||
|
||||
def incr_failure(self) -> None:
|
||||
self._failure_count += 1
|
||||
|
||||
@property
|
||||
def failure_count(self) -> int:
|
||||
return self._failure_count
|
||||
|
||||
@asynccontextmanager
|
||||
async def lock(self, peer_certificate: Certificate) -> typing.AsyncGenerator[None, None]:
|
||||
fingerprint: str = _str_fingerprint_of(peer_certificate)
|
||||
|
||||
if fingerprint not in self._semaphores:
|
||||
self._semaphores[fingerprint] = asyncio.Semaphore()
|
||||
|
||||
await self._semaphores[fingerprint].acquire()
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._semaphores[fingerprint].release()
|
||||
|
||||
def rate(self):
|
||||
previous_dt: datetime.datetime | None = None
|
||||
delays: list[float] = []
|
||||
|
||||
for dt in self._timings:
|
||||
if previous_dt is None:
|
||||
previous_dt = dt
|
||||
continue
|
||||
delays.append((dt - previous_dt).total_seconds())
|
||||
previous_dt = dt
|
||||
|
||||
return sum(delays) / len(delays) if delays else 0.0
|
||||
|
||||
def check(self, peer_certificate: Certificate) -> OCSPResponse | None:
|
||||
fingerprint: str = _str_fingerprint_of(peer_certificate)
|
||||
|
||||
if fingerprint not in self._store:
|
||||
return None
|
||||
|
||||
cached_response = self._store[fingerprint]
|
||||
|
||||
if cached_response.certificate_status == OCSPCertStatus.GOOD:
|
||||
if cached_response.next_update and datetime.datetime.now().timestamp() >= cached_response.next_update:
|
||||
del self._store[fingerprint]
|
||||
return None
|
||||
return cached_response
|
||||
|
||||
return cached_response
|
||||
|
||||
def save(
|
||||
self,
|
||||
peer_certificate: Certificate,
|
||||
issuer_certificate: Certificate,
|
||||
ocsp_response: OCSPResponse,
|
||||
) -> None:
|
||||
if len(self._store) >= self._max_size:
|
||||
tbd_key: str | None = None
|
||||
closest_next_update: int | None = None
|
||||
|
||||
for k in self._store:
|
||||
if self._store[k].response_status != OCSPResponseStatus.SUCCESSFUL:
|
||||
tbd_key = k
|
||||
break
|
||||
|
||||
if self._store[k].certificate_status != OCSPCertStatus.REVOKED:
|
||||
if closest_next_update is None:
|
||||
closest_next_update = self._store[k].next_update
|
||||
tbd_key = k
|
||||
continue
|
||||
if self._store[k].next_update > closest_next_update: # type: ignore
|
||||
closest_next_update = self._store[k].next_update
|
||||
tbd_key = k
|
||||
|
||||
if tbd_key:
|
||||
del self._store[tbd_key]
|
||||
del self._issuers_map[tbd_key]
|
||||
else:
|
||||
first_key = list(self._store.keys())[0]
|
||||
del self._store[first_key]
|
||||
del self._issuers_map[first_key]
|
||||
|
||||
peer_fingerprint: str = _str_fingerprint_of(peer_certificate)
|
||||
|
||||
self._store[peer_fingerprint] = ocsp_response
|
||||
self._issuers_map[peer_fingerprint] = issuer_certificate
|
||||
self._failure_count = 0
|
||||
|
||||
self._timings.append(datetime.datetime.now())
|
||||
|
||||
if len(self._timings) >= self._max_size:
|
||||
self._timings.pop(0)
|
||||
|
||||
|
||||
async def verify(
|
||||
r: PreparedRequest,
|
||||
strict: bool = False,
|
||||
timeout: float | int = 0.2,
|
||||
proxies: ProxyType | None = None,
|
||||
resolver: AsyncBaseResolver | None = None,
|
||||
happy_eyeballs: bool | int = False,
|
||||
cache: InMemoryRevocationStatus | None = None,
|
||||
) -> None:
|
||||
conn_info: ConnectionInfo | None = r.conn_info
|
||||
|
||||
# we can't do anything in that case.
|
||||
if conn_info is None or conn_info.certificate_der is None or conn_info.certificate_dict is None:
|
||||
return
|
||||
|
||||
endpoints: list[str] = [ # type: ignore
|
||||
# exclude non-HTTP endpoint. like ldap.
|
||||
ep # type: ignore
|
||||
for ep in list(conn_info.certificate_dict.get("OCSP", [])) # type: ignore
|
||||
if ep.startswith("http://") # type: ignore
|
||||
]
|
||||
|
||||
# well... not all issued certificate have a OCSP entry. e.g. mkcert.
|
||||
if not endpoints:
|
||||
return
|
||||
|
||||
if cache is None:
|
||||
cache = InMemoryRevocationStatus()
|
||||
|
||||
peer_certificate = _parse_x509_der_cached(conn_info.certificate_der)
|
||||
|
||||
cached_response = cache.check(peer_certificate)
|
||||
|
||||
if cached_response is not None:
|
||||
issuer_certificate = cache.get_issuer_of(peer_certificate)
|
||||
|
||||
if issuer_certificate:
|
||||
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
|
||||
|
||||
if cached_response.response_status == OCSPResponseStatus.SUCCESSFUL:
|
||||
if cached_response.certificate_status == OCSPCertStatus.REVOKED:
|
||||
r.ocsp_verified = False
|
||||
raise SSLError(
|
||||
(
|
||||
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
|
||||
f"by issuer ({readable_revocation_reason(cached_response.revocation_reason) or 'unspecified'}). "
|
||||
"You should avoid trying to request anything from it as the remote has been compromised. ",
|
||||
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
|
||||
"for more information.",
|
||||
)
|
||||
)
|
||||
elif cached_response.certificate_status == OCSPCertStatus.UNKNOWN:
|
||||
r.ocsp_verified = False
|
||||
if strict is True:
|
||||
raise SSLError(
|
||||
f"Unable to establish a secure connection to {r.url} because the issuer does not know "
|
||||
"whether certificate is valid or not. This error occurred because you enabled strict mode "
|
||||
"for the OCSP / Revocation check."
|
||||
)
|
||||
else:
|
||||
r.ocsp_verified = True
|
||||
|
||||
return
|
||||
|
||||
async with cache.lock(peer_certificate):
|
||||
# this feature, by default, is reserved for a reasonable usage.
|
||||
if not strict:
|
||||
if cache.failure_count >= 4:
|
||||
return
|
||||
|
||||
mean_rate_sec = cache.rate()
|
||||
cache_count = len(cache)
|
||||
|
||||
if cache_count >= 10 and mean_rate_sec <= 1.0:
|
||||
cache.hold = True
|
||||
|
||||
if cache.hold:
|
||||
return
|
||||
|
||||
# some corporate environment
|
||||
# have invalid OCSP implementation
|
||||
# they use a cert that IS NOT in the chain
|
||||
# to sign the response. It's weird but true.
|
||||
# see https://github.com/jawah/niquests/issues/274
|
||||
ignore_signature_without_strict = ipaddress.ip_address(conn_info.destination_address[0]).is_private or bool(proxies)
|
||||
verify_signature = strict is True or ignore_signature_without_strict is False
|
||||
|
||||
# why are we doing this twice?
|
||||
# because using Semaphore to prevent concurrent
|
||||
# revocation check have a heavy toll!
|
||||
# todo: respect DRY
|
||||
cached_response = cache.check(peer_certificate)
|
||||
|
||||
if cached_response is not None:
|
||||
issuer_certificate = cache.get_issuer_of(peer_certificate)
|
||||
|
||||
if issuer_certificate:
|
||||
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
|
||||
|
||||
if cached_response.response_status == OCSPResponseStatus.SUCCESSFUL:
|
||||
if cached_response.certificate_status == OCSPCertStatus.REVOKED:
|
||||
r.ocsp_verified = False
|
||||
raise SSLError(
|
||||
(
|
||||
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
|
||||
f"by issuer ({readable_revocation_reason(cached_response.revocation_reason) or 'unspecified'}). "
|
||||
"You should avoid trying to request anything from it as the remote has been compromised. ",
|
||||
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
|
||||
"for more information.",
|
||||
)
|
||||
)
|
||||
elif cached_response.certificate_status == OCSPCertStatus.UNKNOWN:
|
||||
r.ocsp_verified = False
|
||||
if strict is True:
|
||||
raise SSLError(
|
||||
f"Unable to establish a secure connection to {r.url} because the issuer does not know "
|
||||
"whether certificate is valid or not. This error occurred because you enabled strict mode "
|
||||
"for the OCSP / Revocation check."
|
||||
)
|
||||
else:
|
||||
r.ocsp_verified = True
|
||||
|
||||
return
|
||||
|
||||
from .....async_session import AsyncSession
|
||||
|
||||
async with AsyncSession(resolver=resolver, happy_eyeballs=happy_eyeballs) as session:
|
||||
session.trust_env = False
|
||||
if proxies:
|
||||
session.proxies = proxies
|
||||
|
||||
# When using Python native capabilities, you won't have the issuerCA DER by default (Python 3.7 to 3.9).
|
||||
# Unfortunately! But no worries, we can circumvent it! (Python 3.10+ is not concerned anymore)
|
||||
# Three ways are valid to fetch it (in order of preference, safest to riskiest):
|
||||
# - The issuer can be (but unlikely) a root CA.
|
||||
# - Retrieve it by asking it from the TLS layer.
|
||||
# - Downloading it using specified caIssuers from the peer certificate.
|
||||
if conn_info.issuer_certificate_der is None:
|
||||
# It could be a root (self-signed) certificate. Or a previously seen issuer.
|
||||
issuer_certificate = cache.get_issuer_of(peer_certificate)
|
||||
|
||||
hint_ca_issuers: list[str] = [
|
||||
ep # type: ignore
|
||||
for ep in list(conn_info.certificate_dict.get("caIssuers", [])) # type: ignore
|
||||
if ep.startswith("http://") # type: ignore
|
||||
]
|
||||
|
||||
if issuer_certificate is None and hint_ca_issuers:
|
||||
try:
|
||||
raw_intermediary_response = await session.get(hint_ca_issuers[0])
|
||||
except RequestException as e:
|
||||
if is_cancelled_error_root_cause(e):
|
||||
return
|
||||
except asyncio.CancelledError: # don't raise any error or warnings!
|
||||
return
|
||||
else:
|
||||
if raw_intermediary_response.status_code and 300 > raw_intermediary_response.status_code >= 200:
|
||||
raw_intermediary_content = raw_intermediary_response.content
|
||||
|
||||
if raw_intermediary_content is not None:
|
||||
# binary DER
|
||||
if b"-----BEGIN CERTIFICATE-----" not in raw_intermediary_content:
|
||||
issuer_certificate = Certificate(raw_intermediary_content)
|
||||
# b64 PEM
|
||||
elif b"-----BEGIN CERTIFICATE-----" in raw_intermediary_content:
|
||||
issuer_certificate = Certificate(
|
||||
ssl.PEM_cert_to_DER_cert(raw_intermediary_content.decode())
|
||||
)
|
||||
|
||||
# Well! We're out of luck. No further should we go.
|
||||
if issuer_certificate is None:
|
||||
cache.incr_failure()
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate "
|
||||
"via OCSP. You are seeing this warning due to enabling strict mode for OCSP / "
|
||||
"Revocation check. Reason: Remote did not provide any intermediary certificate."
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
|
||||
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
|
||||
else:
|
||||
issuer_certificate = Certificate(conn_info.issuer_certificate_der)
|
||||
|
||||
try:
|
||||
req = OCSPRequest(peer_certificate.public_bytes(), issuer_certificate.public_bytes())
|
||||
except ValueError:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
"Reason: The X509 OCSP generator failed to assemble the request."
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
ocsp_http_response = await session.post(
|
||||
endpoints[randint(0, len(endpoints) - 1)],
|
||||
data=req.public_bytes(),
|
||||
headers={"Content-Type": "application/ocsp-request"},
|
||||
timeout=timeout,
|
||||
)
|
||||
except RequestException as e:
|
||||
if is_cancelled_error_root_cause(e):
|
||||
return
|
||||
cache.incr_failure()
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
f"Reason: {e}"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
except asyncio.CancelledError: # don't raise any error or warnings!
|
||||
return
|
||||
|
||||
if ocsp_http_response.status_code and 300 > ocsp_http_response.status_code >= 200:
|
||||
if ocsp_http_response.content is None:
|
||||
return
|
||||
|
||||
try:
|
||||
ocsp_resp = OCSPResponse(ocsp_http_response.content)
|
||||
except ValueError:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
"Reason: The X509 OCSP parser failed to read the response"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
|
||||
if verify_signature:
|
||||
# Verify the signature of the OCSP response with issuer public key
|
||||
try:
|
||||
if not ocsp_resp.authenticate_for(issuer_certificate.public_bytes()): # type: ignore[attr-defined]
|
||||
raise SSLError(
|
||||
f"Unable to establish a secure connection to {r.url} "
|
||||
"because the OCSP response received has been tampered. "
|
||||
"You could be targeted by a MITM attack."
|
||||
)
|
||||
except ValueError:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently "
|
||||
"valid certificate via OCSP. You are seeing this warning due to "
|
||||
"enabling strict mode for OCSP / Revocation check. "
|
||||
"Reason: The X509 OCSP response is signed using an unsupported algorithm."
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
|
||||
cache.save(peer_certificate, issuer_certificate, ocsp_resp)
|
||||
|
||||
if ocsp_resp.response_status == OCSPResponseStatus.SUCCESSFUL:
|
||||
if ocsp_resp.certificate_status == OCSPCertStatus.REVOKED:
|
||||
r.ocsp_verified = False
|
||||
raise SSLError(
|
||||
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
|
||||
f"by issuer ({readable_revocation_reason(ocsp_resp.revocation_reason) or 'unspecified'}). "
|
||||
"You should avoid trying to request anything from it as the remote has been compromised. "
|
||||
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
|
||||
"for more information."
|
||||
)
|
||||
if ocsp_resp.certificate_status == OCSPCertStatus.UNKNOWN:
|
||||
r.ocsp_verified = False
|
||||
if strict is True:
|
||||
raise SSLError(
|
||||
f"Unable to establish a secure connection to {r.url} because the issuer does not know whether "
|
||||
"certificate is valid or not. This error occurred because you enabled strict mode for "
|
||||
"the OCSP / Revocation check."
|
||||
)
|
||||
else:
|
||||
r.ocsp_verified = True
|
||||
else:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
f"OCSP Server Status: {ocsp_resp.response_status}"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
else:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
f"OCSP Server Status: {str(ocsp_http_response)}"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ("verify",)
|
||||
@@ -0,0 +1,325 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from ..._constant import DEFAULT_RETRIES
|
||||
from ...adapters import BaseAdapter
|
||||
from ...models import PreparedRequest, Response
|
||||
from ...packages.urllib3.exceptions import MaxRetryError
|
||||
from ...packages.urllib3.response import BytesQueueBuffer
|
||||
from ...packages.urllib3.response import HTTPResponse as BaseHTTPResponse
|
||||
from ...packages.urllib3.util import Timeout as TimeoutSauce
|
||||
from ...packages.urllib3.util.retry import Retry
|
||||
from ...structures import CaseInsensitiveDict
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ...typing import ProxyType, RetryType, TLSClientCertType, TLSVerifyType, WSGIApp
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
class _WSGIRawIO:
|
||||
"""File-like wrapper around a WSGI response iterator for streaming."""
|
||||
|
||||
def __init__(self, generator: typing.Generator[bytes, None, None], headers: list[tuple[str, str]]) -> None:
|
||||
self._generator = generator
|
||||
self._buffer = BytesQueueBuffer()
|
||||
self._closed = False
|
||||
self.headers = headers
|
||||
self.extension: typing.Any = None
|
||||
|
||||
def read(
|
||||
self,
|
||||
amt: int | None = None,
|
||||
decode_content: bool = True,
|
||||
) -> bytes:
|
||||
if self._closed:
|
||||
return b""
|
||||
|
||||
if amt is None or amt < 0:
|
||||
# Read all remaining
|
||||
for chunk in self._generator:
|
||||
self._buffer.put(chunk)
|
||||
return self._buffer.get(len(self._buffer))
|
||||
|
||||
# Read specific amount
|
||||
while len(self._buffer) < amt:
|
||||
try:
|
||||
self._buffer.put(next(self._generator))
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
if len(self._buffer) == 0:
|
||||
return b""
|
||||
|
||||
return self._buffer.get(min(amt, len(self._buffer)))
|
||||
|
||||
def stream(self, amt: int, decode_content: bool = True) -> typing.Generator[bytes, None, None]:
|
||||
"""Iterate over chunks of the response."""
|
||||
while True:
|
||||
chunk = self.read(amt)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
def close(self) -> None:
|
||||
self._closed = True
|
||||
if hasattr(self._generator, "close"):
|
||||
self._generator.close()
|
||||
|
||||
def __iter__(self) -> typing.Iterator[bytes]:
|
||||
return self
|
||||
|
||||
def __next__(self) -> bytes:
|
||||
chunk = self.read(8192)
|
||||
if not chunk:
|
||||
raise StopIteration
|
||||
return chunk
|
||||
|
||||
|
||||
class WebServerGatewayInterface(BaseAdapter):
|
||||
"""Adapter for making requests to WSGI applications directly."""
|
||||
|
||||
def __init__(self, app: WSGIApp, max_retries: RetryType = DEFAULT_RETRIES) -> None:
|
||||
"""
|
||||
Initialize the WSGI adapter.
|
||||
|
||||
:param app: A WSGI application callable.
|
||||
:param max_retries: Maximum number of retries for requests.
|
||||
"""
|
||||
super().__init__()
|
||||
self.app = app
|
||||
|
||||
if isinstance(max_retries, Retry):
|
||||
self.max_retries = max_retries
|
||||
else:
|
||||
self.max_retries = Retry.from_int(max_retries)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<WSGIAdapter />"
|
||||
|
||||
def send(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
stream: bool = False,
|
||||
timeout: int | float | tuple | TimeoutSauce | None = None,
|
||||
verify: TLSVerifyType = True,
|
||||
cert: TLSClientCertType | None = None,
|
||||
proxies: ProxyType | None = None,
|
||||
on_post_connection: typing.Callable[[typing.Any], None] | None = None,
|
||||
on_upload_body: typing.Callable[[int, int | None, bool, bool], None] | None = None,
|
||||
on_early_response: typing.Callable[[Response], None] | None = None,
|
||||
multiplexed: bool = False,
|
||||
) -> Response:
|
||||
"""Send a PreparedRequest to the WSGI application."""
|
||||
if isinstance(timeout, tuple):
|
||||
if len(timeout) == 3:
|
||||
timeout = timeout[2] or timeout[0] # prefer total, fallback connect
|
||||
else:
|
||||
timeout = timeout[0] # use connect
|
||||
elif isinstance(timeout, TimeoutSauce):
|
||||
timeout = timeout.total or timeout.connect_timeout
|
||||
|
||||
retries = self.max_retries
|
||||
method = request.method or "GET"
|
||||
|
||||
while True:
|
||||
try:
|
||||
response = self._do_send(request, stream)
|
||||
except Exception as err:
|
||||
try:
|
||||
retries = retries.increment(method, request.url, error=err)
|
||||
except MaxRetryError:
|
||||
raise
|
||||
|
||||
retries.sleep()
|
||||
continue
|
||||
|
||||
# we rely on the urllib3 implementation for retries
|
||||
# so we basically mock a response to get it to work
|
||||
base_response = BaseHTTPResponse(
|
||||
body=b"",
|
||||
headers=response.headers,
|
||||
status=response.status_code,
|
||||
request_method=request.method,
|
||||
request_url=request.url,
|
||||
)
|
||||
|
||||
# Check if we should retry based on status code
|
||||
has_retry_after = bool(response.headers.get("Retry-After"))
|
||||
|
||||
if retries.is_retry(method, response.status_code, has_retry_after):
|
||||
try:
|
||||
retries = retries.increment(method, request.url, response=base_response)
|
||||
except MaxRetryError:
|
||||
if retries.raise_on_status:
|
||||
raise
|
||||
return response
|
||||
|
||||
retries.sleep(base_response)
|
||||
continue
|
||||
|
||||
return response
|
||||
|
||||
def _do_send_sse(self, request: PreparedRequest) -> Response:
|
||||
"""Handle SSE requests by wrapping the WSGI response with SSE parsing."""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ._sse import WSGISSEExtension
|
||||
|
||||
# Convert scheme: sse:// -> https://, psse:// -> http://
|
||||
parsed = urlparse(request.url)
|
||||
if parsed.scheme == "sse":
|
||||
http_scheme = "https"
|
||||
else:
|
||||
http_scheme = "http"
|
||||
|
||||
original_url = request.url
|
||||
request.url = request.url.replace(f"{parsed.scheme}://", f"{http_scheme}://", 1) # type: ignore[union-attr,str-bytes-safe]
|
||||
|
||||
environ = self._create_environ(request)
|
||||
request.url = original_url
|
||||
|
||||
status_code = None
|
||||
response_headers: list[tuple[str, str]] = []
|
||||
|
||||
def start_response(status: str, headers: list[tuple[str, str]], exc_info=None):
|
||||
nonlocal status_code, response_headers
|
||||
status_code = int(status.split(" ", 1)[0])
|
||||
response_headers = headers
|
||||
|
||||
result = self.app(environ, start_response)
|
||||
|
||||
def generate():
|
||||
try:
|
||||
yield from result
|
||||
finally:
|
||||
if hasattr(result, "close"):
|
||||
result.close()
|
||||
|
||||
ext = WSGISSEExtension(generate())
|
||||
|
||||
response = Response()
|
||||
response.status_code = status_code
|
||||
response.headers = CaseInsensitiveDict(response_headers)
|
||||
response.request = request
|
||||
response.url = original_url
|
||||
response.encoding = response.headers.get("content-type", "utf-8") # type: ignore[assignment]
|
||||
|
||||
raw_io = _WSGIRawIO(iter([]), response_headers) # type: ignore[arg-type]
|
||||
raw_io.extension = ext
|
||||
response.raw = raw_io # type: ignore
|
||||
response._content = False
|
||||
response._content_consumed = False
|
||||
|
||||
return response
|
||||
|
||||
def _do_send(self, request: PreparedRequest, stream: bool) -> Response:
|
||||
"""Perform the actual WSGI request."""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(request.url)
|
||||
|
||||
if parsed.scheme in ("ws", "wss"):
|
||||
raise NotImplementedError("WebSocket is not supported over WSGI")
|
||||
|
||||
if parsed.scheme in ("sse", "psse"):
|
||||
return self._do_send_sse(request)
|
||||
|
||||
environ = self._create_environ(request)
|
||||
|
||||
status_code = None
|
||||
response_headers: list[tuple[str, str]] = []
|
||||
|
||||
def start_response(status: str, headers: list[tuple[str, str]], exc_info=None):
|
||||
nonlocal status_code, response_headers
|
||||
status_code = int(status.split(" ", 1)[0])
|
||||
response_headers = headers
|
||||
|
||||
result = self.app(environ, start_response)
|
||||
|
||||
response = Response()
|
||||
|
||||
response.status_code = status_code
|
||||
response.headers = CaseInsensitiveDict(response_headers)
|
||||
response.request = request
|
||||
response.url = request.url
|
||||
response.encoding = response.headers.get("content-type", "utf-8") # type: ignore[assignment]
|
||||
|
||||
# Wrap the WSGI iterator for streaming
|
||||
def generate():
|
||||
try:
|
||||
yield from result
|
||||
finally:
|
||||
if hasattr(result, "close"):
|
||||
result.close()
|
||||
|
||||
response.raw = _WSGIRawIO(generate(), response_headers) # type: ignore
|
||||
|
||||
if stream:
|
||||
response._content = False # Indicate content not yet consumed
|
||||
response._content_consumed = False
|
||||
else:
|
||||
# Consume all content immediately
|
||||
body_chunks: list[bytes] = []
|
||||
try:
|
||||
for chunk in result:
|
||||
body_chunks.append(chunk)
|
||||
finally:
|
||||
if hasattr(result, "close"):
|
||||
result.close()
|
||||
response._content = b"".join(body_chunks)
|
||||
|
||||
return response
|
||||
|
||||
def _create_environ(self, request: PreparedRequest) -> dict:
|
||||
"""Create a WSGI environ dict from a PreparedRequest."""
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
parsed = urlparse(request.url)
|
||||
|
||||
body = request.body or b""
|
||||
|
||||
if isinstance(body, str):
|
||||
body = body.encode("utf-8")
|
||||
elif isinstance(body, typing.Iterable) and not isinstance(body, (str, bytes, bytearray, tuple)):
|
||||
tmp = b""
|
||||
|
||||
for chunk in body:
|
||||
tmp += chunk # type: ignore[operator]
|
||||
|
||||
body = tmp
|
||||
|
||||
environ = {
|
||||
"REQUEST_METHOD": request.method,
|
||||
"SCRIPT_NAME": "",
|
||||
"PATH_INFO": unquote(parsed.path) or "/",
|
||||
"QUERY_STRING": parsed.query or "",
|
||||
"SERVER_NAME": parsed.hostname or "localhost",
|
||||
"SERVER_PORT": str(parsed.port or (443 if parsed.scheme == "https" else 80)),
|
||||
"SERVER_PROTOCOL": "HTTP/1.1",
|
||||
"wsgi.version": (1, 0),
|
||||
"wsgi.url_scheme": parsed.scheme or "http",
|
||||
"wsgi.input": BytesIO(body), # type: ignore[arg-type]
|
||||
"wsgi.errors": BytesIO(),
|
||||
"wsgi.multithread": True,
|
||||
"wsgi.multiprocess": True,
|
||||
"wsgi.run_once": False,
|
||||
"CONTENT_LENGTH": str(len(body)), # type: ignore[arg-type]
|
||||
}
|
||||
|
||||
if request.headers:
|
||||
for key, value in request.headers.items():
|
||||
key_upper = key.upper().replace("-", "_")
|
||||
if key_upper == "CONTENT_TYPE":
|
||||
environ["CONTENT_TYPE"] = value
|
||||
elif key_upper == "CONTENT_LENGTH":
|
||||
environ["CONTENT_LENGTH"] = value
|
||||
else:
|
||||
environ[f"HTTP_{key_upper}"] = value
|
||||
|
||||
return environ
|
||||
|
||||
def close(self) -> None:
|
||||
"""Clean up adapter resources."""
|
||||
pass
|
||||
@@ -0,0 +1,821 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import threading
|
||||
import typing
|
||||
from concurrent.futures import Future
|
||||
|
||||
from ...._constant import DEFAULT_RETRIES
|
||||
from ....adapters import AsyncBaseAdapter, BaseAdapter
|
||||
from ....exceptions import ConnectTimeout, ReadTimeout
|
||||
from ....models import AsyncResponse, PreparedRequest, Response
|
||||
from ....packages.urllib3._async.response import AsyncHTTPResponse as BaseHTTPResponse
|
||||
from ....packages.urllib3.contrib.ssa._timeout import timeout as asyncio_timeout
|
||||
from ....packages.urllib3.exceptions import MaxRetryError
|
||||
from ....packages.urllib3.response import BytesQueueBuffer
|
||||
from ....packages.urllib3.util import Timeout as TimeoutSauce
|
||||
from ....packages.urllib3.util.retry import Retry
|
||||
from ....structures import CaseInsensitiveDict
|
||||
from ....utils import _swap_context
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ....typing import ASGIApp, ASGIMessage, ProxyType, RetryType, TLSClientCertType, TLSVerifyType
|
||||
|
||||
|
||||
class _ASGIRawIO:
|
||||
"""Async file-like wrapper around an ASGI response for true async streaming."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
response_queue: asyncio.Queue[ASGIMessage | None],
|
||||
response_complete: asyncio.Event,
|
||||
timeout: float | None = None,
|
||||
) -> None:
|
||||
self._response_queue = response_queue
|
||||
self._response_complete = response_complete
|
||||
self._timeout = timeout
|
||||
self._buffer = BytesQueueBuffer()
|
||||
self._closed = False
|
||||
self._finished = False
|
||||
self._task: asyncio.Task | None = None
|
||||
self.headers: dict | None = None
|
||||
self.extension: typing.Any = None
|
||||
|
||||
async def read(self, amt: int | None = None, decode_content: bool = True) -> bytes:
|
||||
if self._closed or self._finished:
|
||||
return self._buffer.get(len(self._buffer))
|
||||
|
||||
if amt is None or amt < 0:
|
||||
async for chunk in self._async_iter_chunks():
|
||||
self._buffer.put(chunk)
|
||||
self._finished = True
|
||||
return self._buffer.get(len(self._buffer))
|
||||
|
||||
while len(self._buffer) < amt and not self._finished:
|
||||
chunk = await self._get_next_chunk() # type: ignore[assignment]
|
||||
if chunk is None:
|
||||
self._finished = True
|
||||
break
|
||||
self._buffer.put(chunk)
|
||||
|
||||
if len(self._buffer) == 0:
|
||||
return b""
|
||||
|
||||
return self._buffer.get(min(amt, len(self._buffer)))
|
||||
|
||||
async def _get_next_chunk(self) -> bytes | None:
|
||||
try:
|
||||
async with asyncio_timeout(self._timeout):
|
||||
message = await self._response_queue.get()
|
||||
if message is None:
|
||||
return None
|
||||
if message["type"] == "http.response.body":
|
||||
return message.get("body", b"")
|
||||
return None
|
||||
except asyncio.TimeoutError:
|
||||
await self._cancel_task()
|
||||
raise ReadTimeout("Read timed out while streaming ASGI response")
|
||||
except asyncio.CancelledError:
|
||||
return None
|
||||
|
||||
async def _async_iter_chunks(self) -> typing.AsyncGenerator[bytes]:
|
||||
while True:
|
||||
try:
|
||||
async with asyncio_timeout(self._timeout):
|
||||
message = await self._response_queue.get()
|
||||
except asyncio.TimeoutError:
|
||||
await self._cancel_task()
|
||||
raise ReadTimeout("Read timed out while streaming ASGI response")
|
||||
if message is None:
|
||||
break
|
||||
if message["type"] == "http.response.body":
|
||||
chunk = message.get("body", b"")
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
async def _cancel_task(self) -> None:
|
||||
if self._task is not None and not self._task.done():
|
||||
self._task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._task
|
||||
|
||||
def stream(self, amt: int, decode_content: bool = True) -> typing.AsyncGenerator[bytes]:
|
||||
return self._async_stream(amt)
|
||||
|
||||
async def _async_stream(self, amt: int) -> typing.AsyncGenerator[bytes]:
|
||||
while True:
|
||||
chunk = await self.read(amt)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
def close(self) -> None:
|
||||
self._closed = True
|
||||
self._response_complete.set()
|
||||
|
||||
def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
return self._async_iter_self()
|
||||
|
||||
async def _async_iter_self(self) -> typing.AsyncIterator[bytes]:
|
||||
async for chunk in self._async_iter_chunks():
|
||||
yield chunk
|
||||
|
||||
async def __anext__(self) -> bytes:
|
||||
chunk = await self.read(8192)
|
||||
if not chunk:
|
||||
raise StopAsyncIteration
|
||||
return chunk
|
||||
|
||||
|
||||
class AsyncServerGatewayInterface(AsyncBaseAdapter):
|
||||
"""Adapter for making requests to ASGI applications directly."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
raise_app_exceptions: bool = True,
|
||||
max_retries: RetryType = DEFAULT_RETRIES,
|
||||
lifespan_state: dict[str, typing.Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.app = app
|
||||
self.raise_app_exceptions = raise_app_exceptions
|
||||
self._lifespan_state = lifespan_state
|
||||
|
||||
if isinstance(max_retries, Retry):
|
||||
self.max_retries = max_retries
|
||||
else:
|
||||
self.max_retries = Retry.from_int(max_retries)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<ASGIAdapter Native/>"
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
stream: bool = False,
|
||||
timeout: int | float | tuple | TimeoutSauce | None = None,
|
||||
verify: TLSVerifyType = True,
|
||||
cert: TLSClientCertType | None = None,
|
||||
proxies: ProxyType | None = None,
|
||||
on_post_connection: typing.Callable[[typing.Any], typing.Awaitable[None]] | None = None,
|
||||
on_upload_body: typing.Callable[[int, int | None, bool, bool], typing.Awaitable[None]] | None = None,
|
||||
on_early_response: typing.Callable[[Response], typing.Awaitable[None]] | None = None,
|
||||
multiplexed: bool = False,
|
||||
) -> AsyncResponse:
|
||||
"""Send a PreparedRequest to the ASGI application."""
|
||||
if isinstance(timeout, tuple):
|
||||
if len(timeout) == 3:
|
||||
timeout = timeout[2] or timeout[0] # prefer total, fallback connect
|
||||
else:
|
||||
timeout = timeout[0] # use connect
|
||||
elif isinstance(timeout, TimeoutSauce):
|
||||
timeout = timeout.total or timeout.connect_timeout
|
||||
|
||||
retries = self.max_retries
|
||||
method = request.method or "GET"
|
||||
|
||||
while True:
|
||||
try:
|
||||
response = await self._do_send(request, stream, timeout)
|
||||
except Exception as err:
|
||||
try:
|
||||
retries = retries.increment(method, request.url, error=err)
|
||||
except MaxRetryError:
|
||||
raise
|
||||
|
||||
await retries.async_sleep()
|
||||
continue
|
||||
|
||||
# we rely on the urllib3 implementation for retries
|
||||
# so we basically mock a response to get it to work
|
||||
base_response = BaseHTTPResponse(
|
||||
body=b"",
|
||||
headers=response.headers,
|
||||
status=response.status_code,
|
||||
request_method=request.method,
|
||||
request_url=request.url,
|
||||
)
|
||||
|
||||
# Check if we should retry based on status code
|
||||
has_retry_after = bool(response.headers.get("Retry-After"))
|
||||
|
||||
if retries.is_retry(method, response.status_code, has_retry_after):
|
||||
try:
|
||||
retries = retries.increment(method, request.url, response=base_response)
|
||||
except MaxRetryError:
|
||||
if retries.raise_on_status:
|
||||
raise
|
||||
return response
|
||||
|
||||
await retries.async_sleep(base_response)
|
||||
continue
|
||||
|
||||
return response
|
||||
|
||||
async def _do_send_ws(self, request: PreparedRequest) -> AsyncResponse:
|
||||
"""Handle WebSocket requests via the ASGI websocket protocol."""
|
||||
from ._ws import ASGIWebSocketExtension
|
||||
|
||||
scope = self._create_ws_scope(request)
|
||||
|
||||
ext = ASGIWebSocketExtension()
|
||||
await ext.start(self.app, scope)
|
||||
|
||||
response = Response()
|
||||
response.status_code = 101
|
||||
response.headers = CaseInsensitiveDict({"upgrade": "websocket"})
|
||||
response.request = request
|
||||
response.url = request.url
|
||||
|
||||
raw_io = _ASGIRawIO(asyncio.Queue(), asyncio.Event())
|
||||
raw_io.extension = ext
|
||||
response.raw = raw_io # type: ignore
|
||||
|
||||
_swap_context(response)
|
||||
return response # type: ignore
|
||||
|
||||
async def _do_send_sse(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
timeout: int | float | None,
|
||||
) -> AsyncResponse:
|
||||
"""Handle SSE requests via ASGI HTTP streaming with SSE parsing."""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ._sse import ASGISSEExtension
|
||||
|
||||
# Convert scheme: sse:// -> https://, psse:// -> http://
|
||||
parsed = urlparse(request.url)
|
||||
if parsed.scheme == "sse":
|
||||
http_scheme = "https"
|
||||
else:
|
||||
http_scheme = "http"
|
||||
|
||||
# Build a modified request with http scheme for scope creation
|
||||
original_url = request.url
|
||||
request.url = request.url.replace(f"{parsed.scheme}://", f"{http_scheme}://", 1) # type: ignore[union-attr,str-bytes-safe]
|
||||
|
||||
scope = self._create_scope(request)
|
||||
request.url = original_url
|
||||
|
||||
body = request.body or b""
|
||||
if isinstance(body, str):
|
||||
body = body.encode("utf-8")
|
||||
|
||||
ext = ASGISSEExtension()
|
||||
start_message = await ext.start(self.app, scope, body) # type: ignore[arg-type]
|
||||
|
||||
status_code = start_message["status"]
|
||||
response_headers = start_message.get("headers", [])
|
||||
headers_dict = {k.decode("latin-1"): v.decode("latin-1") for k, v in response_headers}
|
||||
|
||||
response = Response()
|
||||
response.status_code = status_code
|
||||
response.headers = CaseInsensitiveDict(headers_dict)
|
||||
response.request = request
|
||||
response.url = original_url
|
||||
response.encoding = response.headers.get("content-type", "utf-8") # type: ignore[assignment]
|
||||
|
||||
raw_io = _ASGIRawIO(asyncio.Queue(), asyncio.Event(), timeout)
|
||||
raw_io.headers = headers_dict
|
||||
raw_io.extension = ext
|
||||
response.raw = raw_io # type: ignore
|
||||
response._content = False
|
||||
response._content_consumed = False
|
||||
|
||||
_swap_context(response)
|
||||
return response # type: ignore
|
||||
|
||||
async def _do_send(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
stream: bool,
|
||||
timeout: int | float | None,
|
||||
) -> AsyncResponse:
|
||||
"""Perform the actual ASGI request."""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(request.url)
|
||||
|
||||
if parsed.scheme in ("ws", "wss"):
|
||||
return await self._do_send_ws(request)
|
||||
|
||||
if parsed.scheme in ("sse", "psse"):
|
||||
return await self._do_send_sse(request, timeout)
|
||||
|
||||
scope = self._create_scope(request)
|
||||
|
||||
body = request.body or b""
|
||||
body_iter: typing.AsyncIterator[bytes] | typing.AsyncIterator[str] | None = None
|
||||
|
||||
# Check if body is an async iterable
|
||||
if hasattr(body, "__aiter__"):
|
||||
body_iter = body.__aiter__()
|
||||
body = b"" # Will be streamed
|
||||
elif isinstance(body, str):
|
||||
body = body.encode("utf-8")
|
||||
|
||||
request_complete = False
|
||||
response_complete = asyncio.Event()
|
||||
response_queue: asyncio.Queue[ASGIMessage | None] = asyncio.Queue()
|
||||
app_exception: Exception | None = None
|
||||
|
||||
async def receive() -> ASGIMessage:
|
||||
nonlocal request_complete
|
||||
if request_complete:
|
||||
await response_complete.wait()
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
if body_iter is not None:
|
||||
# Stream chunks from async iterable
|
||||
try:
|
||||
chunk = await body_iter.__anext__()
|
||||
if isinstance(chunk, str):
|
||||
chunk = chunk.encode("utf-8")
|
||||
return {"type": "http.request", "body": chunk, "more_body": True}
|
||||
except StopAsyncIteration:
|
||||
request_complete = True
|
||||
return {"type": "http.request", "body": b"", "more_body": False}
|
||||
else:
|
||||
# Single body chunk
|
||||
request_complete = True
|
||||
return {"type": "http.request", "body": body, "more_body": False}
|
||||
|
||||
async def send_func(message: ASGIMessage) -> None:
|
||||
await response_queue.put(message)
|
||||
if message["type"] == "http.response.body" and not message.get("more_body", False):
|
||||
response_complete.set()
|
||||
|
||||
async def run_app() -> None:
|
||||
nonlocal app_exception
|
||||
try:
|
||||
await self.app(scope, receive, send_func)
|
||||
except Exception as ex:
|
||||
app_exception = ex
|
||||
finally:
|
||||
await response_queue.put(None)
|
||||
|
||||
if stream:
|
||||
return await self._stream_response(
|
||||
request, response_queue, response_complete, run_app, lambda: app_exception, timeout
|
||||
)
|
||||
else:
|
||||
return await self._buffered_response(
|
||||
request, response_queue, response_complete, run_app, lambda: app_exception, timeout
|
||||
)
|
||||
|
||||
async def _stream_response(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
response_queue: asyncio.Queue[ASGIMessage | None],
|
||||
response_complete: asyncio.Event,
|
||||
run_app: typing.Callable[[], typing.Awaitable[None]],
|
||||
get_exception: typing.Callable[[], Exception | None],
|
||||
timeout: float | None,
|
||||
) -> AsyncResponse:
|
||||
status_code: int | None = None
|
||||
response_headers: list[tuple[bytes, bytes]] = []
|
||||
|
||||
task = asyncio.create_task(run_app()) # type: ignore[var-annotated,arg-type]
|
||||
|
||||
try:
|
||||
# Wait for http.response.start with timeout
|
||||
async with asyncio_timeout(timeout):
|
||||
while True:
|
||||
message = await response_queue.get()
|
||||
if message is None:
|
||||
break
|
||||
if message["type"] == "http.response.start":
|
||||
status_code = message["status"]
|
||||
response_headers = message.get("headers", [])
|
||||
break
|
||||
|
||||
headers_dict = {k.decode("latin-1"): v.decode("latin-1") for k, v in response_headers}
|
||||
|
||||
raw_io = _ASGIRawIO(response_queue, response_complete, timeout)
|
||||
raw_io.headers = headers_dict
|
||||
raw_io._task = task
|
||||
|
||||
response = Response()
|
||||
response.status_code = status_code
|
||||
response.headers = CaseInsensitiveDict(headers_dict)
|
||||
response.request = request
|
||||
response.url = request.url
|
||||
response.encoding = response.headers.get("content-type", "utf-8") # type: ignore[assignment]
|
||||
response.raw = raw_io # type: ignore
|
||||
response._content = False
|
||||
response._content_consumed = False
|
||||
_swap_context(response)
|
||||
|
||||
return response # type: ignore
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
raise ConnectTimeout("Timed out waiting for ASGI response headers")
|
||||
|
||||
except Exception:
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
raise
|
||||
|
||||
async def _buffered_response(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
response_queue: asyncio.Queue[ASGIMessage | None],
|
||||
response_complete: asyncio.Event,
|
||||
run_app: typing.Callable[[], typing.Awaitable[None]],
|
||||
get_exception: typing.Callable[[], Exception | None],
|
||||
timeout: float | None,
|
||||
) -> AsyncResponse:
|
||||
status_code: int | None = None
|
||||
response_headers: list[tuple[bytes, bytes]] = []
|
||||
body_chunks: list[bytes] = []
|
||||
|
||||
task = asyncio.create_task(run_app()) # type: ignore[var-annotated,arg-type]
|
||||
|
||||
try:
|
||||
async with asyncio_timeout(timeout):
|
||||
while True:
|
||||
message = await response_queue.get()
|
||||
if message is None:
|
||||
break
|
||||
if message["type"] == "http.response.start":
|
||||
status_code = message["status"]
|
||||
response_headers = message.get("headers", [])
|
||||
elif message["type"] == "http.response.body":
|
||||
chunk = message.get("body", b"")
|
||||
if chunk:
|
||||
body_chunks.append(chunk)
|
||||
|
||||
await task
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
raise ReadTimeout("Timed out reading ASGI response body")
|
||||
|
||||
except Exception:
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
raise
|
||||
|
||||
if self.raise_app_exceptions and get_exception() is not None:
|
||||
raise get_exception() # type: ignore
|
||||
|
||||
headers_dict = {k.decode("latin-1"): v.decode("latin-1") for k, v in response_headers}
|
||||
|
||||
response = Response()
|
||||
response.status_code = status_code
|
||||
response.headers = CaseInsensitiveDict(headers_dict)
|
||||
response.request = request
|
||||
response.url = request.url
|
||||
response.encoding = response.headers.get("content-type", "utf-8") # type: ignore[assignment]
|
||||
response._content = b"".join(body_chunks)
|
||||
response.raw = _ASGIRawIO(response_queue, response_complete, timeout) # type: ignore
|
||||
response.raw.headers = headers_dict
|
||||
|
||||
_swap_context(response)
|
||||
|
||||
return response # type: ignore[return-value]
|
||||
|
||||
def _create_scope(self, request: PreparedRequest) -> dict:
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
parsed = urlparse(request.url)
|
||||
headers: list[tuple[bytes, bytes]] = []
|
||||
if request.headers:
|
||||
for key, value in request.headers.items():
|
||||
headers.append((key.lower().encode("latin-1"), value.encode("latin-1")))
|
||||
|
||||
scope: dict[str, typing.Any] = {
|
||||
"type": "http",
|
||||
"asgi": {"version": "3.0"},
|
||||
"http_version": "1.1",
|
||||
"method": request.method,
|
||||
"scheme": "http",
|
||||
"path": unquote(parsed.path) or "/",
|
||||
"query_string": (parsed.query or "").encode("latin-1"), # type: ignore[union-attr]
|
||||
"root_path": "",
|
||||
"headers": headers,
|
||||
"server": (
|
||||
parsed.hostname or "localhost",
|
||||
parsed.port or (443 if parsed.scheme == "https" else 80),
|
||||
),
|
||||
}
|
||||
|
||||
# Include lifespan state if available (for frameworks like Starlette that use it)
|
||||
if self._lifespan_state is not None:
|
||||
scope["state"] = self._lifespan_state.copy()
|
||||
|
||||
return scope
|
||||
|
||||
def _create_ws_scope(self, request: PreparedRequest) -> dict:
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
parsed = urlparse(request.url)
|
||||
headers: list[tuple[bytes, bytes]] = []
|
||||
if request.headers:
|
||||
for key, value in request.headers.items():
|
||||
headers.append((key.lower().encode("latin-1"), value.encode("latin-1")))
|
||||
|
||||
scheme = "wss" if parsed.scheme == "wss" else "ws"
|
||||
|
||||
scope: dict[str, typing.Any] = {
|
||||
"type": "websocket",
|
||||
"asgi": {"version": "3.0"},
|
||||
"http_version": "1.1",
|
||||
"scheme": scheme,
|
||||
"path": unquote(parsed.path) or "/",
|
||||
"query_string": (parsed.query or "").encode("latin-1"), # type: ignore[union-attr]
|
||||
"root_path": "",
|
||||
"headers": headers,
|
||||
"server": (
|
||||
parsed.hostname or "localhost",
|
||||
parsed.port or (443 if scheme == "wss" else 80),
|
||||
),
|
||||
}
|
||||
|
||||
if self._lifespan_state is not None:
|
||||
scope["state"] = self._lifespan_state.copy()
|
||||
|
||||
return scope
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class ThreadAsyncServerGatewayInterface(BaseAdapter):
|
||||
"""Synchronous adapter for ASGI applications using a background event loop."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
raise_app_exceptions: bool = True,
|
||||
max_retries: RetryType = DEFAULT_RETRIES,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.app = app
|
||||
self.raise_app_exceptions = raise_app_exceptions
|
||||
|
||||
if isinstance(max_retries, Retry):
|
||||
self.max_retries = max_retries
|
||||
else:
|
||||
self.max_retries = Retry.from_int(max_retries)
|
||||
|
||||
self._async_adapter: AsyncServerGatewayInterface | None = None
|
||||
self._loop: typing.Any = None # asyncio.AbstractEventLoop
|
||||
self._thread: threading.Thread | None = None
|
||||
self._started = threading.Event()
|
||||
self._lifespan_task: asyncio.Task | None = None
|
||||
self._lifespan_receive_queue: asyncio.Queue[ASGIMessage] | None = None
|
||||
self._lifespan_startup_complete = threading.Event()
|
||||
self._lifespan_startup_failed: Exception | None = None
|
||||
self._lifespan_state: dict[str, typing.Any] = {}
|
||||
self._startup_lock: threading.Lock = threading.Lock()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<ASGIAdapter Thread/>"
|
||||
|
||||
def _ensure_loop_running(self) -> None:
|
||||
"""Start the background event loop thread if not already running."""
|
||||
with self._startup_lock:
|
||||
if self._thread is not None and self._thread.is_alive():
|
||||
return
|
||||
|
||||
import asyncio
|
||||
|
||||
def run_loop() -> None:
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._lifespan_receive_queue = asyncio.Queue()
|
||||
self._async_adapter = AsyncServerGatewayInterface(
|
||||
self.app,
|
||||
raise_app_exceptions=self.raise_app_exceptions,
|
||||
max_retries=self.max_retries,
|
||||
lifespan_state=self._lifespan_state,
|
||||
)
|
||||
# Start lifespan handler
|
||||
self._lifespan_task = self._loop.create_task(self._handle_lifespan())
|
||||
self._started.set()
|
||||
self._loop.run_forever()
|
||||
|
||||
self._thread = threading.Thread(target=run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
self._started.wait()
|
||||
self._lifespan_startup_complete.wait()
|
||||
|
||||
if self._lifespan_startup_failed is not None:
|
||||
raise self._lifespan_startup_failed
|
||||
|
||||
async def _handle_lifespan(self) -> None:
|
||||
"""Handle ASGI lifespan protocol."""
|
||||
scope = {
|
||||
"type": "lifespan",
|
||||
"asgi": {"version": "3.0"},
|
||||
"state": self._lifespan_state,
|
||||
}
|
||||
|
||||
startup_complete = asyncio.Event()
|
||||
shutdown_complete = asyncio.Event()
|
||||
startup_failed: list[Exception] = []
|
||||
|
||||
# Keep local reference to avoid race condition during shutdown
|
||||
receive_queue = self._lifespan_receive_queue
|
||||
|
||||
async def receive() -> ASGIMessage:
|
||||
return await receive_queue.get() # type: ignore[union-attr]
|
||||
|
||||
async def send(message: ASGIMessage) -> None:
|
||||
if message["type"] == "lifespan.startup.complete":
|
||||
startup_complete.set()
|
||||
elif message["type"] == "lifespan.startup.failed":
|
||||
startup_failed.append(RuntimeError(message.get("message", "Lifespan startup failed")))
|
||||
startup_complete.set()
|
||||
elif message["type"] == "lifespan.shutdown.complete":
|
||||
shutdown_complete.set()
|
||||
elif message["type"] == "lifespan.shutdown.failed":
|
||||
shutdown_complete.set()
|
||||
|
||||
async def run_lifespan() -> None:
|
||||
try:
|
||||
await self.app(scope, receive, send)
|
||||
except Exception as e:
|
||||
if not startup_complete.is_set():
|
||||
startup_failed.append(e)
|
||||
startup_complete.set()
|
||||
|
||||
lifespan_task = asyncio.create_task(run_lifespan())
|
||||
|
||||
# Send startup event
|
||||
await receive_queue.put({"type": "lifespan.startup"}) # type: ignore[union-attr]
|
||||
await startup_complete.wait()
|
||||
|
||||
if startup_failed:
|
||||
self._lifespan_startup_failed = startup_failed[0]
|
||||
self._lifespan_startup_complete.set()
|
||||
|
||||
# Wait for shutdown signal (loop.stop() will cancel this)
|
||||
try:
|
||||
await asyncio.Future() # Wait forever until canceled
|
||||
except (asyncio.CancelledError, GeneratorExit):
|
||||
pass
|
||||
|
||||
# Send shutdown event - must happen before loop stops
|
||||
if receive_queue is not None:
|
||||
try:
|
||||
await receive_queue.put({"type": "lifespan.shutdown"})
|
||||
await asyncio.wait_for(shutdown_complete.wait(), timeout=5.0)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError, RuntimeError):
|
||||
pass
|
||||
lifespan_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await lifespan_task
|
||||
|
||||
def _do_send_ws(self, request: PreparedRequest) -> Response:
|
||||
"""Handle WebSocket requests synchronously via the background loop."""
|
||||
from .._ws import ThreadASGIWebSocketExtension
|
||||
|
||||
self._ensure_loop_running()
|
||||
|
||||
future: Future[Response] = Future()
|
||||
|
||||
async def run_ws() -> None:
|
||||
try:
|
||||
result = await self._async_adapter._do_send_ws(request) # type: ignore[union-attr]
|
||||
_swap_context(result)
|
||||
|
||||
# Wrap the async WS extension in a sync wrapper
|
||||
async_ext = result.raw.extension # type: ignore[union-attr]
|
||||
result.raw.extension = ThreadASGIWebSocketExtension(async_ext, self._loop) # type: ignore[union-attr]
|
||||
|
||||
future.set_result(result) # type: ignore[arg-type]
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
|
||||
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(run_ws()))
|
||||
return future.result()
|
||||
|
||||
def _do_send_sse(self, request: PreparedRequest, timeout: int | float | None = None) -> Response:
|
||||
"""Handle SSE requests synchronously via the background loop."""
|
||||
from .._sse import ThreadASGISSEExtension
|
||||
|
||||
self._ensure_loop_running()
|
||||
|
||||
future: Future[Response] = Future()
|
||||
|
||||
async def run_sse() -> None:
|
||||
try:
|
||||
result = await self._async_adapter._do_send_sse(request, timeout) # type: ignore[union-attr]
|
||||
_swap_context(result)
|
||||
|
||||
# Wrap the async SSE extension in a sync wrapper
|
||||
async_ext = result.raw.extension # type: ignore[union-attr]
|
||||
result.raw.extension = ThreadASGISSEExtension(async_ext, self._loop) # type: ignore[union-attr]
|
||||
|
||||
future.set_result(result) # type: ignore[arg-type]
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
|
||||
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(run_sse()))
|
||||
return future.result()
|
||||
|
||||
def send(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
stream: bool = False,
|
||||
timeout: int | float | tuple | TimeoutSauce | None = None,
|
||||
verify: TLSVerifyType = True,
|
||||
cert: TLSClientCertType | None = None,
|
||||
proxies: ProxyType | None = None,
|
||||
on_post_connection: typing.Callable[[typing.Any], None] | None = None,
|
||||
on_upload_body: typing.Callable[[int, int | None, bool, bool], None] | None = None,
|
||||
on_early_response: typing.Callable[[Response], None] | None = None,
|
||||
multiplexed: bool = False,
|
||||
) -> Response:
|
||||
"""Send a PreparedRequest to the ASGI application synchronously."""
|
||||
if isinstance(timeout, tuple):
|
||||
if len(timeout) == 3:
|
||||
timeout = timeout[2] or timeout[0] # prefer total, fallback connect
|
||||
else:
|
||||
timeout = timeout[0] # use connect
|
||||
elif isinstance(timeout, TimeoutSauce):
|
||||
timeout = timeout.total or timeout.connect_timeout
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(request.url)
|
||||
|
||||
if parsed.scheme in ("ws", "wss"):
|
||||
return self._do_send_ws(request)
|
||||
|
||||
if parsed.scheme in ("sse", "psse"):
|
||||
return self._do_send_sse(request, timeout)
|
||||
|
||||
if stream:
|
||||
raise ValueError(
|
||||
"ThreadAsyncServerGatewayInterface does not support streaming responses. "
|
||||
"Use stream=False or migrate to pure async/await implementation."
|
||||
)
|
||||
|
||||
self._ensure_loop_running()
|
||||
|
||||
future: Future[Response] = Future()
|
||||
|
||||
async def run_send() -> None:
|
||||
try:
|
||||
result = await self._async_adapter.send( # type: ignore[union-attr]
|
||||
request,
|
||||
stream=False,
|
||||
timeout=timeout,
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
proxies=proxies,
|
||||
)
|
||||
_swap_context(result)
|
||||
future.set_result(result) # type: ignore[arg-type]
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
|
||||
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(run_send()))
|
||||
|
||||
return future.result()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Clean up adapter resources."""
|
||||
if self._loop is not None and self._lifespan_task is not None:
|
||||
# Signal shutdown and wait for it to complete
|
||||
shutdown_done = threading.Event()
|
||||
|
||||
async def do_shutdown() -> None:
|
||||
if self._lifespan_task is not None:
|
||||
self._lifespan_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._lifespan_task
|
||||
shutdown_done.set()
|
||||
|
||||
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(do_shutdown()))
|
||||
shutdown_done.wait(timeout=6.0)
|
||||
|
||||
if self._loop is not None:
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
if self._thread is not None:
|
||||
self._thread.join(timeout=5.0)
|
||||
self._thread = None
|
||||
# Clear resources only after thread has stopped
|
||||
self._loop = None
|
||||
self._async_adapter = None
|
||||
self._lifespan_task = None
|
||||
self._lifespan_receive_queue = None
|
||||
self._started.clear()
|
||||
self._lifespan_startup_complete.clear()
|
||||
self._lifespan_startup_failed = None
|
||||
|
||||
|
||||
__all__ = ("AsyncServerGatewayInterface", "ThreadAsyncServerGatewayInterface")
|
||||
@@ -0,0 +1,175 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import typing
|
||||
|
||||
from ....packages.urllib3.contrib.webextensions.sse import ServerSentEvent
|
||||
|
||||
|
||||
class ASGISSEExtension:
|
||||
"""Async SSE extension for ASGI applications.
|
||||
|
||||
Runs the ASGI app as a normal HTTP streaming request and parses
|
||||
SSE events from the response body chunks.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._closed = False
|
||||
self._buffer: str = ""
|
||||
self._last_event_id: str | None = None
|
||||
self._response_queue: asyncio.Queue[dict[str, typing.Any] | None] | None = None
|
||||
self._task: asyncio.Task[None] | None = None
|
||||
|
||||
async def start(
|
||||
self,
|
||||
app: typing.Any,
|
||||
scope: dict[str, typing.Any],
|
||||
body: bytes = b"",
|
||||
) -> dict[str, typing.Any]:
|
||||
"""Start the ASGI app and wait for http.response.start.
|
||||
|
||||
Returns the response start message (with status and headers).
|
||||
"""
|
||||
self._response_queue = asyncio.Queue()
|
||||
|
||||
request_complete = False
|
||||
response_complete = asyncio.Event()
|
||||
|
||||
async def receive() -> dict[str, typing.Any]:
|
||||
nonlocal request_complete
|
||||
if request_complete:
|
||||
await response_complete.wait()
|
||||
return {"type": "http.disconnect"}
|
||||
request_complete = True
|
||||
return {"type": "http.request", "body": body, "more_body": False}
|
||||
|
||||
async def send(message: dict[str, typing.Any]) -> None:
|
||||
await self._response_queue.put(message) # type: ignore[union-attr]
|
||||
if message["type"] == "http.response.body" and not message.get("more_body", False):
|
||||
response_complete.set()
|
||||
|
||||
async def run_app() -> None:
|
||||
try:
|
||||
await app(scope, receive, send)
|
||||
finally:
|
||||
await self._response_queue.put(None) # type: ignore[union-attr]
|
||||
|
||||
self._task = asyncio.create_task(run_app())
|
||||
|
||||
# Wait for http.response.start
|
||||
while True:
|
||||
message = await self._response_queue.get()
|
||||
if message is None:
|
||||
raise ConnectionError("ASGI app closed before sending response headers")
|
||||
if message["type"] == "http.response.start":
|
||||
return message
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._closed
|
||||
|
||||
async def next_payload(self, *, raw: bool = False) -> ServerSentEvent | str | None:
|
||||
"""Read and parse the next SSE event from the ASGI response stream.
|
||||
Returns None when the stream ends."""
|
||||
if self._closed:
|
||||
raise OSError("The SSE extension is closed")
|
||||
|
||||
while True:
|
||||
# Check if we already have a complete event in the buffer
|
||||
sep_idx = self._buffer.find("\n\n")
|
||||
if sep_idx == -1:
|
||||
sep_idx = self._buffer.find("\r\n\r\n")
|
||||
if sep_idx != -1:
|
||||
sep_len = 4
|
||||
else:
|
||||
sep_len = 2
|
||||
else:
|
||||
sep_len = 2
|
||||
|
||||
if sep_idx != -1:
|
||||
raw_event = self._buffer[:sep_idx]
|
||||
self._buffer = self._buffer[sep_idx + sep_len :]
|
||||
|
||||
event = self._parse_event(raw_event)
|
||||
|
||||
if event is not None:
|
||||
if raw:
|
||||
return raw_event + "\n\n"
|
||||
return event
|
||||
|
||||
# Empty event (e.g. just comments), try next
|
||||
continue
|
||||
|
||||
# Need more data from the ASGI response queue
|
||||
chunk = await self._read_chunk()
|
||||
if chunk is None:
|
||||
self._closed = True
|
||||
return None
|
||||
|
||||
self._buffer += chunk
|
||||
|
||||
async def _read_chunk(self) -> str | None:
|
||||
"""Read the next body chunk from the ASGI response."""
|
||||
if self._response_queue is None:
|
||||
return None
|
||||
|
||||
message = await self._response_queue.get()
|
||||
if message is None:
|
||||
return None
|
||||
if message["type"] == "http.response.body":
|
||||
body = message.get("body", b"")
|
||||
if body:
|
||||
return body.decode("utf-8")
|
||||
return None
|
||||
|
||||
def _parse_event(self, raw_event: str) -> ServerSentEvent | None:
|
||||
"""Parse a raw SSE event block into a ServerSentEvent."""
|
||||
kwargs: dict[str, typing.Any] = {}
|
||||
|
||||
for line in raw_event.splitlines():
|
||||
if not line or line.startswith(":"):
|
||||
continue
|
||||
key, _, value = line.partition(":")
|
||||
if key not in {"event", "data", "retry", "id"}:
|
||||
continue
|
||||
if value.startswith(" "):
|
||||
value = value[1:]
|
||||
if key == "id":
|
||||
if "\u0000" in value:
|
||||
continue
|
||||
if key == "retry":
|
||||
try:
|
||||
value = int(value) # type: ignore[assignment]
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
kwargs[key] = value
|
||||
|
||||
if not kwargs:
|
||||
return None
|
||||
|
||||
if "id" not in kwargs and self._last_event_id is not None:
|
||||
kwargs["id"] = self._last_event_id
|
||||
|
||||
event = ServerSentEvent(**kwargs)
|
||||
|
||||
if event.id:
|
||||
self._last_event_id = event.id
|
||||
|
||||
return event
|
||||
|
||||
async def send_payload(self, buf: str | bytes) -> None:
|
||||
"""SSE is one-way only."""
|
||||
raise NotImplementedError("SSE is only one-way. Sending is forbidden.")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the SSE stream and clean up the app task."""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
|
||||
if self._task is not None and not self._task.done():
|
||||
self._task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._task
|
||||
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import typing
|
||||
|
||||
|
||||
class ASGIWebSocketExtension:
|
||||
"""Async WebSocket extension for ASGI applications.
|
||||
|
||||
Uses the ASGI websocket protocol with send/receive queues to communicate
|
||||
with the application task.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._closed = False
|
||||
self._app_send_queue: asyncio.Queue[dict[str, typing.Any]] = asyncio.Queue()
|
||||
self._app_receive_queue: asyncio.Queue[dict[str, typing.Any]] = asyncio.Queue()
|
||||
self._task: asyncio.Task[None] | None = None
|
||||
|
||||
async def start(self, app: typing.Any, scope: dict[str, typing.Any]) -> None:
|
||||
"""Start the ASGI app task and perform the WebSocket handshake."""
|
||||
|
||||
async def receive() -> dict[str, typing.Any]:
|
||||
return await self._app_receive_queue.get()
|
||||
|
||||
async def send(message: dict[str, typing.Any]) -> None:
|
||||
await self._app_send_queue.put(message)
|
||||
|
||||
self._task = asyncio.create_task(app(scope, receive, send))
|
||||
|
||||
# Send connect and wait for accept
|
||||
await self._app_receive_queue.put({"type": "websocket.connect"})
|
||||
|
||||
message = await self._app_send_queue.get()
|
||||
|
||||
if message["type"] == "websocket.close":
|
||||
self._closed = True
|
||||
self._task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._task
|
||||
raise ConnectionError(f"WebSocket connection rejected with code {message.get('code', 1000)}")
|
||||
|
||||
if message["type"] != "websocket.accept":
|
||||
self._closed = True
|
||||
self._task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._task
|
||||
raise ConnectionError(f"Unexpected ASGI message during handshake: {message['type']}")
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._closed
|
||||
|
||||
async def next_payload(self) -> str | bytes | None:
|
||||
"""Await the next message from the ASGI WebSocket app.
|
||||
Returns None when the app closes the connection."""
|
||||
if self._closed:
|
||||
raise OSError("The WebSocket extension is closed")
|
||||
|
||||
message = await self._app_send_queue.get()
|
||||
|
||||
if message["type"] == "websocket.send":
|
||||
if "text" in message:
|
||||
return message["text"]
|
||||
if "bytes" in message:
|
||||
return message["bytes"]
|
||||
return b""
|
||||
|
||||
if message["type"] == "websocket.close":
|
||||
self._closed = True
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
async def send_payload(self, buf: str | bytes) -> None:
|
||||
"""Send a message to the ASGI WebSocket app."""
|
||||
if self._closed:
|
||||
raise OSError("The WebSocket extension is closed")
|
||||
|
||||
if isinstance(buf, (bytes, bytearray)):
|
||||
await self._app_receive_queue.put({"type": "websocket.receive", "bytes": bytes(buf)})
|
||||
else:
|
||||
await self._app_receive_queue.put({"type": "websocket.receive", "text": buf})
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebSocket and clean up the app task."""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
|
||||
await self._app_receive_queue.put({"type": "websocket.disconnect", "code": 1000})
|
||||
|
||||
if self._task is not None and not self._task.done():
|
||||
self._task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._task
|
||||
@@ -0,0 +1,175 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from concurrent.futures import Future
|
||||
|
||||
from ...packages.urllib3.contrib.webextensions.sse import ServerSentEvent
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import asyncio
|
||||
|
||||
from ._async._sse import ASGISSEExtension
|
||||
|
||||
|
||||
class WSGISSEExtension:
|
||||
"""SSE extension for WSGI applications.
|
||||
|
||||
Reads from a WSGI response iterator, buffers text, and parses SSE events.
|
||||
"""
|
||||
|
||||
def __init__(self, generator: typing.Generator[bytes, None, None]) -> None:
|
||||
self._generator = generator
|
||||
self._closed = False
|
||||
self._buffer: str = ""
|
||||
self._last_event_id: str | None = None
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._closed
|
||||
|
||||
def next_payload(self, *, raw: bool = False) -> ServerSentEvent | str | None:
|
||||
"""Read and parse the next SSE event from the WSGI response.
|
||||
Returns None when the stream ends."""
|
||||
if self._closed:
|
||||
raise OSError("The SSE extension is closed")
|
||||
|
||||
while True:
|
||||
# Check if we already have a complete event in the buffer
|
||||
sep_idx = self._buffer.find("\n\n")
|
||||
if sep_idx == -1:
|
||||
sep_idx = self._buffer.find("\r\n\r\n")
|
||||
if sep_idx != -1:
|
||||
sep_len = 4
|
||||
else:
|
||||
sep_len = 2
|
||||
else:
|
||||
sep_len = 2
|
||||
|
||||
if sep_idx != -1:
|
||||
raw_event = self._buffer[:sep_idx]
|
||||
self._buffer = self._buffer[sep_idx + sep_len :]
|
||||
|
||||
event = self._parse_event(raw_event)
|
||||
|
||||
if event is not None:
|
||||
if raw:
|
||||
return raw_event + "\n\n"
|
||||
return event
|
||||
|
||||
# Empty event (e.g. just comments), try next
|
||||
continue
|
||||
|
||||
# Need more data
|
||||
chunk = self._read_chunk()
|
||||
if chunk is None:
|
||||
self._closed = True
|
||||
return None
|
||||
|
||||
self._buffer += chunk
|
||||
|
||||
def _read_chunk(self) -> str | None:
|
||||
"""Read the next chunk from the WSGI response iterator."""
|
||||
try:
|
||||
chunk = next(self._generator)
|
||||
if chunk:
|
||||
return chunk.decode("utf-8")
|
||||
return None
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def _parse_event(self, raw_event: str) -> ServerSentEvent | None:
|
||||
"""Parse a raw SSE event block into a ServerSentEvent."""
|
||||
kwargs: dict[str, typing.Any] = {}
|
||||
|
||||
for line in raw_event.splitlines():
|
||||
if not line or line.startswith(":"):
|
||||
continue
|
||||
key, _, value = line.partition(":")
|
||||
if key not in {"event", "data", "retry", "id"}:
|
||||
continue
|
||||
if value.startswith(" "):
|
||||
value = value[1:]
|
||||
if key == "id":
|
||||
if "\u0000" in value:
|
||||
continue
|
||||
if key == "retry":
|
||||
try:
|
||||
value = int(value) # type: ignore[assignment]
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
kwargs[key] = value
|
||||
|
||||
if not kwargs:
|
||||
return None
|
||||
|
||||
if "id" not in kwargs and self._last_event_id is not None:
|
||||
kwargs["id"] = self._last_event_id
|
||||
|
||||
event = ServerSentEvent(**kwargs)
|
||||
|
||||
if event.id:
|
||||
self._last_event_id = event.id
|
||||
|
||||
return event
|
||||
|
||||
def send_payload(self, buf: str | bytes) -> None:
|
||||
"""SSE is one-way only."""
|
||||
raise NotImplementedError("SSE is only one-way. Sending is forbidden.")
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the stream and release resources."""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
|
||||
if hasattr(self._generator, "close"):
|
||||
self._generator.close()
|
||||
|
||||
|
||||
class ThreadASGISSEExtension:
|
||||
"""Synchronous SSE extension wrapping an async ASGISSEExtension.
|
||||
|
||||
Delegates all operations to the async extension on a background event loop,
|
||||
blocking the calling thread via concurrent.futures.Future.
|
||||
"""
|
||||
|
||||
def __init__(self, async_ext: ASGISSEExtension, loop: asyncio.AbstractEventLoop) -> None:
|
||||
self._async_ext = async_ext
|
||||
self._loop = loop
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._async_ext.closed
|
||||
|
||||
def next_payload(self, *, raw: bool = False) -> ServerSentEvent | str | None:
|
||||
"""Block until the next SSE event arrives from the ASGI app."""
|
||||
future: Future[ServerSentEvent | str | None] = Future()
|
||||
|
||||
async def _do() -> None:
|
||||
try:
|
||||
result = await self._async_ext.next_payload(raw=raw)
|
||||
future.set_result(result)
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
|
||||
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(_do()))
|
||||
return future.result()
|
||||
|
||||
def send_payload(self, buf: str | bytes) -> None:
|
||||
"""SSE is one-way only."""
|
||||
raise NotImplementedError("SSE is only one-way. Sending is forbidden.")
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the SSE stream and clean up."""
|
||||
future: Future[None] = Future()
|
||||
|
||||
async def _do() -> None:
|
||||
try:
|
||||
await self._async_ext.close()
|
||||
future.set_result(None)
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
|
||||
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(_do()))
|
||||
future.result()
|
||||
@@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from concurrent.futures import Future
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import asyncio
|
||||
|
||||
from ._async._ws import ASGIWebSocketExtension
|
||||
|
||||
|
||||
class ThreadASGIWebSocketExtension:
|
||||
"""Synchronous WebSocket extension wrapping an async ASGIWebSocketExtension.
|
||||
|
||||
Delegates all operations to the async extension on a background event loop,
|
||||
blocking the calling thread via concurrent.futures.Future.
|
||||
"""
|
||||
|
||||
def __init__(self, async_ext: ASGIWebSocketExtension, loop: asyncio.AbstractEventLoop) -> None:
|
||||
self._async_ext = async_ext
|
||||
self._loop = loop
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._async_ext.closed
|
||||
|
||||
def next_payload(self) -> str | bytes | None:
|
||||
"""Block until the next message arrives from the ASGI WebSocket app."""
|
||||
future: Future[str | bytes | None] = Future()
|
||||
|
||||
async def _do() -> None:
|
||||
try:
|
||||
result = await self._async_ext.next_payload()
|
||||
future.set_result(result)
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
|
||||
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(_do()))
|
||||
return future.result()
|
||||
|
||||
def send_payload(self, buf: str | bytes) -> None:
|
||||
"""Send a message to the ASGI WebSocket app."""
|
||||
future: Future[None] = Future()
|
||||
|
||||
async def _do() -> None:
|
||||
try:
|
||||
await self._async_ext.send_payload(buf)
|
||||
future.set_result(None)
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
|
||||
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(_do()))
|
||||
future.result()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the WebSocket and clean up."""
|
||||
future: Future[None] = Future()
|
||||
|
||||
async def _do() -> None:
|
||||
try:
|
||||
await self._async_ext.close()
|
||||
future.set_result(None)
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
|
||||
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(_do()))
|
||||
future.result()
|
||||
@@ -0,0 +1,100 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
from urllib.parse import unquote
|
||||
|
||||
from ..._constant import DEFAULT_POOLBLOCK
|
||||
from ...adapters import HTTPAdapter
|
||||
from ...exceptions import RequestException
|
||||
from ...packages.urllib3.connection import HTTPConnection
|
||||
from ...packages.urllib3.connectionpool import HTTPConnectionPool
|
||||
from ...packages.urllib3.contrib.webextensions import ServerSideEventExtensionFromHTTP, WebSocketExtensionFromHTTP
|
||||
from ...packages.urllib3.poolmanager import PoolManager
|
||||
from ...typing import CacheLayerAltSvcType
|
||||
from ...utils import select_proxy
|
||||
|
||||
|
||||
class UnixServerSideEventExtensionFromHTTP(ServerSideEventExtensionFromHTTP):
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
return "unix"
|
||||
|
||||
@staticmethod
|
||||
def scheme_to_http_scheme(scheme: str) -> str:
|
||||
return {"psse": "http+unix"}[scheme]
|
||||
|
||||
|
||||
if WebSocketExtensionFromHTTP is not None:
|
||||
|
||||
class UnixWebSocketExtensionFromHTTP(WebSocketExtensionFromHTTP):
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
return "unix"
|
||||
|
||||
@staticmethod
|
||||
def scheme_to_http_scheme(scheme: str) -> str:
|
||||
return {"ws": "http+unix"}[scheme]
|
||||
|
||||
|
||||
class UnixHTTPConnection(HTTPConnection):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.host: str = unquote(self.host)
|
||||
self.socket_path = self.host
|
||||
self.host = self.socket_path.split("/")[-1]
|
||||
|
||||
def connect(self):
|
||||
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
sock.settimeout(self.timeout)
|
||||
sock.connect(self.socket_path)
|
||||
self.sock = sock
|
||||
self._post_conn()
|
||||
|
||||
|
||||
class UnixHTTPConnectionPool(HTTPConnectionPool):
|
||||
ConnectionCls = UnixHTTPConnection
|
||||
|
||||
|
||||
class UnixAdapter(HTTPAdapter):
|
||||
def init_poolmanager(
|
||||
self,
|
||||
connections: int,
|
||||
maxsize: int,
|
||||
block: bool = DEFAULT_POOLBLOCK,
|
||||
quic_cache_layer: CacheLayerAltSvcType | None = None,
|
||||
**pool_kwargs: typing.Any,
|
||||
):
|
||||
self._pool_connections = connections
|
||||
self._pool_maxsize = maxsize
|
||||
self._pool_block = block
|
||||
self._quic_cache_layer = quic_cache_layer
|
||||
|
||||
self.poolmanager = PoolManager(
|
||||
num_pools=connections,
|
||||
maxsize=maxsize,
|
||||
block=block,
|
||||
preemptive_quic_cache=quic_cache_layer,
|
||||
**pool_kwargs,
|
||||
)
|
||||
self.poolmanager.key_fn_by_scheme["http+unix"] = self.poolmanager.key_fn_by_scheme["http"]
|
||||
self.poolmanager.pool_classes_by_scheme = {
|
||||
"http+unix": UnixHTTPConnectionPool,
|
||||
}
|
||||
|
||||
def get_connection(self, url, proxies=None):
|
||||
proxy = select_proxy(url, proxies)
|
||||
|
||||
if proxy:
|
||||
raise RequestException("unix socket cannot be associated with proxies")
|
||||
|
||||
return self.poolmanager.connection_from_url(url)
|
||||
|
||||
def request_url(self, request, proxies):
|
||||
return request.path_url
|
||||
|
||||
def close(self):
|
||||
self.poolmanager.clear()
|
||||
|
||||
|
||||
__all__ = ("UnixAdapter",)
|
||||
@@ -0,0 +1,101 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
from urllib.parse import unquote
|
||||
|
||||
from ...._constant import DEFAULT_POOLBLOCK
|
||||
from ....adapters import AsyncHTTPAdapter
|
||||
from ....exceptions import RequestException
|
||||
from ....packages.urllib3._async.connection import AsyncHTTPConnection
|
||||
from ....packages.urllib3._async.connectionpool import AsyncHTTPConnectionPool
|
||||
from ....packages.urllib3._async.poolmanager import AsyncPoolManager
|
||||
from ....packages.urllib3.contrib.ssa import AsyncSocket
|
||||
from ....packages.urllib3.contrib.webextensions._async import (
|
||||
AsyncServerSideEventExtensionFromHTTP,
|
||||
AsyncWebSocketExtensionFromHTTP,
|
||||
)
|
||||
from ....typing import CacheLayerAltSvcType
|
||||
from ....utils import select_proxy
|
||||
|
||||
|
||||
class AsyncUnixServerSideEventExtensionFromHTTP(AsyncServerSideEventExtensionFromHTTP):
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
return "unix"
|
||||
|
||||
@staticmethod
|
||||
def scheme_to_http_scheme(scheme: str) -> str:
|
||||
return {"psse": "http+unix"}[scheme]
|
||||
|
||||
|
||||
if AsyncWebSocketExtensionFromHTTP is not None:
|
||||
|
||||
class AsyncUnixWebSocketExtensionFromHTTP(AsyncWebSocketExtensionFromHTTP):
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
return "unix"
|
||||
|
||||
@staticmethod
|
||||
def scheme_to_http_scheme(scheme: str) -> str:
|
||||
return {"ws": "http+unix"}[scheme]
|
||||
|
||||
|
||||
class AsyncUnixHTTPConnection(AsyncHTTPConnection):
|
||||
def __init__(self, host, **kwargs):
|
||||
super().__init__(host, **kwargs)
|
||||
self.host: str = unquote(self.host)
|
||||
self.socket_path = self.host
|
||||
self.host = self.socket_path.split("/")[-1]
|
||||
|
||||
async def connect(self):
|
||||
sock = AsyncSocket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
sock.settimeout(self.timeout)
|
||||
await sock.connect(self.socket_path)
|
||||
self.sock = sock
|
||||
await self._post_conn()
|
||||
|
||||
|
||||
class AsyncUnixHTTPConnectionPool(AsyncHTTPConnectionPool):
|
||||
ConnectionCls = AsyncUnixHTTPConnection
|
||||
|
||||
|
||||
class AsyncUnixAdapter(AsyncHTTPAdapter):
|
||||
def init_poolmanager(
|
||||
self,
|
||||
connections: int,
|
||||
maxsize: int,
|
||||
block: bool = DEFAULT_POOLBLOCK,
|
||||
quic_cache_layer: CacheLayerAltSvcType | None = None,
|
||||
**pool_kwargs: typing.Any,
|
||||
):
|
||||
self._pool_connections = connections
|
||||
self._pool_maxsize = maxsize
|
||||
self._pool_block = block
|
||||
self._quic_cache_layer = quic_cache_layer
|
||||
|
||||
self.poolmanager = AsyncPoolManager(
|
||||
num_pools=connections,
|
||||
maxsize=maxsize,
|
||||
block=block,
|
||||
preemptive_quic_cache=quic_cache_layer,
|
||||
**pool_kwargs,
|
||||
)
|
||||
self.poolmanager.key_fn_by_scheme["http+unix"] = self.poolmanager.key_fn_by_scheme["http"]
|
||||
self.poolmanager.pool_classes_by_scheme = {
|
||||
"http+unix": AsyncUnixHTTPConnectionPool,
|
||||
}
|
||||
|
||||
def get_connection(self, url, proxies=None):
|
||||
proxy = select_proxy(url, proxies)
|
||||
|
||||
if proxy:
|
||||
raise RequestException("unix socket cannot be associated with proxies")
|
||||
|
||||
return self.poolmanager.connection_from_url(url)
|
||||
|
||||
def request_url(self, request, proxies):
|
||||
return request.path_url
|
||||
|
||||
|
||||
__all__ = ("AsyncUnixAdapter",)
|
||||
Reference in New Issue
Block a user