fix: 포트 충돌 회피 — note_bridge 8098, intent_service 8099

Jellyfin(8096), OrbStack(8097) 포트 충돌으로 변경.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Hyungi Ahn
2026-03-19 13:53:55 +09:00
parent dc08d29509
commit c2257d3a86
2709 changed files with 619549 additions and 10 deletions

View File

@@ -0,0 +1,4 @@
"""
This subpackage hold anything that is very relevant
to the HTTP ecosystem but not per-say Niquests core logic.
"""

View File

@@ -0,0 +1,394 @@
from __future__ import annotations
import time
import typing
from datetime import timedelta
from pyodide.ffi import run_sync # type: ignore[import]
from pyodide.http import pyfetch # type: ignore[import]
from ..._constant import DEFAULT_RETRIES
from ...adapters import BaseAdapter
from ...exceptions import ConnectionError, ConnectTimeout
from ...models import PreparedRequest, Response
from ...packages.urllib3.exceptions import MaxRetryError
from ...packages.urllib3.response import BytesQueueBuffer
from ...packages.urllib3.response import HTTPResponse as BaseHTTPResponse
from ...packages.urllib3.util import Timeout as TimeoutSauce
from ...packages.urllib3.util.retry import Retry
from ...structures import CaseInsensitiveDict
from ...utils import get_encoding_from_headers
if typing.TYPE_CHECKING:
from ...typing import ProxyType, RetryType, TLSClientCertType, TLSVerifyType
class _PyodideRawIO:
"""File-like wrapper around a Pyodide Fetch response with true streaming via JSPI.
When constructed with a JS Response object, reads chunks incrementally from the
JavaScript ReadableStream using ``run_sync(reader.read())`` per chunk.
When constructed with preloaded content (non-streaming), serves from a memory buffer.
"""
def __init__(
self,
js_response: typing.Any = None,
preloaded_content: bytes | None = None,
) -> None:
self._js_response = js_response
self._buffer = BytesQueueBuffer()
self._closed = False
self._finished = preloaded_content is not None or js_response is None
self._reader: typing.Any = None
self.headers: dict[str, str] = {}
self.extension: typing.Any = None
if preloaded_content is not None:
self._buffer.put(preloaded_content)
def _ensure_reader(self) -> None:
"""Initialize the ReadableStream reader if not already done."""
if self._reader is None and self._js_response is not None:
try:
body = self._js_response.body
if body is not None:
self._reader = body.getReader()
except Exception:
pass
def _get_next_chunk(self) -> bytes | None:
"""Read the next chunk from the JS ReadableStream, blocking via JSPI."""
self._ensure_reader()
if self._reader is None:
return None
try:
result = run_sync(self._reader.read())
if result.done:
return None
value = result.value
if value is not None:
return bytes(value.to_py())
return None
except Exception:
return None
def read(
self,
amt: int | None = None,
decode_content: bool = True,
) -> bytes:
if self._closed:
return self._buffer.get(len(self._buffer)) if len(self._buffer) > 0 else b""
if self._finished:
if len(self._buffer) == 0:
return b""
if amt is None or amt < 0:
return self._buffer.get(len(self._buffer))
return self._buffer.get(min(amt, len(self._buffer)))
if amt is None or amt < 0:
# Read everything remaining
while True:
chunk = self._get_next_chunk()
if chunk is None:
break
self._buffer.put(chunk)
self._finished = True
return self._buffer.get(len(self._buffer)) if len(self._buffer) > 0 else b""
# Read until we have enough bytes or stream ends
while len(self._buffer) < amt and not self._finished:
chunk = self._get_next_chunk()
if chunk is None:
self._finished = True
break
self._buffer.put(chunk)
if len(self._buffer) == 0:
return b""
return self._buffer.get(min(amt, len(self._buffer)))
def stream(self, amt: int, decode_content: bool = True) -> typing.Generator[bytes, None, None]:
"""Iterate over chunks of the response."""
while True:
chunk = self.read(amt)
if not chunk:
break
yield chunk
def close(self) -> None:
self._closed = True
if self._reader is not None:
try:
run_sync(self._reader.cancel())
except Exception:
pass
self._reader = None
self._js_response = None
def __iter__(self) -> typing.Iterator[bytes]:
return self
def __next__(self) -> bytes:
chunk = self.read(8192)
if not chunk:
raise StopIteration
return chunk
class PyodideAdapter(BaseAdapter):
"""Synchronous adapter for making HTTP requests in Pyodide using JSPI + pyfetch."""
def __init__(self, max_retries: RetryType = DEFAULT_RETRIES) -> None:
super().__init__()
if isinstance(max_retries, Retry):
self.max_retries = max_retries
else:
self.max_retries = Retry.from_int(max_retries)
def __repr__(self) -> str:
return "<PyodideAdapter WASM/>"
def send(
self,
request: PreparedRequest,
stream: bool = False,
timeout: int | float | tuple | TimeoutSauce | None = None,
verify: TLSVerifyType = True,
cert: TLSClientCertType | None = None,
proxies: ProxyType | None = None,
on_post_connection: typing.Callable[[typing.Any], None] | None = None,
on_upload_body: typing.Callable[[int, int | None, bool, bool], None] | None = None,
on_early_response: typing.Callable[[Response], None] | None = None,
multiplexed: bool = False,
) -> Response:
"""Send a PreparedRequest using Pyodide's pyfetch (synchronous via JSPI)."""
if isinstance(timeout, tuple):
if len(timeout) == 3:
timeout = timeout[2] or timeout[0] # prefer total, fallback connect
else:
timeout = timeout[0] # use connect
elif isinstance(timeout, TimeoutSauce):
timeout = timeout.total or timeout.connect_timeout
retries = self.max_retries
method = request.method or "GET"
start = time.time()
while True:
try:
response = self._do_send(request, stream, timeout)
except Exception as err:
retries = retries.increment(method, request.url, error=err)
retries.sleep()
continue
base_response = BaseHTTPResponse(
body=b"",
headers=response.headers,
status=response.status_code,
request_method=request.method,
request_url=request.url,
)
has_retry_after = bool(response.headers.get("Retry-After"))
if retries.is_retry(method, response.status_code, has_retry_after):
try:
retries = retries.increment(method, request.url, response=base_response)
except MaxRetryError:
if retries.raise_on_status:
raise
return response
retries.sleep(base_response)
continue
response.elapsed = timedelta(seconds=time.time() - start)
return response
def _do_send(
self,
request: PreparedRequest,
stream: bool,
timeout: int | float | None,
) -> Response:
"""Perform the actual request using pyfetch made synchronous via JSPI."""
url = request.url or ""
scheme = url.split("://")[0].lower() if "://" in url else ""
# WebSocket: delegate to browser native WebSocket API
if scheme in ("ws", "wss"):
return self._do_send_ws(request, url)
# SSE: delegate to pyfetch streaming + manual SSE parsing
if scheme in ("sse", "psse"):
return self._do_send_sse(request, url, scheme)
# Prepare headers
headers_dict: dict[str, str] = {}
if request.headers:
for key, value in request.headers.items():
if key.lower() not in ("host", "content-length", "connection", "transfer-encoding"):
headers_dict[key] = value
# Prepare body
body = request.body
if body is not None:
if isinstance(body, str):
body = body.encode("utf-8")
elif isinstance(body, typing.Iterable) and not isinstance(body, (bytes, bytearray)):
chunks: list[bytes] = []
for chunk in body:
if isinstance(chunk, str):
chunks.append(chunk.encode("utf-8"))
elif isinstance(chunk, bytes):
chunks.append(chunk)
body = b"".join(chunks)
# Build fetch options
fetch_options: dict[str, typing.Any] = {
"method": request.method or "GET",
"headers": headers_dict,
}
if body:
fetch_options["body"] = body
# Use AbortSignal.timeout() for timeout — the browser-native mechanism.
# run_sync cannot interrupt a JS Promise.
signal = None
if timeout is not None:
from js import AbortSignal # type: ignore[import]
signal = AbortSignal.timeout(int(timeout * 1000))
try:
js_response = run_sync(pyfetch(request.url, signal=signal, **fetch_options))
except Exception as e:
err_str = str(e).lower()
if "abort" in err_str or "timeout" in err_str or "timed out" in err_str:
raise ConnectTimeout(f"Connection to {request.url} timed out")
raise ConnectionError(f"Failed to fetch {request.url}: {e}")
# Parse response headers
response_headers: dict[str, str] = {}
try:
if hasattr(js_response, "headers"):
js_headers = js_response.headers
if hasattr(js_headers, "items"):
for key, value in js_headers.items():
response_headers[key] = value
elif hasattr(js_headers, "entries"):
for entry in js_headers.entries():
response_headers[entry[0]] = entry[1]
except Exception:
pass
# Build response object
response = Response()
response.status_code = js_response.status
response.headers = CaseInsensitiveDict(response_headers)
response.request = request
response.url = js_response.url or request.url
response.encoding = get_encoding_from_headers(response_headers)
try:
response.reason = js_response.status_text or ""
except Exception:
response.reason = ""
if stream:
# Streaming: pass the underlying JS Response to the raw IO
# so it can read chunks incrementally via run_sync(reader.read())
raw_io = _PyodideRawIO(js_response=js_response.js_response)
raw_io.headers = response_headers
response.raw = raw_io # type: ignore
response._content = False # type: ignore[assignment]
response._content_consumed = False
else:
# Non-streaming: read full body upfront
try:
response_body: bytes = run_sync(js_response.bytes())
except Exception:
response_body = b""
raw_io = _PyodideRawIO(preloaded_content=response_body)
raw_io.headers = response_headers
response.raw = raw_io # type: ignore
response._content = response_body
return response
def _do_send_ws(self, request: PreparedRequest, url: str) -> Response:
"""Handle WebSocket connections via browser native WebSocket API."""
from ._ws import PyodideWebSocketExtension
try:
ext = PyodideWebSocketExtension(url)
except Exception as e:
raise ConnectionError(f"WebSocket connection to {url} failed: {e}")
response = Response()
response.status_code = 101
response.headers = CaseInsensitiveDict({"upgrade": "websocket", "connection": "upgrade"})
response.request = request
response.url = url
response.reason = "Switching Protocols"
raw_io = _PyodideRawIO()
raw_io.extension = ext
response.raw = raw_io # type: ignore
response._content = b""
return response
def _do_send_sse(self, request: PreparedRequest, url: str, scheme: str) -> Response:
"""Handle SSE connections via pyfetch streaming + manual parsing."""
from ._sse import PyodideSSEExtension
http_url = url.replace("sse://", "https://", 1) if scheme == "sse" else url.replace("psse://", "http://", 1)
# Pass through user-provided headers
headers_dict: dict[str, str] = {}
if request.headers:
for key, value in request.headers.items():
if key.lower() not in ("host", "content-length", "connection"):
headers_dict[key] = value
try:
ext = PyodideSSEExtension(http_url, headers=headers_dict)
except Exception as e:
raise ConnectionError(f"SSE connection to {url} failed: {e}")
response = Response()
response.status_code = 200
response.headers = CaseInsensitiveDict({"content-type": "text/event-stream"})
response.request = request
response.url = url
response.reason = "OK"
raw_io = _PyodideRawIO()
raw_io.extension = ext
response.raw = raw_io # type: ignore
response._content = False # type: ignore[assignment]
response._content_consumed = False
return response
def close(self) -> None:
"""Clean up adapter resources."""
pass
__all__ = ("PyodideAdapter",)

View File

@@ -0,0 +1,451 @@
from __future__ import annotations
import asyncio
import time
import typing
from datetime import timedelta
from pyodide.http import pyfetch # type: ignore[import]
from ...._constant import DEFAULT_RETRIES
from ....adapters import AsyncBaseAdapter
from ....exceptions import ConnectionError, ConnectTimeout, ReadTimeout
from ....models import AsyncResponse, PreparedRequest, Response
from ....packages.urllib3._async.response import AsyncHTTPResponse as BaseHTTPResponse
from ....packages.urllib3.contrib.ssa._timeout import timeout as asyncio_timeout
from ....packages.urllib3.exceptions import MaxRetryError
from ....packages.urllib3.response import BytesQueueBuffer
from ....packages.urllib3.util import Timeout as TimeoutSauce
from ....packages.urllib3.util.retry import Retry
from ....structures import CaseInsensitiveDict
from ....utils import _swap_context, get_encoding_from_headers
if typing.TYPE_CHECKING:
from ....typing import ProxyType, RetryType, TLSClientCertType, TLSVerifyType
class _AsyncPyodideRawIO:
"""
Async file-like wrapper around Pyodide Fetch response for true streaming.
This class uses the JavaScript ReadableStream API through Pyodide to provide
genuine streaming support without buffering the entire response in memory.
"""
def __init__(
self,
js_response: typing.Any, # JavaScript Response object
timeout: float | None = None,
) -> None:
self._js_response = js_response
self._timeout = timeout
self._buffer = BytesQueueBuffer()
self._closed = False
self._finished = False
self._reader: typing.Any = None # JavaScript ReadableStreamDefaultReader
self.headers: dict[str, str] = {}
self.extension: typing.Any = None
async def _ensure_reader(self) -> None:
"""Initialize the stream reader if not already done."""
if self._reader is None and self._js_response is not None:
try:
body = self._js_response.body
if body is not None:
self._reader = body.getReader()
except Exception:
pass
async def read(self, amt: int | None = None, decode_content: bool = True) -> bytes:
"""
Read up to `amt` bytes from the response stream.
When `amt` is None, reads the entire remaining response.
"""
if self._closed:
return self._buffer.get(len(self._buffer)) if len(self._buffer) > 0 else b""
if self._finished:
if len(self._buffer) == 0:
return b""
if amt is None or amt < 0:
return self._buffer.get(len(self._buffer))
return self._buffer.get(min(amt, len(self._buffer)))
if amt is None or amt < 0:
# Read everything remaining
async for chunk in self._stream_chunks():
self._buffer.put(chunk)
self._finished = True
return self._buffer.get(len(self._buffer)) if len(self._buffer) > 0 else b""
# Read until we have enough bytes or stream ends
while len(self._buffer) < amt and not self._finished:
chunk = await self._get_next_chunk() # type: ignore[assignment]
if chunk is None:
self._finished = True
break
self._buffer.put(chunk)
if len(self._buffer) == 0:
return b""
return self._buffer.get(min(amt, len(self._buffer)))
async def _get_next_chunk(self) -> bytes | None:
"""Read the next chunk from the JavaScript ReadableStream."""
await self._ensure_reader()
if self._reader is None:
return None
try:
if self._timeout is not None:
async with asyncio_timeout(self._timeout):
result = await self._reader.read()
else:
result = await self._reader.read()
if result.done:
return None
value = result.value
if value is not None:
# Convert Uint8Array to bytes
return bytes(value.to_py())
return None
except asyncio.TimeoutError:
raise ReadTimeout("Read timed out while streaming Pyodide response")
except Exception:
return None
async def _stream_chunks(self) -> typing.AsyncGenerator[bytes, None]:
"""Async generator that yields chunks from the stream."""
await self._ensure_reader()
if self._reader is None:
return
while True:
chunk = await self._get_next_chunk()
if chunk is None:
break
yield chunk
def stream(self, amt: int, decode_content: bool = True) -> typing.AsyncGenerator[bytes, None]:
"""Return an async generator that yields chunks of `amt` bytes."""
return self._async_stream(amt)
async def _async_stream(self, amt: int) -> typing.AsyncGenerator[bytes, None]:
"""Internal async generator for streaming."""
while True:
chunk = await self.read(amt)
if not chunk:
break
yield chunk
async def close(self) -> None:
"""Close the stream and release resources."""
self._closed = True
if self._reader is not None:
try:
await self._reader.cancel()
except Exception:
pass
self._reader = None
self._js_response = None
def __aiter__(self) -> typing.AsyncIterator[bytes]:
return self
async def __anext__(self) -> bytes:
chunk = await self.read(8192)
if not chunk:
raise StopAsyncIteration
return chunk
class AsyncPyodideAdapter(AsyncBaseAdapter):
"""Async adapter for making HTTP requests in Pyodide using the native pyfetch API."""
def __init__(self, max_retries: RetryType = DEFAULT_RETRIES) -> None:
"""
Initialize the async Pyodide adapter.
:param max_retries: Maximum number of retries for requests.
"""
super().__init__()
if isinstance(max_retries, Retry):
self.max_retries = max_retries
else:
self.max_retries = Retry.from_int(max_retries)
def __repr__(self) -> str:
return "<AsyncPyodideAdapter WASM/>"
async def send(
self,
request: PreparedRequest,
stream: bool = False,
timeout: int | float | tuple | TimeoutSauce | None = None,
verify: TLSVerifyType = True,
cert: TLSClientCertType | None = None,
proxies: ProxyType | None = None,
on_post_connection: typing.Callable[[typing.Any], typing.Awaitable[None]] | None = None,
on_upload_body: typing.Callable[[int, int | None, bool, bool], typing.Awaitable[None]] | None = None,
on_early_response: typing.Callable[[Response], typing.Awaitable[None]] | None = None,
multiplexed: bool = False,
) -> AsyncResponse:
"""Send a PreparedRequest using Pyodide's pyfetch (JavaScript Fetch API)."""
if isinstance(timeout, tuple):
if len(timeout) == 3:
timeout = timeout[2] or timeout[0] # prefer total, fallback connect
else:
timeout = timeout[0] # use connect
elif isinstance(timeout, TimeoutSauce):
timeout = timeout.total or timeout.connect_timeout
retries = self.max_retries
method = request.method or "GET"
start = time.time()
while True:
try:
response = await self._do_send(request, stream, timeout)
except Exception as err:
retries = retries.increment(method, request.url, error=err)
await retries.async_sleep()
continue
# We rely on the urllib3 implementation for retries
# so we basically mock a response to get it to work
base_response = BaseHTTPResponse(
body=b"",
headers=response.headers,
status=response.status_code,
request_method=request.method,
request_url=request.url,
)
# Check if we should retry based on status code
has_retry_after = bool(response.headers.get("Retry-After"))
if retries.is_retry(method, response.status_code, has_retry_after):
try:
retries = retries.increment(method, request.url, response=base_response)
except MaxRetryError:
if retries.raise_on_status:
raise
return response
await retries.async_sleep(base_response)
continue
response.elapsed = timedelta(seconds=time.time() - start)
return response
async def _do_send(
self,
request: PreparedRequest,
stream: bool,
timeout: int | float | None,
) -> AsyncResponse:
"""Perform the actual request using Pyodide's pyfetch."""
url = request.url or ""
scheme = url.split("://")[0].lower() if "://" in url else ""
# WebSocket: delegate to browser native WebSocket API
if scheme in ("ws", "wss"):
return await self._do_send_ws(request, url)
# SSE: delegate to pyfetch streaming + manual SSE parsing
if scheme in ("sse", "psse"):
return await self._do_send_sse(request, url, scheme)
# Prepare headers
headers_dict: dict[str, str] = {}
if request.headers:
for key, value in request.headers.items():
# Skip headers that browsers don't allow to be set
if key.lower() not in ("host", "content-length", "connection", "transfer-encoding"):
headers_dict[key] = value
# Prepare body
body = request.body
if body is not None:
if isinstance(body, str):
body = body.encode("utf-8")
elif hasattr(body, "__aiter__"):
# Consume async iterable body
chunks: list[bytes] = []
async for chunk in body: # type: ignore[union-attr]
if isinstance(chunk, str):
chunks.append(chunk.encode("utf-8"))
else:
chunks.append(chunk)
body = b"".join(chunks)
elif isinstance(body, typing.Iterable) and not isinstance(body, (bytes, bytearray)):
# Consume sync iterable body
chunks = []
for chunk in body:
if isinstance(chunk, str):
chunks.append(chunk.encode("utf-8"))
elif isinstance(chunk, bytes):
chunks.append(chunk)
body = b"".join(chunks)
# Build fetch options
fetch_options: dict[str, typing.Any] = {
"method": request.method or "GET",
"headers": headers_dict,
}
if body:
fetch_options["body"] = body
# Use AbortSignal.timeout() for timeout — the browser-native mechanism.
# asyncio_timeout cannot interrupt a single JS Promise await.
signal = None
if timeout is not None:
from js import AbortSignal # type: ignore[import]
signal = AbortSignal.timeout(int(timeout * 1000))
try:
js_response = await pyfetch(request.url, signal=signal, **fetch_options)
except Exception as e:
err_str = str(e).lower()
if "abort" in err_str or "timeout" in err_str or "timed out" in err_str:
raise ConnectTimeout(f"Connection to {request.url} timed out")
raise ConnectionError(f"Failed to fetch {request.url}: {e}")
# Parse response headers
response_headers: dict[str, str] = {}
try:
# Pyodide's FetchResponse has headers as a dict-like object
if hasattr(js_response, "headers"):
js_headers = js_response.headers
if hasattr(js_headers, "items"):
for key, value in js_headers.items():
response_headers[key] = value
elif hasattr(js_headers, "entries"):
# JavaScript Headers.entries() returns an iterator
for entry in js_headers.entries():
response_headers[entry[0]] = entry[1]
except Exception:
pass
# Build response object
response = Response()
response.status_code = js_response.status
response.headers = CaseInsensitiveDict(response_headers)
response.request = request
response.url = js_response.url or request.url
response.encoding = get_encoding_from_headers(response_headers)
# Try to get status text
try:
response.reason = js_response.status_text or ""
except Exception:
response.reason = ""
if stream:
# For streaming: set up async raw IO using the JS Response object
# This provides true streaming without buffering entire response
raw_io = _AsyncPyodideRawIO(js_response.js_response, timeout)
raw_io.headers = response_headers
response.raw = raw_io # type: ignore
response._content = False # type: ignore[assignment]
response._content_consumed = False
else:
# For non-streaming: get full response body
try:
if timeout is not None:
async with asyncio_timeout(timeout):
response_body = await js_response.bytes()
else:
response_body = await js_response.bytes()
except asyncio.TimeoutError:
raise ReadTimeout(f"Read timed out for {request.url}")
response._content = response_body
raw_io = _AsyncPyodideRawIO(None, timeout)
raw_io.headers = response_headers
response.raw = raw_io # type: ignore
_swap_context(response)
return response # type: ignore[return-value]
async def _do_send_ws(self, request: PreparedRequest, url: str) -> AsyncResponse:
"""Handle WebSocket connections via browser native WebSocket API."""
from ._ws import AsyncPyodideWebSocketExtension
ext = AsyncPyodideWebSocketExtension()
try:
await ext.start(url)
except Exception as e:
raise ConnectionError(f"WebSocket connection to {url} failed: {e}")
response = Response()
response.status_code = 101
response.headers = CaseInsensitiveDict({"upgrade": "websocket", "connection": "upgrade"})
response.request = request
response.url = url
response.reason = "Switching Protocols"
raw_io = _AsyncPyodideRawIO(None)
raw_io.extension = ext
response.raw = raw_io # type: ignore
response._content = b""
_swap_context(response)
return response # type: ignore[return-value]
async def _do_send_sse(self, request: PreparedRequest, url: str, scheme: str) -> AsyncResponse:
"""Handle SSE connections via pyfetch streaming + manual parsing."""
from ._sse import AsyncPyodideSSEExtension
http_url = url.replace("sse://", "https://", 1) if scheme == "sse" else url.replace("psse://", "http://", 1)
# Pass through user-provided headers
headers_dict: dict[str, str] = {}
if request.headers:
for key, value in request.headers.items():
if key.lower() not in ("host", "content-length", "connection"):
headers_dict[key] = value
ext = AsyncPyodideSSEExtension()
try:
await ext.start(http_url, headers=headers_dict)
except Exception as e:
raise ConnectionError(f"SSE connection to {url} failed: {e}")
response = Response()
response.status_code = 200
response.headers = CaseInsensitiveDict({"content-type": "text/event-stream"})
response.request = request
response.url = url
response.reason = "OK"
raw_io = _AsyncPyodideRawIO(None)
raw_io.extension = ext
response.raw = raw_io # type: ignore
response._content = False # type: ignore[assignment]
response._content_consumed = False
_swap_context(response)
return response # type: ignore[return-value]
async def close(self) -> None:
"""Clean up adapter resources."""
pass
__all__ = ("AsyncPyodideAdapter",)

View File

@@ -0,0 +1,152 @@
from __future__ import annotations
import typing
from pyodide.http import pyfetch # type: ignore[import]
from ....packages.urllib3.contrib.webextensions.sse import ServerSentEvent
class AsyncPyodideSSEExtension:
"""Async SSE extension for Pyodide using pyfetch streaming + manual SSE parsing.
Reads from a ReadableStream reader, buffers partial lines,
and parses complete SSE events.
"""
def __init__(self) -> None:
self._closed = False
self._buffer: str = ""
self._last_event_id: str | None = None
self._reader: typing.Any = None
async def start(self, url: str, headers: dict[str, str] | None = None) -> None:
"""Open the SSE stream via pyfetch."""
fetch_options: dict[str, typing.Any] = {
"method": "GET",
"headers": {
"Accept": "text/event-stream",
"Cache-Control": "no-store",
**(headers or {}),
},
}
js_response = await pyfetch(url, **fetch_options)
body = js_response.js_response.body
if body is not None:
self._reader = body.getReader()
@property
def closed(self) -> bool:
return self._closed
async def _read_chunk(self) -> str | None:
"""Read the next chunk from the ReadableStream."""
if self._reader is None:
return None
try:
result = await self._reader.read()
if result.done:
return None
value = result.value
if value is not None:
return bytes(value.to_py()).decode("utf-8")
return None
except Exception:
return None
async def next_payload(self, *, raw: bool = False) -> ServerSentEvent | str | None:
"""Read and parse the next SSE event from the stream.
Returns None when the stream ends."""
if self._closed:
raise OSError("The SSE extension is closed")
# Keep reading chunks until we have a complete event (double newline)
while True:
# Check if we already have a complete event in the buffer
sep_idx = self._buffer.find("\n\n")
if sep_idx == -1:
sep_idx = self._buffer.find("\r\n\r\n")
if sep_idx != -1:
sep_len = 4
else:
sep_len = 2
else:
sep_len = 2
if sep_idx != -1:
raw_event = self._buffer[:sep_idx]
self._buffer = self._buffer[sep_idx + sep_len :]
event = self._parse_event(raw_event)
if event is not None:
if raw:
return raw_event + "\n\n"
return event
# Empty event (e.g. just comments), try next
continue
# Need more data
chunk = await self._read_chunk()
if chunk is None:
self._closed = True
return None
self._buffer += chunk
def _parse_event(self, raw_event: str) -> ServerSentEvent | None:
"""Parse a raw SSE event block into a ServerSentEvent."""
kwargs: dict[str, typing.Any] = {}
for line in raw_event.splitlines():
if not line or line.startswith(":"):
continue
key, _, value = line.partition(":")
if key not in {"event", "data", "retry", "id"}:
continue
if value.startswith(" "):
value = value[1:]
if key == "id":
if "\u0000" in value:
continue
if key == "retry":
try:
value = int(value) # type: ignore[assignment]
except (ValueError, TypeError):
continue
kwargs[key] = value
if not kwargs:
return None
if "id" not in kwargs and self._last_event_id is not None:
kwargs["id"] = self._last_event_id
event = ServerSentEvent(**kwargs)
if event.id:
self._last_event_id = event.id
return event
async def send_payload(self, buf: str | bytes) -> None:
"""SSE is one-way only."""
raise NotImplementedError("SSE is only one-way. Sending is forbidden.")
async def close(self) -> None:
"""Close the stream and release resources."""
if self._closed:
return
self._closed = True
if self._reader is not None:
try:
await self._reader.cancel()
except Exception:
pass
self._reader = None

View File

@@ -0,0 +1,124 @@
from __future__ import annotations
import asyncio
import typing
from pyodide.ffi import create_proxy # type: ignore[import]
try:
from js import WebSocket as JSWebSocket # type: ignore[import]
except ImportError:
JSWebSocket = None
class AsyncPyodideWebSocketExtension:
"""Async WebSocket extension for Pyodide using the browser's native WebSocket API.
Messages are queued from JS callbacks and dequeued by next_payload()
which awaits until a message arrives.
"""
def __init__(self) -> None:
self._closed = False
self._queue: asyncio.Queue[str | bytes | None] = asyncio.Queue()
self._proxies: list[typing.Any] = []
self._ws: typing.Any = None
async def start(self, url: str) -> None:
"""Open the WebSocket connection and wait until it's ready."""
if JSWebSocket is None: # Defensive: depends on JS runtime
raise OSError(
"WebSocket is not available in this JavaScript runtime. "
"Browser environment required (not supported in Node.js)."
)
loop = asyncio.get_running_loop()
open_future: asyncio.Future[None] = loop.create_future()
self._ws = JSWebSocket.new(url)
self._ws.binaryType = "arraybuffer"
# JS→Python callbacks: invoked by the browser event loop, not Python call
# frames, so coverage cannot trace into them.
def _onopen(event: typing.Any) -> None: # Defensive: JS callback
if not open_future.done():
open_future.set_result(None)
def _onerror(event: typing.Any) -> None: # Defensive: JS callback
if not open_future.done():
open_future.set_exception(ConnectionError("WebSocket connection failed"))
def _onmessage(event: typing.Any) -> None: # Defensive: JS callback
data = event.data
if isinstance(data, str):
self._queue.put_nowait(data)
else:
# ArrayBuffer → bytes via Pyodide
self._queue.put_nowait(bytes(data.to_py()))
def _onclose(event: typing.Any) -> None: # Defensive: JS callback
self._queue.put_nowait(None)
for name, fn in [
("onopen", _onopen),
("onerror", _onerror),
("onmessage", _onmessage),
("onclose", _onclose),
]:
proxy = create_proxy(fn)
self._proxies.append(proxy)
setattr(self._ws, name, proxy)
await open_future
@property
def closed(self) -> bool:
return self._closed
async def next_payload(self) -> str | bytes | None:
"""Await the next message from the WebSocket.
Returns None when the remote end closes the connection."""
if self._closed: # Defensive: caller should not call after close
raise OSError("The WebSocket extension is closed")
msg = await self._queue.get()
if msg is None: # Defensive: server-initiated close sentinel
self._closed = True
return msg
async def send_payload(self, buf: str | bytes) -> None:
"""Send a message over the WebSocket."""
if self._closed: # Defensive: caller should not call after close
raise OSError("The WebSocket extension is closed")
if isinstance(buf, (bytes, bytearray)):
from js import Uint8Array # type: ignore[import]
self._ws.send(Uint8Array.new(buf))
else:
self._ws.send(buf)
async def ping(self) -> None:
"""No-op — browser WebSocket handles ping/pong at protocol level."""
pass
async def close(self) -> None:
"""Close the WebSocket and clean up proxies."""
if self._closed: # Defensive: idempotent close
return
self._closed = True
try:
self._ws.close()
except Exception: # Defensive: suppress JS errors on teardown
pass
for proxy in self._proxies:
try:
proxy.destroy()
except Exception: # Defensive: suppress JS errors on teardown
pass
self._proxies.clear()

View File

@@ -0,0 +1,151 @@
from __future__ import annotations
import typing
from pyodide.ffi import run_sync # type: ignore[import]
from pyodide.http import pyfetch # type: ignore[import]
from ...packages.urllib3.contrib.webextensions.sse import ServerSentEvent
class PyodideSSEExtension:
"""SSE extension for Pyodide using pyfetch streaming + manual SSE parsing.
Synchronous via JSPI (run_sync). Reads from a ReadableStream reader,
buffers partial lines, and parses complete SSE events.
"""
def __init__(self, url: str, headers: dict[str, str] | None = None) -> None:
self._closed = False
self._buffer: str = ""
self._last_event_id: str | None = None
self._reader: typing.Any = None
fetch_options: dict[str, typing.Any] = {
"method": "GET",
"headers": {
"Accept": "text/event-stream",
"Cache-Control": "no-store",
**(headers or {}),
},
}
js_response = run_sync(pyfetch(url, **fetch_options))
body = js_response.js_response.body
if body is not None:
self._reader = body.getReader()
@property
def closed(self) -> bool:
return self._closed
def _read_chunk(self) -> str | None:
"""Read the next chunk from the ReadableStream, blocking via JSPI."""
if self._reader is None:
return None
try:
result = run_sync(self._reader.read())
if result.done:
return None
value = result.value
if value is not None:
return bytes(value.to_py()).decode("utf-8")
return None
except Exception:
return None
def next_payload(self, *, raw: bool = False) -> ServerSentEvent | str | None:
"""Read and parse the next SSE event from the stream.
Returns None when the stream ends."""
if self._closed:
raise OSError("The SSE extension is closed")
# Keep reading chunks until we have a complete event (double newline)
while True:
# Check if we already have a complete event in the buffer
sep_idx = self._buffer.find("\n\n")
if sep_idx == -1:
sep_idx = self._buffer.find("\r\n\r\n")
if sep_idx != -1:
sep_len = 4
else:
sep_len = 2
else:
sep_len = 2
if sep_idx != -1:
raw_event = self._buffer[:sep_idx]
self._buffer = self._buffer[sep_idx + sep_len :]
event = self._parse_event(raw_event)
if event is not None:
if raw:
return raw_event + "\n\n"
return event
# Empty event (e.g. just comments), try next
continue
# Need more data
chunk = self._read_chunk()
if chunk is None:
self._closed = True
return None
self._buffer += chunk
def _parse_event(self, raw_event: str) -> ServerSentEvent | None:
"""Parse a raw SSE event block into a ServerSentEvent."""
kwargs: dict[str, typing.Any] = {}
for line in raw_event.splitlines():
if not line or line.startswith(":"):
continue
key, _, value = line.partition(":")
if key not in {"event", "data", "retry", "id"}:
continue
if value.startswith(" "):
value = value[1:]
if key == "id":
if "\u0000" in value:
continue
if key == "retry":
try:
value = int(value) # type: ignore[assignment]
except (ValueError, TypeError):
continue
kwargs[key] = value
if not kwargs:
return None
if "id" not in kwargs and self._last_event_id is not None:
kwargs["id"] = self._last_event_id
event = ServerSentEvent(**kwargs)
if event.id:
self._last_event_id = event.id
return event
def send_payload(self, buf: str | bytes) -> None:
"""SSE is one-way only."""
raise NotImplementedError("SSE is only one-way. Sending is forbidden.")
def close(self) -> None:
"""Close the stream and release resources."""
if self._closed:
return
self._closed = True
if self._reader is not None:
try:
run_sync(self._reader.cancel())
except Exception:
pass
self._reader = None

View File

@@ -0,0 +1,174 @@
from __future__ import annotations
import typing
from js import Promise # type: ignore[import]
from pyodide.ffi import create_proxy, run_sync # type: ignore[import]
try:
from js import WebSocket as JSWebSocket # type: ignore[import]
except ImportError:
JSWebSocket = None
class PyodideWebSocketExtension:
"""WebSocket extension for Pyodide using the browser's native WebSocket API.
Synchronous via JSPI (run_sync). Uses JS Promises for signaling instead of
asyncio primitives, because run_sync bypasses the asyncio event loop.
Messages from JS callbacks are buffered in a list and delivered via
next_payload() which blocks via run_sync on a JS Promise.
"""
def __init__(self, url: str) -> None:
if JSWebSocket is None: # Defensive: depends on JS runtime
raise OSError(
"WebSocket is not available in this JavaScript runtime. "
"Browser environment required (not supported in Node.js)."
)
self._closed = False
self._pending: list[str | bytes | None] = []
self._waiting_resolve: typing.Any = None
self._last_msg: str | bytes | None = None
self._proxies: list[typing.Any] = []
# Create browser WebSocket
self._ws = JSWebSocket.new(url)
self._ws.binaryType = "arraybuffer"
# Open/error signaling via a JS Promise (not asyncio.Future,
# because run_sync/JSPI does not drive the asyncio event loop).
_open_state: dict[str, typing.Any] = {
"resolve": None,
"reject": None,
}
def _open_executor(resolve: typing.Any, reject: typing.Any) -> None:
_open_state["resolve"] = resolve
_open_state["reject"] = reject
exec_proxy = create_proxy(_open_executor)
open_promise = Promise.new(exec_proxy)
self._proxies.append(exec_proxy)
# JS→Python callbacks: invoked by the browser event loop, not Python call
# frames, so coverage cannot trace into them.
def _onopen(event: typing.Any) -> None: # Defensive: JS callback
r = _open_state["resolve"]
if r is not None:
_open_state["resolve"] = None
_open_state["reject"] = None
r()
def _onerror(event: typing.Any) -> None: # Defensive: JS callback
r = _open_state["reject"]
if r is not None:
_open_state["resolve"] = None
_open_state["reject"] = None
r("WebSocket connection failed")
def _onmessage(event: typing.Any) -> None: # Defensive: JS callback
data = event.data
if isinstance(data, str):
msg: str | bytes = data
else:
# ArrayBuffer → bytes via Pyodide
msg = bytes(data.to_py())
if self._waiting_resolve is not None:
self._last_msg = msg
r = self._waiting_resolve
self._waiting_resolve = None
r()
else:
self._pending.append(msg)
def _onclose(event: typing.Any) -> None: # Defensive: JS callback
if self._waiting_resolve is not None:
self._last_msg = None
r = self._waiting_resolve
self._waiting_resolve = None
r()
else:
self._pending.append(None)
# Create proxies so JS can call these Python functions
for name, fn in [
("onopen", _onopen),
("onerror", _onerror),
("onmessage", _onmessage),
("onclose", _onclose),
]:
proxy = create_proxy(fn)
self._proxies.append(proxy)
setattr(self._ws, name, proxy)
# Block until the connection is open (or error rejects the promise)
run_sync(open_promise)
@property
def closed(self) -> bool:
return self._closed
def next_payload(self) -> str | bytes | None:
"""Block (via JSPI) until the next message arrives.
Returns None when the remote end closes the connection."""
if self._closed: # Defensive: caller should not call after close
raise OSError("The WebSocket extension is closed")
# Drain from buffer first
if self._pending:
msg = self._pending.pop(0)
if msg is None: # Defensive: serverinitiated close sentinel
self._closed = True
return msg
# Wait for the next message via a JS Promise
def _executor(resolve: typing.Any, reject: typing.Any) -> None:
self._waiting_resolve = resolve
exec_proxy = create_proxy(_executor)
promise = Promise.new(exec_proxy)
run_sync(promise)
exec_proxy.destroy()
msg = self._last_msg
if msg is None: # Defensive: server-initiated close sentinel
self._closed = True
return msg
def send_payload(self, buf: str | bytes) -> None:
"""Send a message over the WebSocket."""
if self._closed: # Defensive: caller should not call after close
raise OSError("The WebSocket extension is closed")
if isinstance(buf, (bytes, bytearray)):
from js import Uint8Array # type: ignore[import]
self._ws.send(Uint8Array.new(buf))
else:
self._ws.send(buf)
def ping(self) -> None:
"""No-op — browser WebSocket handles ping/pong at protocol level."""
pass
def close(self) -> None:
"""Close the WebSocket and clean up proxies."""
if self._closed: # Defensive: idempotent close
return
self._closed = True
try:
self._ws.close()
except Exception: # Defensive: suppress JS errors on teardown
pass
for proxy in self._proxies:
try:
proxy.destroy()
except Exception: # Defensive: suppress JS errors on teardown
pass
self._proxies.clear()

View File

@@ -0,0 +1,19 @@
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
class RevocationStrategy(Enum):
PREFER_OCSP = 0
PREFER_CRL = 1
CHECK_ALL = 2
@dataclass
class RevocationConfiguration:
strategy: RevocationStrategy | None = RevocationStrategy.PREFER_OCSP
strict_mode: bool = False
DEFAULT_STRATEGY: RevocationConfiguration = RevocationConfiguration()

View File

@@ -0,0 +1,374 @@
from __future__ import annotations
import contextlib
import datetime
import ipaddress
import ssl
import threading
import typing
import warnings
from random import randint
from qh3._hazmat import (
Certificate,
CertificateRevocationList,
)
from ....exceptions import RequestException, SSLError
from ....models import PreparedRequest
from ....packages.urllib3 import ConnectionInfo
from ....packages.urllib3.contrib.resolver import BaseResolver
from ....packages.urllib3.exceptions import SecurityWarning
from ....typing import ProxyType
from .._ocsp import _parse_x509_der_cached, _str_fingerprint_of, readable_revocation_reason
class InMemoryRevocationList:
def __init__(self, max_size: int = 256):
self._max_size: int = max_size
self._store: dict[str, CertificateRevocationList] = {}
self._issuers_map: dict[str, Certificate] = {}
self._crl_endpoints: dict[str, str] = {}
self._failure_count: int = 0
self._access_lock = threading.RLock()
self._second_level_locks: dict[str, threading.RLock] = {}
@contextlib.contextmanager
def lock_for(self, peer_certificate: Certificate) -> typing.Generator[None]:
fingerprint: str = _str_fingerprint_of(peer_certificate)
with self._access_lock:
if fingerprint not in self._second_level_locks:
self._second_level_locks[fingerprint] = threading.RLock()
lock = self._second_level_locks[fingerprint]
lock.acquire()
try:
yield
finally:
lock.release()
def __getstate__(self) -> dict[str, typing.Any]:
with self._access_lock:
return {
"_max_size": self._max_size,
"_store": {k: v.serialize() for k, v in self._store.items()},
"_issuers_map": {k: v.serialize() for k, v in self._issuers_map.items()},
"_failure_count": self._failure_count,
"_crl_endpoints": self._crl_endpoints,
}
def __setstate__(self, state: dict[str, typing.Any]) -> None:
if "_store" not in state or "_issuers_map" not in state or "_max_size" not in state or "_crl_endpoints" not in state:
raise OSError("unrecoverable state for InMemoryRevocationStatus")
self._access_lock = threading.RLock()
self._second_level_locks = {}
self._max_size = state["_max_size"]
self._failure_count = state["_failure_count"] if "_failure_count" in state else 0
self._crl_endpoints = state["_crl_endpoints"]
self._store = {}
for k, v in state["_store"].items():
self._store[k] = CertificateRevocationList.deserialize(v)
self._issuers_map = {}
for k, v in state["_issuers_map"].items():
self._issuers_map[k] = Certificate.deserialize(v)
def get_issuer_of(self, peer_certificate: Certificate) -> Certificate | None:
with self._access_lock:
fingerprint: str = _str_fingerprint_of(peer_certificate)
if fingerprint not in self._issuers_map:
return None
return self._issuers_map[fingerprint]
def __len__(self) -> int:
with self._access_lock:
return len(self._store)
def incr_failure(self) -> None:
with self._access_lock:
self._failure_count += 1
@property
def failure_count(self) -> int:
return self._failure_count
def check(self, crl_distribution_point: str) -> CertificateRevocationList | None:
with self._access_lock:
if crl_distribution_point not in self._store:
return None
cached_response = self._store[crl_distribution_point]
if cached_response.next_update_at and datetime.datetime.now().timestamp() >= cached_response.next_update_at:
del self._store[crl_distribution_point]
return None
return cached_response
def get_previous_crl_endpoint(self, leaf: Certificate) -> str | None:
fingerprint = _str_fingerprint_of(leaf)
if fingerprint in self._crl_endpoints:
return self._crl_endpoints[fingerprint]
return None
def save(
self,
leaf: Certificate,
issuer: Certificate,
crl: CertificateRevocationList,
crl_distribution_point: str,
) -> None:
with self._access_lock:
if len(self._store) >= self._max_size:
tbd_key: str | None = None
closest_next_update: int | None = None
for k in self._store:
if closest_next_update is None:
closest_next_update = self._store[k].next_update_at
tbd_key = k
continue
if self._store[k].next_update_at > closest_next_update: # type: ignore
closest_next_update = self._store[k].next_update_at
tbd_key = k
if tbd_key:
del self._store[tbd_key]
else:
first_key = list(self._store.keys())[0]
del self._store[first_key]
peer_fingerprint: str = _str_fingerprint_of(leaf)
self._store[crl_distribution_point] = crl
self._crl_endpoints[peer_fingerprint] = crl_distribution_point
self._issuers_map[peer_fingerprint] = issuer
self._failure_count = 0
def verify(
r: PreparedRequest,
strict: bool = False,
timeout: float | int = 0.2,
proxies: ProxyType | None = None,
resolver: BaseResolver | None = None,
happy_eyeballs: bool | int = False,
cache: InMemoryRevocationList | None = None,
) -> None:
conn_info: ConnectionInfo | None = r.conn_info
# we can't do anything in that case.
if conn_info is None or conn_info.certificate_der is None or conn_info.certificate_dict is None:
return
endpoints: list[str] = [ # type: ignore
# exclude non-HTTP endpoint. like ldap.
ep # type: ignore
for ep in list(conn_info.certificate_dict.get("crlDistributionPoints", [])) # type: ignore
if ep.startswith("http://") # type: ignore
]
# no CRL distribution point available.
if not endpoints:
return
if cache is None:
cache = InMemoryRevocationList()
if not strict:
if cache.failure_count >= 4:
return
# some corporate environment
# have invalid OCSP implementation
# they use a cert that IS NOT in the chain
# to sign the response. It's weird but true.
# see https://github.com/jawah/niquests/issues/274
ignore_signature_without_strict = ipaddress.ip_address(conn_info.destination_address[0]).is_private or bool(proxies)
verify_signature = strict is True or ignore_signature_without_strict is False
peer_certificate: Certificate = _parse_x509_der_cached(conn_info.certificate_der)
crl_distribution_point: str = cache.get_previous_crl_endpoint(peer_certificate) or endpoints[randint(0, len(endpoints) - 1)]
with cache.lock_for(peer_certificate):
cached_revocation_list = cache.check(crl_distribution_point)
if cached_revocation_list is not None:
issuer_certificate = cache.get_issuer_of(peer_certificate)
if issuer_certificate:
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
revocation_status = cached_revocation_list.is_revoked(peer_certificate.serial_number)
if revocation_status is not None:
r.ocsp_verified = False
raise SSLError(
(
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
f"by issuer ({readable_revocation_reason(revocation_status.reason)}). "
"You should avoid trying to request anything from it as the remote has been compromised. ",
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
"for more information.",
)
)
else:
r.ocsp_verified = True
return
from ....sessions import Session
with Session(resolver=resolver, happy_eyeballs=happy_eyeballs) as session:
session.trust_env = False
if proxies:
session.proxies = proxies
# When using Python native capabilities, you won't have the issuerCA DER by default.
# Unfortunately! But no worries, we can circumvent it!
# Three ways are valid to fetch it (in order of preference, safest to riskiest):
# - The issuer can be (but unlikely) a root CA.
# - Retrieve it by asking it from the TLS layer.
# - Downloading it using specified caIssuers from the peer certificate.
if conn_info.issuer_certificate_der is None:
# It could be a root (self-signed) certificate. Or a previously seen issuer.
issuer_certificate = cache.get_issuer_of(peer_certificate)
hint_ca_issuers: list[str] = [
ep # type: ignore
for ep in list(conn_info.certificate_dict.get("caIssuers", [])) # type: ignore
if ep.startswith("http://") # type: ignore
]
if issuer_certificate is None and hint_ca_issuers:
try:
raw_intermediary_response = session.get(hint_ca_issuers[0])
except RequestException:
pass
else:
if raw_intermediary_response.status_code and 300 > raw_intermediary_response.status_code >= 200:
raw_intermediary_content = raw_intermediary_response.content
if raw_intermediary_content is not None:
# binary DER
if b"-----BEGIN CERTIFICATE-----" not in raw_intermediary_content:
issuer_certificate = Certificate(raw_intermediary_content)
# b64 PEM
elif b"-----BEGIN CERTIFICATE-----" in raw_intermediary_content:
issuer_certificate = Certificate(
ssl.PEM_cert_to_DER_cert(raw_intermediary_content.decode())
)
# Well! We're out of luck. No further should we go.
if issuer_certificate is None:
cache.incr_failure()
if strict:
warnings.warn(
(
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate "
"via CRL. You are seeing this warning due to enabling strict mode for OCSP / "
"Revocation check. Reason: Remote did not provide any intermediary certificate."
),
SecurityWarning,
)
return
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
else:
issuer_certificate = Certificate(conn_info.issuer_certificate_der)
try:
crl_http_response = session.get(
crl_distribution_point,
timeout=timeout,
)
except RequestException as e:
if strict:
warnings.warn(
(
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via CRL. "
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
f"Reason: {e}"
),
SecurityWarning,
)
return
if crl_http_response.status_code and 300 > crl_http_response.status_code >= 200:
if crl_http_response.content is None:
return
try:
crl = CertificateRevocationList(crl_http_response.content)
except ValueError:
if strict:
warnings.warn(
(
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via CRL. "
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
"Reason: The X509 CRL parser failed to read the response"
),
SecurityWarning,
)
return
if verify_signature:
# Verify the signature of the OCSP response with issuer public key
try:
if not crl.authenticate_for(issuer_certificate.public_bytes()):
raise SSLError(
f"Unable to establish a secure connection to {r.url} "
"because the CRL response received has been tampered. "
"You could be targeted by a MITM attack."
)
except ValueError:
if strict:
warnings.warn(
(
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate "
"via CRL. You are seeing this warning due to enabling strict mode for OCSP / "
"Revocation check. Reason: The X509 CRL is signed using an unsupported algorithm."
),
SecurityWarning,
)
cache.save(peer_certificate, issuer_certificate, crl, crl_distribution_point)
revocation_status = crl.is_revoked(peer_certificate.serial_number)
if revocation_status is not None:
r.ocsp_verified = False
raise SSLError(
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
f"by issuer ({readable_revocation_reason(revocation_status.reason)}). "
"You should avoid trying to request anything from it as the remote has been compromised. "
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
"for more information."
)
else:
r.ocsp_verified = True
else:
if strict:
warnings.warn(
(
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via CRL. "
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
f"CRL endpoint: {str(crl_http_response)}"
),
SecurityWarning,
)
__all__ = ("verify",)

View File

@@ -0,0 +1,405 @@
from __future__ import annotations
import asyncio
import datetime
import ipaddress
import ssl
import typing
import warnings
from contextlib import asynccontextmanager
from random import randint
from qh3._hazmat import (
Certificate,
CertificateRevocationList,
)
from .....exceptions import RequestException, SSLError
from .....models import PreparedRequest
from .....packages.urllib3 import ConnectionInfo
from .....packages.urllib3.contrib.resolver._async import AsyncBaseResolver
from .....packages.urllib3.exceptions import SecurityWarning
from .....typing import ProxyType
from .....utils import is_cancelled_error_root_cause
from ..._ocsp import _parse_x509_der_cached, _str_fingerprint_of, readable_revocation_reason
class InMemoryRevocationList:
def __init__(self, max_size: int = 256):
self._max_size: int = max_size
self._store: dict[str, CertificateRevocationList] = {}
self._semaphores: dict[str, asyncio.Semaphore] = {}
self._issuers_map: dict[str, Certificate] = {}
self._crl_endpoints: dict[str, str] = {}
self._failure_count: int = 0
def __getstate__(self) -> dict[str, typing.Any]:
return {
"_max_size": self._max_size,
"_store": {k: v.serialize() for k, v in self._store.items()},
"_issuers_map": {k: v.serialize() for k, v in self._issuers_map.items()},
"_failure_count": self._failure_count,
"_crl_endpoints": self._crl_endpoints,
}
def __setstate__(self, state: dict[str, typing.Any]) -> None:
if "_store" not in state or "_issuers_map" not in state or "_max_size" not in state or "_crl_endpoints" not in state:
raise OSError("unrecoverable state for InMemoryRevocationStatus")
self._max_size = state["_max_size"]
self._failure_count = state["_failure_count"] if "_failure_count" in state else 0
self._crl_endpoints = state["_crl_endpoints"]
self._semaphores = {}
self._store = {}
for k, v in state["_store"].items():
self._store[k] = CertificateRevocationList.deserialize(v)
self._issuers_map = {}
for k, v in state["_issuers_map"].items():
self._issuers_map[k] = Certificate.deserialize(v)
@asynccontextmanager
async def lock(self, peer_certificate: Certificate) -> typing.AsyncGenerator[None, None]:
fingerprint: str = _str_fingerprint_of(peer_certificate)
if fingerprint not in self._semaphores:
self._semaphores[fingerprint] = asyncio.Semaphore()
await self._semaphores[fingerprint].acquire()
try:
yield
finally:
self._semaphores[fingerprint].release()
def get_issuer_of(self, peer_certificate: Certificate) -> Certificate | None:
fingerprint: str = _str_fingerprint_of(peer_certificate)
if fingerprint not in self._issuers_map:
return None
return self._issuers_map[fingerprint]
def __len__(self) -> int:
return len(self._store)
def incr_failure(self) -> None:
self._failure_count += 1
@property
def failure_count(self) -> int:
return self._failure_count
def check(self, crl_distribution_point: str) -> CertificateRevocationList | None:
if crl_distribution_point not in self._store:
return None
cached_response = self._store[crl_distribution_point]
if cached_response.next_update_at and datetime.datetime.now().timestamp() >= cached_response.next_update_at:
del self._store[crl_distribution_point]
return None
return cached_response
def get_previous_crl_endpoint(self, leaf: Certificate) -> str | None:
fingerprint = _str_fingerprint_of(leaf)
if fingerprint in self._crl_endpoints:
return self._crl_endpoints[fingerprint]
return None
def save(
self,
leaf: Certificate,
issuer: Certificate,
crl: CertificateRevocationList,
crl_distribution_point: str,
) -> None:
if len(self._store) >= self._max_size:
tbd_key: str | None = None
closest_next_update: int | None = None
for k in self._store:
if closest_next_update is None:
closest_next_update = self._store[k].next_update_at
tbd_key = k
continue
if self._store[k].next_update_at > closest_next_update: # type: ignore
closest_next_update = self._store[k].next_update_at
tbd_key = k
if tbd_key:
del self._store[tbd_key]
else:
first_key = list(self._store.keys())[0]
del self._store[first_key]
peer_fingerprint: str = _str_fingerprint_of(leaf)
self._store[crl_distribution_point] = crl
self._crl_endpoints[peer_fingerprint] = crl_distribution_point
self._issuers_map[peer_fingerprint] = issuer
self._failure_count = 0
async def verify(
r: PreparedRequest,
strict: bool = False,
timeout: float | int = 0.2,
proxies: ProxyType | None = None,
resolver: AsyncBaseResolver | None = None,
happy_eyeballs: bool | int = False,
cache: InMemoryRevocationList | None = None,
) -> None:
conn_info: ConnectionInfo | None = r.conn_info
# we can't do anything in that case.
if conn_info is None or conn_info.certificate_der is None or conn_info.certificate_dict is None:
return
endpoints: list[str] = [ # type: ignore
# exclude non-HTTP endpoint. like ldap.
ep # type: ignore
for ep in list(conn_info.certificate_dict.get("crlDistributionPoints", [])) # type: ignore
if ep.startswith("http://") # type: ignore
]
# well... not all issued certificate have a OCSP entry. e.g. mkcert.
if not endpoints:
return
if cache is None:
cache = InMemoryRevocationList()
if not strict:
if cache.failure_count >= 4:
return
# some corporate environment
# have invalid OCSP implementation
# they use a cert that IS NOT in the chain
# to sign the response. It's weird but true.
# see https://github.com/jawah/niquests/issues/274
ignore_signature_without_strict = ipaddress.ip_address(conn_info.destination_address[0]).is_private or bool(proxies)
verify_signature = strict is True or ignore_signature_without_strict is False
peer_certificate: Certificate = _parse_x509_der_cached(conn_info.certificate_der)
crl_distribution_point: str = cache.get_previous_crl_endpoint(peer_certificate) or endpoints[randint(0, len(endpoints) - 1)]
cached_revocation_list = cache.check(crl_distribution_point)
if cached_revocation_list is not None:
issuer_certificate = cache.get_issuer_of(peer_certificate)
if issuer_certificate is not None:
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
revocation_status = cached_revocation_list.is_revoked(peer_certificate.serial_number)
if revocation_status is not None:
r.ocsp_verified = False
raise SSLError(
(
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
f"by issuer ({readable_revocation_reason(revocation_status.reason)}). "
"You should avoid trying to request anything from it as the remote has been compromised. ",
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
"for more information.",
)
)
else:
r.ocsp_verified = True
return
async with cache.lock(peer_certificate):
# why are we doing this twice?
# because using Semaphore to prevent concurrent
# revocation check have a heavy toll!
# todo: respect DRY
cached_revocation_list = cache.check(crl_distribution_point)
if cached_revocation_list is not None:
issuer_certificate = cache.get_issuer_of(peer_certificate)
if issuer_certificate is not None:
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
revocation_status = cached_revocation_list.is_revoked(peer_certificate.serial_number)
if revocation_status is not None:
r.ocsp_verified = False
raise SSLError(
(
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
f"by issuer ({readable_revocation_reason(revocation_status.reason)}). "
"You should avoid trying to request anything from it as the remote has been compromised. ",
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
"for more information.",
)
)
else:
r.ocsp_verified = True
return
from .....async_session import AsyncSession
async with AsyncSession(resolver=resolver, happy_eyeballs=happy_eyeballs) as session:
session.trust_env = False
if proxies:
session.proxies = proxies
# When using Python native capabilities, you won't have the issuerCA DER by default.
# Unfortunately! But no worries, we can circumvent it!
# Three ways are valid to fetch it (in order of preference, safest to riskiest):
# - The issuer can be (but unlikely) a root CA.
# - Retrieve it by asking it from the TLS layer.
# - Downloading it using specified caIssuers from the peer certificate.
if conn_info.issuer_certificate_der is None:
# It could be a root (self-signed) certificate. Or a previously seen issuer.
issuer_certificate = cache.get_issuer_of(peer_certificate)
hint_ca_issuers: list[str] = [
ep # type: ignore
for ep in list(conn_info.certificate_dict.get("caIssuers", [])) # type: ignore
if ep.startswith("http://") # type: ignore
]
if issuer_certificate is None and hint_ca_issuers:
try:
raw_intermediary_response = await session.get(hint_ca_issuers[0])
except RequestException as e:
if is_cancelled_error_root_cause(e):
return
except asyncio.CancelledError: # don't raise any error or warnings!
return
else:
if raw_intermediary_response.status_code and 300 > raw_intermediary_response.status_code >= 200:
raw_intermediary_content = raw_intermediary_response.content
if raw_intermediary_content is not None:
# binary DER
if b"-----BEGIN CERTIFICATE-----" not in raw_intermediary_content:
issuer_certificate = Certificate(raw_intermediary_content)
# b64 PEM
elif b"-----BEGIN CERTIFICATE-----" in raw_intermediary_content:
issuer_certificate = Certificate(
ssl.PEM_cert_to_DER_cert(raw_intermediary_content.decode())
)
# Well! We're out of luck. No further should we go.
if issuer_certificate is None:
# aia fetching should be counted as general ocsp failure too.
cache.incr_failure()
if strict:
warnings.warn(
(
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate "
"via CRL. You are seeing this warning due to enabling strict mode for OCSP / "
"Revocation check. Reason: Remote did not provide any intermediary certificate."
),
SecurityWarning,
)
return
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
else:
issuer_certificate = Certificate(conn_info.issuer_certificate_der)
try:
crl_http_response = await session.get(
crl_distribution_point,
timeout=timeout,
)
except RequestException as e:
if is_cancelled_error_root_cause(e):
return
# aia fetching should be counted as general ocsp failure too.
cache.incr_failure()
if strict:
warnings.warn(
(
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via CRL. "
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
f"Reason: {e}"
),
SecurityWarning,
)
return
except asyncio.CancelledError: # don't raise any error or warnings!
return
if crl_http_response.status_code and 300 > crl_http_response.status_code >= 200:
if crl_http_response.content is None:
return
try:
crl = CertificateRevocationList(crl_http_response.content)
except ValueError:
if strict:
warnings.warn(
(
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via CRL. "
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
"Reason: The X509 CRL parser failed to read the response"
),
SecurityWarning,
)
return
if verify_signature:
# Verify the signature of the OCSP response with issuer public key
try:
if not crl.authenticate_for(issuer_certificate.public_bytes()):
raise SSLError(
f"Unable to establish a secure connection to {r.url} "
"because the CRL response received has been tampered. "
"You could be targeted by a MITM attack."
)
except ValueError:
if strict:
warnings.warn(
(
f"Unable to insure that the remote peer ({r.url}) has a currently valid "
"certificate via CRL. You are seeing this warning due to enabling strict "
"mode for OCSP / Revocation check. "
"Reason: The X509 CRL is signed using an unsupported algorithm."
),
SecurityWarning,
)
cache.save(peer_certificate, issuer_certificate, crl, crl_distribution_point)
revocation_status = crl.is_revoked(peer_certificate.serial_number)
if revocation_status is not None:
r.ocsp_verified = False
raise SSLError(
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
f"by issuer ({readable_revocation_reason(revocation_status.reason)}). "
"You should avoid trying to request anything from it as the remote has been compromised. "
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
"for more information."
)
else:
r.ocsp_verified = True
else:
if strict:
warnings.warn(
(
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via CRL. "
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
f"CRL endpoint: {str(crl_http_response)}"
),
SecurityWarning,
)
__all__ = ("verify",)

View File

@@ -0,0 +1,477 @@
from __future__ import annotations
import contextlib
import datetime
import ipaddress
import ssl
import threading
import typing
import warnings
from functools import lru_cache
from hashlib import sha256
from random import randint
from qh3._hazmat import (
Certificate,
OCSPCertStatus,
OCSPRequest,
OCSPResponse,
OCSPResponseStatus,
ReasonFlags,
)
from ....exceptions import RequestException, SSLError
from ....models import PreparedRequest
from ....packages.urllib3 import ConnectionInfo
from ....packages.urllib3.contrib.resolver import BaseResolver
from ....packages.urllib3.exceptions import SecurityWarning
from ....typing import ProxyType
@lru_cache(maxsize=64)
def _parse_x509_der_cached(der: bytes) -> Certificate:
return Certificate(der)
@lru_cache(maxsize=64)
def _fingerprint_raw_data(payload: bytes) -> str:
return "".join([format(i, "02x") for i in sha256(payload).digest()])
def _str_fingerprint_of(certificate: Certificate) -> str:
return _fingerprint_raw_data(certificate.public_bytes())
def readable_revocation_reason(flag: ReasonFlags | None) -> str | None:
return str(flag).split(".")[-1].lower() if flag is not None else None
class InMemoryRevocationStatus:
def __init__(self, max_size: int = 2048):
self._max_size: int = max_size
self._store: dict[str, OCSPResponse] = {}
self._issuers_map: dict[str, Certificate] = {}
self._timings: list[datetime.datetime] = []
self._failure_count: int = 0
self._access_lock = threading.RLock()
self._second_level_locks: dict[str, threading.RLock] = {}
self.hold: bool = False
@staticmethod
def support_pickle() -> bool:
"""This gives you a hint on whether you can cache it to restore later."""
return hasattr(OCSPResponse, "serialize")
def __getstate__(self) -> dict[str, typing.Any]:
with self._access_lock:
return {
"_max_size": self._max_size,
"_store": {k: v.serialize() for k, v in self._store.items()},
"_issuers_map": {k: v.serialize() for k, v in self._issuers_map.items()},
"_failure_count": self._failure_count,
}
@contextlib.contextmanager
def lock_for(self, peer_certificate: Certificate) -> typing.Generator[None]:
fingerprint: str = _str_fingerprint_of(peer_certificate)
with self._access_lock:
if fingerprint not in self._second_level_locks:
self._second_level_locks[fingerprint] = threading.RLock()
lock = self._second_level_locks[fingerprint]
lock.acquire()
try:
yield
finally:
lock.release()
def __setstate__(self, state: dict[str, typing.Any]) -> None:
if "_store" not in state or "_issuers_map" not in state or "_max_size" not in state:
raise OSError("unrecoverable state for InMemoryRevocationStatus")
self._access_lock = threading.RLock()
self._second_level_locks = {}
self.hold = False
self._timings = []
self._max_size = state["_max_size"]
self._failure_count = state["_failure_count"] if "_failure_count" in state else 0
self._store = {}
for k, v in state["_store"].items():
self._store[k] = OCSPResponse.deserialize(v)
self._issuers_map = {}
for k, v in state["_issuers_map"].items():
self._issuers_map[k] = Certificate.deserialize(v)
def get_issuer_of(self, peer_certificate: Certificate) -> Certificate | None:
with self._access_lock:
fingerprint: str = _str_fingerprint_of(peer_certificate)
if fingerprint not in self._issuers_map:
return None
return self._issuers_map[fingerprint]
def __len__(self) -> int:
with self._access_lock:
return len(self._store)
def incr_failure(self) -> None:
with self._access_lock:
self._failure_count += 1
@property
def failure_count(self) -> int:
return self._failure_count
def rate(self):
with self._access_lock:
previous_dt: datetime.datetime | None = None
delays: list[float] = []
for dt in self._timings:
if previous_dt is None:
previous_dt = dt
continue
delays.append((dt - previous_dt).total_seconds())
previous_dt = dt
return sum(delays) / len(delays) if delays else 0.0
def check(self, peer_certificate: Certificate) -> OCSPResponse | None:
with self._access_lock:
fingerprint: str = _str_fingerprint_of(peer_certificate)
if fingerprint not in self._store:
return None
cached_response = self._store[fingerprint]
if cached_response.certificate_status == OCSPCertStatus.GOOD:
if cached_response.next_update and datetime.datetime.now().timestamp() >= cached_response.next_update:
del self._store[fingerprint]
return None
return cached_response
return cached_response
def save(
self,
peer_certificate: Certificate,
issuer_certificate: Certificate,
ocsp_response: OCSPResponse,
) -> None:
with self._access_lock:
if len(self._store) >= self._max_size:
tbd_key: str | None = None
closest_next_update: int | None = None
for k in self._store:
if self._store[k].response_status != OCSPResponseStatus.SUCCESSFUL:
tbd_key = k
break
if self._store[k].certificate_status != OCSPCertStatus.REVOKED:
if closest_next_update is None:
closest_next_update = self._store[k].next_update
tbd_key = k
continue
if self._store[k].next_update > closest_next_update: # type: ignore
closest_next_update = self._store[k].next_update
tbd_key = k
if tbd_key:
del self._store[tbd_key]
del self._issuers_map[tbd_key]
else:
first_key = list(self._store.keys())[0]
del self._store[first_key]
del self._issuers_map[first_key]
peer_fingerprint: str = _str_fingerprint_of(peer_certificate)
self._store[peer_fingerprint] = ocsp_response
self._issuers_map[peer_fingerprint] = issuer_certificate
self._failure_count = 0
self._timings.append(datetime.datetime.now())
if len(self._timings) >= self._max_size:
self._timings.pop(0)
def verify(
r: PreparedRequest,
strict: bool = False,
timeout: float | int = 0.2,
proxies: ProxyType | None = None,
resolver: BaseResolver | None = None,
happy_eyeballs: bool | int = False,
cache: InMemoryRevocationStatus | None = None,
) -> None:
conn_info: ConnectionInfo | None = r.conn_info
# we can't do anything in that case.
if conn_info is None or conn_info.certificate_der is None or conn_info.certificate_dict is None:
return
endpoints: list[str] = [ # type: ignore
# exclude non-HTTP endpoint. like ldap.
ep # type: ignore
for ep in list(conn_info.certificate_dict.get("OCSP", [])) # type: ignore
if ep.startswith("http://") # type: ignore
]
# well... not all issued certificate have a OCSP entry. e.g. mkcert.
if not endpoints:
return
if cache is None:
cache = InMemoryRevocationStatus()
# this feature, by default, is reserved for a reasonable usage.
if not strict:
if cache.failure_count >= 4:
return
mean_rate_sec = cache.rate()
cache_count = len(cache)
if cache_count >= 10 and mean_rate_sec <= 1.0:
cache.hold = True
if cache.hold:
return
# some corporate environment
# have invalid OCSP implementation
# they use a cert that IS NOT in the chain
# to sign the response. It's weird but true.
# see https://github.com/jawah/niquests/issues/274
ignore_signature_without_strict = ipaddress.ip_address(conn_info.destination_address[0]).is_private or bool(proxies)
verify_signature = strict is True or ignore_signature_without_strict is False
peer_certificate = _parse_x509_der_cached(conn_info.certificate_der)
with cache.lock_for(peer_certificate):
cached_response = cache.check(peer_certificate)
if cached_response is not None:
issuer_certificate = cache.get_issuer_of(peer_certificate)
if issuer_certificate:
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
if cached_response.response_status == OCSPResponseStatus.SUCCESSFUL:
if cached_response.certificate_status == OCSPCertStatus.REVOKED:
r.ocsp_verified = False
raise SSLError(
(
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
f"by issuer ({readable_revocation_reason(cached_response.revocation_reason) or 'unspecified'}). "
"You should avoid trying to request anything from it as the remote has been compromised. ",
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
"for more information.",
)
)
elif cached_response.certificate_status == OCSPCertStatus.UNKNOWN:
r.ocsp_verified = False
if strict is True:
raise SSLError(
f"Unable to establish a secure connection to {r.url} because the issuer does not know "
"whether certificate is valid or not. This error occurred because you enabled strict mode "
"for the OCSP / Revocation check."
)
else:
r.ocsp_verified = True
return
from ....sessions import Session
with Session(resolver=resolver, happy_eyeballs=happy_eyeballs) as session:
session.trust_env = False
if proxies:
session.proxies = proxies
# When using Python native capabilities, you won't have the issuerCA DER by default.
# Unfortunately! But no worries, we can circumvent it!
# Three ways are valid to fetch it (in order of preference, safest to riskiest):
# - The issuer can be (but unlikely) a root CA.
# - Retrieve it by asking it from the TLS layer.
# - Downloading it using specified caIssuers from the peer certificate.
if conn_info.issuer_certificate_der is None:
# It could be a root (self-signed) certificate. Or a previously seen issuer.
issuer_certificate = cache.get_issuer_of(peer_certificate)
hint_ca_issuers: list[str] = [
ep # type: ignore
for ep in list(conn_info.certificate_dict.get("caIssuers", [])) # type: ignore
if ep.startswith("http://") # type: ignore
]
# try to do AIA fetching of intermediate certificate (issuer)
if issuer_certificate is None and hint_ca_issuers:
try:
raw_intermediary_response = session.get(hint_ca_issuers[0])
except RequestException:
pass
else:
if raw_intermediary_response.status_code and 300 > raw_intermediary_response.status_code >= 200:
raw_intermediary_content = raw_intermediary_response.content
if raw_intermediary_content is not None:
# binary DER
if b"-----BEGIN CERTIFICATE-----" not in raw_intermediary_content:
issuer_certificate = Certificate(raw_intermediary_content)
# b64 PEM
elif b"-----BEGIN CERTIFICATE-----" in raw_intermediary_content:
issuer_certificate = Certificate(
ssl.PEM_cert_to_DER_cert(raw_intermediary_content.decode())
)
# Well! We're out of luck. No further should we go.
if issuer_certificate is None:
# aia fetching should be counted as general ocsp failure too.
cache.incr_failure()
if strict:
warnings.warn(
(
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate "
"via OCSP. You are seeing this warning due to enabling strict mode for OCSP / "
"Revocation check. Reason: Remote did not provide any intermediary certificate."
),
SecurityWarning,
)
return
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
else:
issuer_certificate = Certificate(conn_info.issuer_certificate_der)
try:
req = OCSPRequest(peer_certificate.public_bytes(), issuer_certificate.public_bytes())
except ValueError:
if strict:
warnings.warn(
(
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
"Reason: The X509 OCSP generator failed to assemble the request."
),
SecurityWarning,
)
return
try:
ocsp_http_response = session.post(
endpoints[randint(0, len(endpoints) - 1)],
data=req.public_bytes(),
headers={"Content-Type": "application/ocsp-request"},
timeout=timeout,
)
except RequestException as e:
# we want to monitor failures related to the responder.
# we don't want to ruin the http experience in normal circumstances.
cache.incr_failure()
if strict:
warnings.warn(
(
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
f"Reason: {e}"
),
SecurityWarning,
)
return
if ocsp_http_response.status_code and 300 > ocsp_http_response.status_code >= 200:
if ocsp_http_response.content is None:
return
try:
ocsp_resp = OCSPResponse(ocsp_http_response.content)
except ValueError:
if strict:
warnings.warn(
(
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
"Reason: The X509 OCSP parser failed to read the response"
),
SecurityWarning,
)
return
# Verify the signature of the OCSP response with issuer public key
if verify_signature:
try:
if not ocsp_resp.authenticate_for(issuer_certificate.public_bytes()): # type: ignore[attr-defined]
raise SSLError(
f"Unable to establish a secure connection to {r.url} "
"because the OCSP response received has been tampered. "
"You could be targeted by a MITM attack."
)
except ValueError: # Defensive: unsupported signature case
if strict:
warnings.warn(
(
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate "
"via OCSP. You are seeing this warning due to enabling strict mode for OCSP / "
"Revocation check. Reason: The X509 OCSP response is signed using an unsupported algorithm."
),
SecurityWarning,
)
cache.save(peer_certificate, issuer_certificate, ocsp_resp)
if ocsp_resp.response_status == OCSPResponseStatus.SUCCESSFUL:
if ocsp_resp.certificate_status == OCSPCertStatus.REVOKED:
r.ocsp_verified = False
raise SSLError(
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
f"by issuer ({readable_revocation_reason(ocsp_resp.revocation_reason) or 'unspecified'}). "
"You should avoid trying to request anything from it as the remote has been compromised. "
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
"for more information."
)
if ocsp_resp.certificate_status == OCSPCertStatus.UNKNOWN:
r.ocsp_verified = False
if strict is True:
raise SSLError(
f"Unable to establish a secure connection to {r.url} because the issuer does not know whether "
"certificate is valid or not. This error occurred because you enabled strict mode for "
"the OCSP / Revocation check."
)
else:
r.ocsp_verified = True
else:
if strict:
warnings.warn(
(
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
f"OCSP Server Status: {ocsp_resp.response_status}"
),
SecurityWarning,
)
else:
if strict:
warnings.warn(
(
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
f"OCSP Server Status: {str(ocsp_http_response)}"
),
SecurityWarning,
)
__all__ = ("verify",)

View File

@@ -0,0 +1,492 @@
from __future__ import annotations
import asyncio
import datetime
import ipaddress
import ssl
import typing
import warnings
from contextlib import asynccontextmanager
from random import randint
from qh3._hazmat import (
Certificate,
OCSPCertStatus,
OCSPRequest,
OCSPResponse,
OCSPResponseStatus,
)
from .....exceptions import RequestException, SSLError
from .....models import PreparedRequest
from .....packages.urllib3 import ConnectionInfo
from .....packages.urllib3.contrib.resolver._async import AsyncBaseResolver
from .....packages.urllib3.exceptions import SecurityWarning
from .....typing import ProxyType
from .....utils import is_cancelled_error_root_cause
from .. import (
_parse_x509_der_cached,
_str_fingerprint_of,
readable_revocation_reason,
)
class InMemoryRevocationStatus:
def __init__(self, max_size: int = 2048):
self._max_size: int = max_size
self._store: dict[str, OCSPResponse] = {}
self._semaphores: dict[str, asyncio.Semaphore] = {}
self._issuers_map: dict[str, Certificate] = {}
self._timings: list[datetime.datetime] = []
self._failure_count: int = 0
self.hold: bool = False
@staticmethod
def support_pickle() -> bool:
"""This gives you a hint on whether you can cache it to restore later."""
return hasattr(OCSPResponse, "serialize")
def __getstate__(self) -> dict[str, typing.Any]:
return {
"_max_size": self._max_size,
"_store": {k: v.serialize() for k, v in self._store.items()},
"_issuers_map": {k: v.serialize() for k, v in self._issuers_map.items()},
"_failure_count": self._failure_count,
}
def __setstate__(self, state: dict[str, typing.Any]) -> None:
if "_store" not in state or "_issuers_map" not in state or "_max_size" not in state:
raise OSError("unrecoverable state for InMemoryRevocationStatus")
self.hold = False
self._timings = []
self._max_size = state["_max_size"]
self._failure_count = state["_failure_count"] if "_failure_count" in state else 0
self._store = {}
self._semaphores = {}
for k, v in state["_store"].items():
self._store[k] = OCSPResponse.deserialize(v)
self._semaphores[k] = asyncio.Semaphore()
self._issuers_map = {}
for k, v in state["_issuers_map"].items():
self._issuers_map[k] = Certificate.deserialize(v)
def get_issuer_of(self, peer_certificate: Certificate) -> Certificate | None:
fingerprint: str = _str_fingerprint_of(peer_certificate)
if fingerprint not in self._issuers_map:
return None
return self._issuers_map[fingerprint]
def __len__(self) -> int:
return len(self._store)
def incr_failure(self) -> None:
self._failure_count += 1
@property
def failure_count(self) -> int:
return self._failure_count
@asynccontextmanager
async def lock(self, peer_certificate: Certificate) -> typing.AsyncGenerator[None, None]:
fingerprint: str = _str_fingerprint_of(peer_certificate)
if fingerprint not in self._semaphores:
self._semaphores[fingerprint] = asyncio.Semaphore()
await self._semaphores[fingerprint].acquire()
try:
yield
finally:
self._semaphores[fingerprint].release()
def rate(self):
previous_dt: datetime.datetime | None = None
delays: list[float] = []
for dt in self._timings:
if previous_dt is None:
previous_dt = dt
continue
delays.append((dt - previous_dt).total_seconds())
previous_dt = dt
return sum(delays) / len(delays) if delays else 0.0
def check(self, peer_certificate: Certificate) -> OCSPResponse | None:
fingerprint: str = _str_fingerprint_of(peer_certificate)
if fingerprint not in self._store:
return None
cached_response = self._store[fingerprint]
if cached_response.certificate_status == OCSPCertStatus.GOOD:
if cached_response.next_update and datetime.datetime.now().timestamp() >= cached_response.next_update:
del self._store[fingerprint]
return None
return cached_response
return cached_response
def save(
self,
peer_certificate: Certificate,
issuer_certificate: Certificate,
ocsp_response: OCSPResponse,
) -> None:
if len(self._store) >= self._max_size:
tbd_key: str | None = None
closest_next_update: int | None = None
for k in self._store:
if self._store[k].response_status != OCSPResponseStatus.SUCCESSFUL:
tbd_key = k
break
if self._store[k].certificate_status != OCSPCertStatus.REVOKED:
if closest_next_update is None:
closest_next_update = self._store[k].next_update
tbd_key = k
continue
if self._store[k].next_update > closest_next_update: # type: ignore
closest_next_update = self._store[k].next_update
tbd_key = k
if tbd_key:
del self._store[tbd_key]
del self._issuers_map[tbd_key]
else:
first_key = list(self._store.keys())[0]
del self._store[first_key]
del self._issuers_map[first_key]
peer_fingerprint: str = _str_fingerprint_of(peer_certificate)
self._store[peer_fingerprint] = ocsp_response
self._issuers_map[peer_fingerprint] = issuer_certificate
self._failure_count = 0
self._timings.append(datetime.datetime.now())
if len(self._timings) >= self._max_size:
self._timings.pop(0)
async def verify(
r: PreparedRequest,
strict: bool = False,
timeout: float | int = 0.2,
proxies: ProxyType | None = None,
resolver: AsyncBaseResolver | None = None,
happy_eyeballs: bool | int = False,
cache: InMemoryRevocationStatus | None = None,
) -> None:
conn_info: ConnectionInfo | None = r.conn_info
# we can't do anything in that case.
if conn_info is None or conn_info.certificate_der is None or conn_info.certificate_dict is None:
return
endpoints: list[str] = [ # type: ignore
# exclude non-HTTP endpoint. like ldap.
ep # type: ignore
for ep in list(conn_info.certificate_dict.get("OCSP", [])) # type: ignore
if ep.startswith("http://") # type: ignore
]
# well... not all issued certificate have a OCSP entry. e.g. mkcert.
if not endpoints:
return
if cache is None:
cache = InMemoryRevocationStatus()
peer_certificate = _parse_x509_der_cached(conn_info.certificate_der)
cached_response = cache.check(peer_certificate)
if cached_response is not None:
issuer_certificate = cache.get_issuer_of(peer_certificate)
if issuer_certificate:
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
if cached_response.response_status == OCSPResponseStatus.SUCCESSFUL:
if cached_response.certificate_status == OCSPCertStatus.REVOKED:
r.ocsp_verified = False
raise SSLError(
(
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
f"by issuer ({readable_revocation_reason(cached_response.revocation_reason) or 'unspecified'}). "
"You should avoid trying to request anything from it as the remote has been compromised. ",
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
"for more information.",
)
)
elif cached_response.certificate_status == OCSPCertStatus.UNKNOWN:
r.ocsp_verified = False
if strict is True:
raise SSLError(
f"Unable to establish a secure connection to {r.url} because the issuer does not know "
"whether certificate is valid or not. This error occurred because you enabled strict mode "
"for the OCSP / Revocation check."
)
else:
r.ocsp_verified = True
return
async with cache.lock(peer_certificate):
# this feature, by default, is reserved for a reasonable usage.
if not strict:
if cache.failure_count >= 4:
return
mean_rate_sec = cache.rate()
cache_count = len(cache)
if cache_count >= 10 and mean_rate_sec <= 1.0:
cache.hold = True
if cache.hold:
return
# some corporate environment
# have invalid OCSP implementation
# they use a cert that IS NOT in the chain
# to sign the response. It's weird but true.
# see https://github.com/jawah/niquests/issues/274
ignore_signature_without_strict = ipaddress.ip_address(conn_info.destination_address[0]).is_private or bool(proxies)
verify_signature = strict is True or ignore_signature_without_strict is False
# why are we doing this twice?
# because using Semaphore to prevent concurrent
# revocation check have a heavy toll!
# todo: respect DRY
cached_response = cache.check(peer_certificate)
if cached_response is not None:
issuer_certificate = cache.get_issuer_of(peer_certificate)
if issuer_certificate:
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
if cached_response.response_status == OCSPResponseStatus.SUCCESSFUL:
if cached_response.certificate_status == OCSPCertStatus.REVOKED:
r.ocsp_verified = False
raise SSLError(
(
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
f"by issuer ({readable_revocation_reason(cached_response.revocation_reason) or 'unspecified'}). "
"You should avoid trying to request anything from it as the remote has been compromised. ",
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
"for more information.",
)
)
elif cached_response.certificate_status == OCSPCertStatus.UNKNOWN:
r.ocsp_verified = False
if strict is True:
raise SSLError(
f"Unable to establish a secure connection to {r.url} because the issuer does not know "
"whether certificate is valid or not. This error occurred because you enabled strict mode "
"for the OCSP / Revocation check."
)
else:
r.ocsp_verified = True
return
from .....async_session import AsyncSession
async with AsyncSession(resolver=resolver, happy_eyeballs=happy_eyeballs) as session:
session.trust_env = False
if proxies:
session.proxies = proxies
# When using Python native capabilities, you won't have the issuerCA DER by default (Python 3.7 to 3.9).
# Unfortunately! But no worries, we can circumvent it! (Python 3.10+ is not concerned anymore)
# Three ways are valid to fetch it (in order of preference, safest to riskiest):
# - The issuer can be (but unlikely) a root CA.
# - Retrieve it by asking it from the TLS layer.
# - Downloading it using specified caIssuers from the peer certificate.
if conn_info.issuer_certificate_der is None:
# It could be a root (self-signed) certificate. Or a previously seen issuer.
issuer_certificate = cache.get_issuer_of(peer_certificate)
hint_ca_issuers: list[str] = [
ep # type: ignore
for ep in list(conn_info.certificate_dict.get("caIssuers", [])) # type: ignore
if ep.startswith("http://") # type: ignore
]
if issuer_certificate is None and hint_ca_issuers:
try:
raw_intermediary_response = await session.get(hint_ca_issuers[0])
except RequestException as e:
if is_cancelled_error_root_cause(e):
return
except asyncio.CancelledError: # don't raise any error or warnings!
return
else:
if raw_intermediary_response.status_code and 300 > raw_intermediary_response.status_code >= 200:
raw_intermediary_content = raw_intermediary_response.content
if raw_intermediary_content is not None:
# binary DER
if b"-----BEGIN CERTIFICATE-----" not in raw_intermediary_content:
issuer_certificate = Certificate(raw_intermediary_content)
# b64 PEM
elif b"-----BEGIN CERTIFICATE-----" in raw_intermediary_content:
issuer_certificate = Certificate(
ssl.PEM_cert_to_DER_cert(raw_intermediary_content.decode())
)
# Well! We're out of luck. No further should we go.
if issuer_certificate is None:
cache.incr_failure()
if strict:
warnings.warn(
(
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate "
"via OCSP. You are seeing this warning due to enabling strict mode for OCSP / "
"Revocation check. Reason: Remote did not provide any intermediary certificate."
),
SecurityWarning,
)
return
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
else:
issuer_certificate = Certificate(conn_info.issuer_certificate_der)
try:
req = OCSPRequest(peer_certificate.public_bytes(), issuer_certificate.public_bytes())
except ValueError:
if strict:
warnings.warn(
(
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
"Reason: The X509 OCSP generator failed to assemble the request."
),
SecurityWarning,
)
return
try:
ocsp_http_response = await session.post(
endpoints[randint(0, len(endpoints) - 1)],
data=req.public_bytes(),
headers={"Content-Type": "application/ocsp-request"},
timeout=timeout,
)
except RequestException as e:
if is_cancelled_error_root_cause(e):
return
cache.incr_failure()
if strict:
warnings.warn(
(
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
f"Reason: {e}"
),
SecurityWarning,
)
return
except asyncio.CancelledError: # don't raise any error or warnings!
return
if ocsp_http_response.status_code and 300 > ocsp_http_response.status_code >= 200:
if ocsp_http_response.content is None:
return
try:
ocsp_resp = OCSPResponse(ocsp_http_response.content)
except ValueError:
if strict:
warnings.warn(
(
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
"Reason: The X509 OCSP parser failed to read the response"
),
SecurityWarning,
)
return
if verify_signature:
# Verify the signature of the OCSP response with issuer public key
try:
if not ocsp_resp.authenticate_for(issuer_certificate.public_bytes()): # type: ignore[attr-defined]
raise SSLError(
f"Unable to establish a secure connection to {r.url} "
"because the OCSP response received has been tampered. "
"You could be targeted by a MITM attack."
)
except ValueError:
if strict:
warnings.warn(
(
f"Unable to insure that the remote peer ({r.url}) has a currently "
"valid certificate via OCSP. You are seeing this warning due to "
"enabling strict mode for OCSP / Revocation check. "
"Reason: The X509 OCSP response is signed using an unsupported algorithm."
),
SecurityWarning,
)
cache.save(peer_certificate, issuer_certificate, ocsp_resp)
if ocsp_resp.response_status == OCSPResponseStatus.SUCCESSFUL:
if ocsp_resp.certificate_status == OCSPCertStatus.REVOKED:
r.ocsp_verified = False
raise SSLError(
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
f"by issuer ({readable_revocation_reason(ocsp_resp.revocation_reason) or 'unspecified'}). "
"You should avoid trying to request anything from it as the remote has been compromised. "
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
"for more information."
)
if ocsp_resp.certificate_status == OCSPCertStatus.UNKNOWN:
r.ocsp_verified = False
if strict is True:
raise SSLError(
f"Unable to establish a secure connection to {r.url} because the issuer does not know whether "
"certificate is valid or not. This error occurred because you enabled strict mode for "
"the OCSP / Revocation check."
)
else:
r.ocsp_verified = True
else:
if strict:
warnings.warn(
(
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
f"OCSP Server Status: {ocsp_resp.response_status}"
),
SecurityWarning,
)
else:
if strict:
warnings.warn(
(
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
f"OCSP Server Status: {str(ocsp_http_response)}"
),
SecurityWarning,
)
__all__ = ("verify",)

View File

@@ -0,0 +1,325 @@
from __future__ import annotations
import typing
from ..._constant import DEFAULT_RETRIES
from ...adapters import BaseAdapter
from ...models import PreparedRequest, Response
from ...packages.urllib3.exceptions import MaxRetryError
from ...packages.urllib3.response import BytesQueueBuffer
from ...packages.urllib3.response import HTTPResponse as BaseHTTPResponse
from ...packages.urllib3.util import Timeout as TimeoutSauce
from ...packages.urllib3.util.retry import Retry
from ...structures import CaseInsensitiveDict
if typing.TYPE_CHECKING:
from ...typing import ProxyType, RetryType, TLSClientCertType, TLSVerifyType, WSGIApp
from io import BytesIO
class _WSGIRawIO:
"""File-like wrapper around a WSGI response iterator for streaming."""
def __init__(self, generator: typing.Generator[bytes, None, None], headers: list[tuple[str, str]]) -> None:
self._generator = generator
self._buffer = BytesQueueBuffer()
self._closed = False
self.headers = headers
self.extension: typing.Any = None
def read(
self,
amt: int | None = None,
decode_content: bool = True,
) -> bytes:
if self._closed:
return b""
if amt is None or amt < 0:
# Read all remaining
for chunk in self._generator:
self._buffer.put(chunk)
return self._buffer.get(len(self._buffer))
# Read specific amount
while len(self._buffer) < amt:
try:
self._buffer.put(next(self._generator))
except StopIteration:
break
if len(self._buffer) == 0:
return b""
return self._buffer.get(min(amt, len(self._buffer)))
def stream(self, amt: int, decode_content: bool = True) -> typing.Generator[bytes, None, None]:
"""Iterate over chunks of the response."""
while True:
chunk = self.read(amt)
if not chunk:
break
yield chunk
def close(self) -> None:
self._closed = True
if hasattr(self._generator, "close"):
self._generator.close()
def __iter__(self) -> typing.Iterator[bytes]:
return self
def __next__(self) -> bytes:
chunk = self.read(8192)
if not chunk:
raise StopIteration
return chunk
class WebServerGatewayInterface(BaseAdapter):
"""Adapter for making requests to WSGI applications directly."""
def __init__(self, app: WSGIApp, max_retries: RetryType = DEFAULT_RETRIES) -> None:
"""
Initialize the WSGI adapter.
:param app: A WSGI application callable.
:param max_retries: Maximum number of retries for requests.
"""
super().__init__()
self.app = app
if isinstance(max_retries, Retry):
self.max_retries = max_retries
else:
self.max_retries = Retry.from_int(max_retries)
def __repr__(self) -> str:
return "<WSGIAdapter />"
def send(
self,
request: PreparedRequest,
stream: bool = False,
timeout: int | float | tuple | TimeoutSauce | None = None,
verify: TLSVerifyType = True,
cert: TLSClientCertType | None = None,
proxies: ProxyType | None = None,
on_post_connection: typing.Callable[[typing.Any], None] | None = None,
on_upload_body: typing.Callable[[int, int | None, bool, bool], None] | None = None,
on_early_response: typing.Callable[[Response], None] | None = None,
multiplexed: bool = False,
) -> Response:
"""Send a PreparedRequest to the WSGI application."""
if isinstance(timeout, tuple):
if len(timeout) == 3:
timeout = timeout[2] or timeout[0] # prefer total, fallback connect
else:
timeout = timeout[0] # use connect
elif isinstance(timeout, TimeoutSauce):
timeout = timeout.total or timeout.connect_timeout
retries = self.max_retries
method = request.method or "GET"
while True:
try:
response = self._do_send(request, stream)
except Exception as err:
try:
retries = retries.increment(method, request.url, error=err)
except MaxRetryError:
raise
retries.sleep()
continue
# we rely on the urllib3 implementation for retries
# so we basically mock a response to get it to work
base_response = BaseHTTPResponse(
body=b"",
headers=response.headers,
status=response.status_code,
request_method=request.method,
request_url=request.url,
)
# Check if we should retry based on status code
has_retry_after = bool(response.headers.get("Retry-After"))
if retries.is_retry(method, response.status_code, has_retry_after):
try:
retries = retries.increment(method, request.url, response=base_response)
except MaxRetryError:
if retries.raise_on_status:
raise
return response
retries.sleep(base_response)
continue
return response
def _do_send_sse(self, request: PreparedRequest) -> Response:
"""Handle SSE requests by wrapping the WSGI response with SSE parsing."""
from urllib.parse import urlparse
from ._sse import WSGISSEExtension
# Convert scheme: sse:// -> https://, psse:// -> http://
parsed = urlparse(request.url)
if parsed.scheme == "sse":
http_scheme = "https"
else:
http_scheme = "http"
original_url = request.url
request.url = request.url.replace(f"{parsed.scheme}://", f"{http_scheme}://", 1) # type: ignore[union-attr,str-bytes-safe]
environ = self._create_environ(request)
request.url = original_url
status_code = None
response_headers: list[tuple[str, str]] = []
def start_response(status: str, headers: list[tuple[str, str]], exc_info=None):
nonlocal status_code, response_headers
status_code = int(status.split(" ", 1)[0])
response_headers = headers
result = self.app(environ, start_response)
def generate():
try:
yield from result
finally:
if hasattr(result, "close"):
result.close()
ext = WSGISSEExtension(generate())
response = Response()
response.status_code = status_code
response.headers = CaseInsensitiveDict(response_headers)
response.request = request
response.url = original_url
response.encoding = response.headers.get("content-type", "utf-8") # type: ignore[assignment]
raw_io = _WSGIRawIO(iter([]), response_headers) # type: ignore[arg-type]
raw_io.extension = ext
response.raw = raw_io # type: ignore
response._content = False
response._content_consumed = False
return response
def _do_send(self, request: PreparedRequest, stream: bool) -> Response:
"""Perform the actual WSGI request."""
from urllib.parse import urlparse
parsed = urlparse(request.url)
if parsed.scheme in ("ws", "wss"):
raise NotImplementedError("WebSocket is not supported over WSGI")
if parsed.scheme in ("sse", "psse"):
return self._do_send_sse(request)
environ = self._create_environ(request)
status_code = None
response_headers: list[tuple[str, str]] = []
def start_response(status: str, headers: list[tuple[str, str]], exc_info=None):
nonlocal status_code, response_headers
status_code = int(status.split(" ", 1)[0])
response_headers = headers
result = self.app(environ, start_response)
response = Response()
response.status_code = status_code
response.headers = CaseInsensitiveDict(response_headers)
response.request = request
response.url = request.url
response.encoding = response.headers.get("content-type", "utf-8") # type: ignore[assignment]
# Wrap the WSGI iterator for streaming
def generate():
try:
yield from result
finally:
if hasattr(result, "close"):
result.close()
response.raw = _WSGIRawIO(generate(), response_headers) # type: ignore
if stream:
response._content = False # Indicate content not yet consumed
response._content_consumed = False
else:
# Consume all content immediately
body_chunks: list[bytes] = []
try:
for chunk in result:
body_chunks.append(chunk)
finally:
if hasattr(result, "close"):
result.close()
response._content = b"".join(body_chunks)
return response
def _create_environ(self, request: PreparedRequest) -> dict:
"""Create a WSGI environ dict from a PreparedRequest."""
from urllib.parse import unquote, urlparse
parsed = urlparse(request.url)
body = request.body or b""
if isinstance(body, str):
body = body.encode("utf-8")
elif isinstance(body, typing.Iterable) and not isinstance(body, (str, bytes, bytearray, tuple)):
tmp = b""
for chunk in body:
tmp += chunk # type: ignore[operator]
body = tmp
environ = {
"REQUEST_METHOD": request.method,
"SCRIPT_NAME": "",
"PATH_INFO": unquote(parsed.path) or "/",
"QUERY_STRING": parsed.query or "",
"SERVER_NAME": parsed.hostname or "localhost",
"SERVER_PORT": str(parsed.port or (443 if parsed.scheme == "https" else 80)),
"SERVER_PROTOCOL": "HTTP/1.1",
"wsgi.version": (1, 0),
"wsgi.url_scheme": parsed.scheme or "http",
"wsgi.input": BytesIO(body), # type: ignore[arg-type]
"wsgi.errors": BytesIO(),
"wsgi.multithread": True,
"wsgi.multiprocess": True,
"wsgi.run_once": False,
"CONTENT_LENGTH": str(len(body)), # type: ignore[arg-type]
}
if request.headers:
for key, value in request.headers.items():
key_upper = key.upper().replace("-", "_")
if key_upper == "CONTENT_TYPE":
environ["CONTENT_TYPE"] = value
elif key_upper == "CONTENT_LENGTH":
environ["CONTENT_LENGTH"] = value
else:
environ[f"HTTP_{key_upper}"] = value
return environ
def close(self) -> None:
"""Clean up adapter resources."""
pass

View File

@@ -0,0 +1,821 @@
from __future__ import annotations
import asyncio
import contextlib
import threading
import typing
from concurrent.futures import Future
from ...._constant import DEFAULT_RETRIES
from ....adapters import AsyncBaseAdapter, BaseAdapter
from ....exceptions import ConnectTimeout, ReadTimeout
from ....models import AsyncResponse, PreparedRequest, Response
from ....packages.urllib3._async.response import AsyncHTTPResponse as BaseHTTPResponse
from ....packages.urllib3.contrib.ssa._timeout import timeout as asyncio_timeout
from ....packages.urllib3.exceptions import MaxRetryError
from ....packages.urllib3.response import BytesQueueBuffer
from ....packages.urllib3.util import Timeout as TimeoutSauce
from ....packages.urllib3.util.retry import Retry
from ....structures import CaseInsensitiveDict
from ....utils import _swap_context
if typing.TYPE_CHECKING:
from ....typing import ASGIApp, ASGIMessage, ProxyType, RetryType, TLSClientCertType, TLSVerifyType
class _ASGIRawIO:
"""Async file-like wrapper around an ASGI response for true async streaming."""
def __init__(
self,
response_queue: asyncio.Queue[ASGIMessage | None],
response_complete: asyncio.Event,
timeout: float | None = None,
) -> None:
self._response_queue = response_queue
self._response_complete = response_complete
self._timeout = timeout
self._buffer = BytesQueueBuffer()
self._closed = False
self._finished = False
self._task: asyncio.Task | None = None
self.headers: dict | None = None
self.extension: typing.Any = None
async def read(self, amt: int | None = None, decode_content: bool = True) -> bytes:
if self._closed or self._finished:
return self._buffer.get(len(self._buffer))
if amt is None or amt < 0:
async for chunk in self._async_iter_chunks():
self._buffer.put(chunk)
self._finished = True
return self._buffer.get(len(self._buffer))
while len(self._buffer) < amt and not self._finished:
chunk = await self._get_next_chunk() # type: ignore[assignment]
if chunk is None:
self._finished = True
break
self._buffer.put(chunk)
if len(self._buffer) == 0:
return b""
return self._buffer.get(min(amt, len(self._buffer)))
async def _get_next_chunk(self) -> bytes | None:
try:
async with asyncio_timeout(self._timeout):
message = await self._response_queue.get()
if message is None:
return None
if message["type"] == "http.response.body":
return message.get("body", b"")
return None
except asyncio.TimeoutError:
await self._cancel_task()
raise ReadTimeout("Read timed out while streaming ASGI response")
except asyncio.CancelledError:
return None
async def _async_iter_chunks(self) -> typing.AsyncGenerator[bytes]:
while True:
try:
async with asyncio_timeout(self._timeout):
message = await self._response_queue.get()
except asyncio.TimeoutError:
await self._cancel_task()
raise ReadTimeout("Read timed out while streaming ASGI response")
if message is None:
break
if message["type"] == "http.response.body":
chunk = message.get("body", b"")
if chunk:
yield chunk
async def _cancel_task(self) -> None:
if self._task is not None and not self._task.done():
self._task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._task
def stream(self, amt: int, decode_content: bool = True) -> typing.AsyncGenerator[bytes]:
return self._async_stream(amt)
async def _async_stream(self, amt: int) -> typing.AsyncGenerator[bytes]:
while True:
chunk = await self.read(amt)
if not chunk:
break
yield chunk
def close(self) -> None:
self._closed = True
self._response_complete.set()
def __aiter__(self) -> typing.AsyncIterator[bytes]:
return self._async_iter_self()
async def _async_iter_self(self) -> typing.AsyncIterator[bytes]:
async for chunk in self._async_iter_chunks():
yield chunk
async def __anext__(self) -> bytes:
chunk = await self.read(8192)
if not chunk:
raise StopAsyncIteration
return chunk
class AsyncServerGatewayInterface(AsyncBaseAdapter):
"""Adapter for making requests to ASGI applications directly."""
def __init__(
self,
app: ASGIApp,
raise_app_exceptions: bool = True,
max_retries: RetryType = DEFAULT_RETRIES,
lifespan_state: dict[str, typing.Any] | None = None,
) -> None:
super().__init__()
self.app = app
self.raise_app_exceptions = raise_app_exceptions
self._lifespan_state = lifespan_state
if isinstance(max_retries, Retry):
self.max_retries = max_retries
else:
self.max_retries = Retry.from_int(max_retries)
def __repr__(self) -> str:
return "<ASGIAdapter Native/>"
async def send(
self,
request: PreparedRequest,
stream: bool = False,
timeout: int | float | tuple | TimeoutSauce | None = None,
verify: TLSVerifyType = True,
cert: TLSClientCertType | None = None,
proxies: ProxyType | None = None,
on_post_connection: typing.Callable[[typing.Any], typing.Awaitable[None]] | None = None,
on_upload_body: typing.Callable[[int, int | None, bool, bool], typing.Awaitable[None]] | None = None,
on_early_response: typing.Callable[[Response], typing.Awaitable[None]] | None = None,
multiplexed: bool = False,
) -> AsyncResponse:
"""Send a PreparedRequest to the ASGI application."""
if isinstance(timeout, tuple):
if len(timeout) == 3:
timeout = timeout[2] or timeout[0] # prefer total, fallback connect
else:
timeout = timeout[0] # use connect
elif isinstance(timeout, TimeoutSauce):
timeout = timeout.total or timeout.connect_timeout
retries = self.max_retries
method = request.method or "GET"
while True:
try:
response = await self._do_send(request, stream, timeout)
except Exception as err:
try:
retries = retries.increment(method, request.url, error=err)
except MaxRetryError:
raise
await retries.async_sleep()
continue
# we rely on the urllib3 implementation for retries
# so we basically mock a response to get it to work
base_response = BaseHTTPResponse(
body=b"",
headers=response.headers,
status=response.status_code,
request_method=request.method,
request_url=request.url,
)
# Check if we should retry based on status code
has_retry_after = bool(response.headers.get("Retry-After"))
if retries.is_retry(method, response.status_code, has_retry_after):
try:
retries = retries.increment(method, request.url, response=base_response)
except MaxRetryError:
if retries.raise_on_status:
raise
return response
await retries.async_sleep(base_response)
continue
return response
async def _do_send_ws(self, request: PreparedRequest) -> AsyncResponse:
"""Handle WebSocket requests via the ASGI websocket protocol."""
from ._ws import ASGIWebSocketExtension
scope = self._create_ws_scope(request)
ext = ASGIWebSocketExtension()
await ext.start(self.app, scope)
response = Response()
response.status_code = 101
response.headers = CaseInsensitiveDict({"upgrade": "websocket"})
response.request = request
response.url = request.url
raw_io = _ASGIRawIO(asyncio.Queue(), asyncio.Event())
raw_io.extension = ext
response.raw = raw_io # type: ignore
_swap_context(response)
return response # type: ignore
async def _do_send_sse(
self,
request: PreparedRequest,
timeout: int | float | None,
) -> AsyncResponse:
"""Handle SSE requests via ASGI HTTP streaming with SSE parsing."""
from urllib.parse import urlparse
from ._sse import ASGISSEExtension
# Convert scheme: sse:// -> https://, psse:// -> http://
parsed = urlparse(request.url)
if parsed.scheme == "sse":
http_scheme = "https"
else:
http_scheme = "http"
# Build a modified request with http scheme for scope creation
original_url = request.url
request.url = request.url.replace(f"{parsed.scheme}://", f"{http_scheme}://", 1) # type: ignore[union-attr,str-bytes-safe]
scope = self._create_scope(request)
request.url = original_url
body = request.body or b""
if isinstance(body, str):
body = body.encode("utf-8")
ext = ASGISSEExtension()
start_message = await ext.start(self.app, scope, body) # type: ignore[arg-type]
status_code = start_message["status"]
response_headers = start_message.get("headers", [])
headers_dict = {k.decode("latin-1"): v.decode("latin-1") for k, v in response_headers}
response = Response()
response.status_code = status_code
response.headers = CaseInsensitiveDict(headers_dict)
response.request = request
response.url = original_url
response.encoding = response.headers.get("content-type", "utf-8") # type: ignore[assignment]
raw_io = _ASGIRawIO(asyncio.Queue(), asyncio.Event(), timeout)
raw_io.headers = headers_dict
raw_io.extension = ext
response.raw = raw_io # type: ignore
response._content = False
response._content_consumed = False
_swap_context(response)
return response # type: ignore
async def _do_send(
self,
request: PreparedRequest,
stream: bool,
timeout: int | float | None,
) -> AsyncResponse:
"""Perform the actual ASGI request."""
from urllib.parse import urlparse
parsed = urlparse(request.url)
if parsed.scheme in ("ws", "wss"):
return await self._do_send_ws(request)
if parsed.scheme in ("sse", "psse"):
return await self._do_send_sse(request, timeout)
scope = self._create_scope(request)
body = request.body or b""
body_iter: typing.AsyncIterator[bytes] | typing.AsyncIterator[str] | None = None
# Check if body is an async iterable
if hasattr(body, "__aiter__"):
body_iter = body.__aiter__()
body = b"" # Will be streamed
elif isinstance(body, str):
body = body.encode("utf-8")
request_complete = False
response_complete = asyncio.Event()
response_queue: asyncio.Queue[ASGIMessage | None] = asyncio.Queue()
app_exception: Exception | None = None
async def receive() -> ASGIMessage:
nonlocal request_complete
if request_complete:
await response_complete.wait()
return {"type": "http.disconnect"}
if body_iter is not None:
# Stream chunks from async iterable
try:
chunk = await body_iter.__anext__()
if isinstance(chunk, str):
chunk = chunk.encode("utf-8")
return {"type": "http.request", "body": chunk, "more_body": True}
except StopAsyncIteration:
request_complete = True
return {"type": "http.request", "body": b"", "more_body": False}
else:
# Single body chunk
request_complete = True
return {"type": "http.request", "body": body, "more_body": False}
async def send_func(message: ASGIMessage) -> None:
await response_queue.put(message)
if message["type"] == "http.response.body" and not message.get("more_body", False):
response_complete.set()
async def run_app() -> None:
nonlocal app_exception
try:
await self.app(scope, receive, send_func)
except Exception as ex:
app_exception = ex
finally:
await response_queue.put(None)
if stream:
return await self._stream_response(
request, response_queue, response_complete, run_app, lambda: app_exception, timeout
)
else:
return await self._buffered_response(
request, response_queue, response_complete, run_app, lambda: app_exception, timeout
)
async def _stream_response(
self,
request: PreparedRequest,
response_queue: asyncio.Queue[ASGIMessage | None],
response_complete: asyncio.Event,
run_app: typing.Callable[[], typing.Awaitable[None]],
get_exception: typing.Callable[[], Exception | None],
timeout: float | None,
) -> AsyncResponse:
status_code: int | None = None
response_headers: list[tuple[bytes, bytes]] = []
task = asyncio.create_task(run_app()) # type: ignore[var-annotated,arg-type]
try:
# Wait for http.response.start with timeout
async with asyncio_timeout(timeout):
while True:
message = await response_queue.get()
if message is None:
break
if message["type"] == "http.response.start":
status_code = message["status"]
response_headers = message.get("headers", [])
break
headers_dict = {k.decode("latin-1"): v.decode("latin-1") for k, v in response_headers}
raw_io = _ASGIRawIO(response_queue, response_complete, timeout)
raw_io.headers = headers_dict
raw_io._task = task
response = Response()
response.status_code = status_code
response.headers = CaseInsensitiveDict(headers_dict)
response.request = request
response.url = request.url
response.encoding = response.headers.get("content-type", "utf-8") # type: ignore[assignment]
response.raw = raw_io # type: ignore
response._content = False
response._content_consumed = False
_swap_context(response)
return response # type: ignore
except asyncio.TimeoutError:
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
raise ConnectTimeout("Timed out waiting for ASGI response headers")
except Exception:
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
raise
async def _buffered_response(
self,
request: PreparedRequest,
response_queue: asyncio.Queue[ASGIMessage | None],
response_complete: asyncio.Event,
run_app: typing.Callable[[], typing.Awaitable[None]],
get_exception: typing.Callable[[], Exception | None],
timeout: float | None,
) -> AsyncResponse:
status_code: int | None = None
response_headers: list[tuple[bytes, bytes]] = []
body_chunks: list[bytes] = []
task = asyncio.create_task(run_app()) # type: ignore[var-annotated,arg-type]
try:
async with asyncio_timeout(timeout):
while True:
message = await response_queue.get()
if message is None:
break
if message["type"] == "http.response.start":
status_code = message["status"]
response_headers = message.get("headers", [])
elif message["type"] == "http.response.body":
chunk = message.get("body", b"")
if chunk:
body_chunks.append(chunk)
await task
except asyncio.TimeoutError:
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
raise ReadTimeout("Timed out reading ASGI response body")
except Exception:
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
raise
if self.raise_app_exceptions and get_exception() is not None:
raise get_exception() # type: ignore
headers_dict = {k.decode("latin-1"): v.decode("latin-1") for k, v in response_headers}
response = Response()
response.status_code = status_code
response.headers = CaseInsensitiveDict(headers_dict)
response.request = request
response.url = request.url
response.encoding = response.headers.get("content-type", "utf-8") # type: ignore[assignment]
response._content = b"".join(body_chunks)
response.raw = _ASGIRawIO(response_queue, response_complete, timeout) # type: ignore
response.raw.headers = headers_dict
_swap_context(response)
return response # type: ignore[return-value]
def _create_scope(self, request: PreparedRequest) -> dict:
from urllib.parse import unquote, urlparse
parsed = urlparse(request.url)
headers: list[tuple[bytes, bytes]] = []
if request.headers:
for key, value in request.headers.items():
headers.append((key.lower().encode("latin-1"), value.encode("latin-1")))
scope: dict[str, typing.Any] = {
"type": "http",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"method": request.method,
"scheme": "http",
"path": unquote(parsed.path) or "/",
"query_string": (parsed.query or "").encode("latin-1"), # type: ignore[union-attr]
"root_path": "",
"headers": headers,
"server": (
parsed.hostname or "localhost",
parsed.port or (443 if parsed.scheme == "https" else 80),
),
}
# Include lifespan state if available (for frameworks like Starlette that use it)
if self._lifespan_state is not None:
scope["state"] = self._lifespan_state.copy()
return scope
def _create_ws_scope(self, request: PreparedRequest) -> dict:
from urllib.parse import unquote, urlparse
parsed = urlparse(request.url)
headers: list[tuple[bytes, bytes]] = []
if request.headers:
for key, value in request.headers.items():
headers.append((key.lower().encode("latin-1"), value.encode("latin-1")))
scheme = "wss" if parsed.scheme == "wss" else "ws"
scope: dict[str, typing.Any] = {
"type": "websocket",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"scheme": scheme,
"path": unquote(parsed.path) or "/",
"query_string": (parsed.query or "").encode("latin-1"), # type: ignore[union-attr]
"root_path": "",
"headers": headers,
"server": (
parsed.hostname or "localhost",
parsed.port or (443 if scheme == "wss" else 80),
),
}
if self._lifespan_state is not None:
scope["state"] = self._lifespan_state.copy()
return scope
async def close(self) -> None:
pass
class ThreadAsyncServerGatewayInterface(BaseAdapter):
"""Synchronous adapter for ASGI applications using a background event loop."""
def __init__(
self,
app: ASGIApp,
raise_app_exceptions: bool = True,
max_retries: RetryType = DEFAULT_RETRIES,
) -> None:
super().__init__()
self.app = app
self.raise_app_exceptions = raise_app_exceptions
if isinstance(max_retries, Retry):
self.max_retries = max_retries
else:
self.max_retries = Retry.from_int(max_retries)
self._async_adapter: AsyncServerGatewayInterface | None = None
self._loop: typing.Any = None # asyncio.AbstractEventLoop
self._thread: threading.Thread | None = None
self._started = threading.Event()
self._lifespan_task: asyncio.Task | None = None
self._lifespan_receive_queue: asyncio.Queue[ASGIMessage] | None = None
self._lifespan_startup_complete = threading.Event()
self._lifespan_startup_failed: Exception | None = None
self._lifespan_state: dict[str, typing.Any] = {}
self._startup_lock: threading.Lock = threading.Lock()
def __repr__(self) -> str:
return "<ASGIAdapter Thread/>"
def _ensure_loop_running(self) -> None:
"""Start the background event loop thread if not already running."""
with self._startup_lock:
if self._thread is not None and self._thread.is_alive():
return
import asyncio
def run_loop() -> None:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._lifespan_receive_queue = asyncio.Queue()
self._async_adapter = AsyncServerGatewayInterface(
self.app,
raise_app_exceptions=self.raise_app_exceptions,
max_retries=self.max_retries,
lifespan_state=self._lifespan_state,
)
# Start lifespan handler
self._lifespan_task = self._loop.create_task(self._handle_lifespan())
self._started.set()
self._loop.run_forever()
self._thread = threading.Thread(target=run_loop, daemon=True)
self._thread.start()
self._started.wait()
self._lifespan_startup_complete.wait()
if self._lifespan_startup_failed is not None:
raise self._lifespan_startup_failed
async def _handle_lifespan(self) -> None:
"""Handle ASGI lifespan protocol."""
scope = {
"type": "lifespan",
"asgi": {"version": "3.0"},
"state": self._lifespan_state,
}
startup_complete = asyncio.Event()
shutdown_complete = asyncio.Event()
startup_failed: list[Exception] = []
# Keep local reference to avoid race condition during shutdown
receive_queue = self._lifespan_receive_queue
async def receive() -> ASGIMessage:
return await receive_queue.get() # type: ignore[union-attr]
async def send(message: ASGIMessage) -> None:
if message["type"] == "lifespan.startup.complete":
startup_complete.set()
elif message["type"] == "lifespan.startup.failed":
startup_failed.append(RuntimeError(message.get("message", "Lifespan startup failed")))
startup_complete.set()
elif message["type"] == "lifespan.shutdown.complete":
shutdown_complete.set()
elif message["type"] == "lifespan.shutdown.failed":
shutdown_complete.set()
async def run_lifespan() -> None:
try:
await self.app(scope, receive, send)
except Exception as e:
if not startup_complete.is_set():
startup_failed.append(e)
startup_complete.set()
lifespan_task = asyncio.create_task(run_lifespan())
# Send startup event
await receive_queue.put({"type": "lifespan.startup"}) # type: ignore[union-attr]
await startup_complete.wait()
if startup_failed:
self._lifespan_startup_failed = startup_failed[0]
self._lifespan_startup_complete.set()
# Wait for shutdown signal (loop.stop() will cancel this)
try:
await asyncio.Future() # Wait forever until canceled
except (asyncio.CancelledError, GeneratorExit):
pass
# Send shutdown event - must happen before loop stops
if receive_queue is not None:
try:
await receive_queue.put({"type": "lifespan.shutdown"})
await asyncio.wait_for(shutdown_complete.wait(), timeout=5.0)
except (asyncio.TimeoutError, asyncio.CancelledError, RuntimeError):
pass
lifespan_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await lifespan_task
def _do_send_ws(self, request: PreparedRequest) -> Response:
"""Handle WebSocket requests synchronously via the background loop."""
from .._ws import ThreadASGIWebSocketExtension
self._ensure_loop_running()
future: Future[Response] = Future()
async def run_ws() -> None:
try:
result = await self._async_adapter._do_send_ws(request) # type: ignore[union-attr]
_swap_context(result)
# Wrap the async WS extension in a sync wrapper
async_ext = result.raw.extension # type: ignore[union-attr]
result.raw.extension = ThreadASGIWebSocketExtension(async_ext, self._loop) # type: ignore[union-attr]
future.set_result(result) # type: ignore[arg-type]
except Exception as e:
future.set_exception(e)
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(run_ws()))
return future.result()
def _do_send_sse(self, request: PreparedRequest, timeout: int | float | None = None) -> Response:
"""Handle SSE requests synchronously via the background loop."""
from .._sse import ThreadASGISSEExtension
self._ensure_loop_running()
future: Future[Response] = Future()
async def run_sse() -> None:
try:
result = await self._async_adapter._do_send_sse(request, timeout) # type: ignore[union-attr]
_swap_context(result)
# Wrap the async SSE extension in a sync wrapper
async_ext = result.raw.extension # type: ignore[union-attr]
result.raw.extension = ThreadASGISSEExtension(async_ext, self._loop) # type: ignore[union-attr]
future.set_result(result) # type: ignore[arg-type]
except Exception as e:
future.set_exception(e)
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(run_sse()))
return future.result()
def send(
self,
request: PreparedRequest,
stream: bool = False,
timeout: int | float | tuple | TimeoutSauce | None = None,
verify: TLSVerifyType = True,
cert: TLSClientCertType | None = None,
proxies: ProxyType | None = None,
on_post_connection: typing.Callable[[typing.Any], None] | None = None,
on_upload_body: typing.Callable[[int, int | None, bool, bool], None] | None = None,
on_early_response: typing.Callable[[Response], None] | None = None,
multiplexed: bool = False,
) -> Response:
"""Send a PreparedRequest to the ASGI application synchronously."""
if isinstance(timeout, tuple):
if len(timeout) == 3:
timeout = timeout[2] or timeout[0] # prefer total, fallback connect
else:
timeout = timeout[0] # use connect
elif isinstance(timeout, TimeoutSauce):
timeout = timeout.total or timeout.connect_timeout
from urllib.parse import urlparse
parsed = urlparse(request.url)
if parsed.scheme in ("ws", "wss"):
return self._do_send_ws(request)
if parsed.scheme in ("sse", "psse"):
return self._do_send_sse(request, timeout)
if stream:
raise ValueError(
"ThreadAsyncServerGatewayInterface does not support streaming responses. "
"Use stream=False or migrate to pure async/await implementation."
)
self._ensure_loop_running()
future: Future[Response] = Future()
async def run_send() -> None:
try:
result = await self._async_adapter.send( # type: ignore[union-attr]
request,
stream=False,
timeout=timeout,
verify=verify,
cert=cert,
proxies=proxies,
)
_swap_context(result)
future.set_result(result) # type: ignore[arg-type]
except Exception as e:
future.set_exception(e)
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(run_send()))
return future.result()
def close(self) -> None:
"""Clean up adapter resources."""
if self._loop is not None and self._lifespan_task is not None:
# Signal shutdown and wait for it to complete
shutdown_done = threading.Event()
async def do_shutdown() -> None:
if self._lifespan_task is not None:
self._lifespan_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._lifespan_task
shutdown_done.set()
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(do_shutdown()))
shutdown_done.wait(timeout=6.0)
if self._loop is not None:
self._loop.call_soon_threadsafe(self._loop.stop)
if self._thread is not None:
self._thread.join(timeout=5.0)
self._thread = None
# Clear resources only after thread has stopped
self._loop = None
self._async_adapter = None
self._lifespan_task = None
self._lifespan_receive_queue = None
self._started.clear()
self._lifespan_startup_complete.clear()
self._lifespan_startup_failed = None
__all__ = ("AsyncServerGatewayInterface", "ThreadAsyncServerGatewayInterface")

View File

@@ -0,0 +1,175 @@
from __future__ import annotations
import asyncio
import contextlib
import typing
from ....packages.urllib3.contrib.webextensions.sse import ServerSentEvent
class ASGISSEExtension:
"""Async SSE extension for ASGI applications.
Runs the ASGI app as a normal HTTP streaming request and parses
SSE events from the response body chunks.
"""
def __init__(self) -> None:
self._closed = False
self._buffer: str = ""
self._last_event_id: str | None = None
self._response_queue: asyncio.Queue[dict[str, typing.Any] | None] | None = None
self._task: asyncio.Task[None] | None = None
async def start(
self,
app: typing.Any,
scope: dict[str, typing.Any],
body: bytes = b"",
) -> dict[str, typing.Any]:
"""Start the ASGI app and wait for http.response.start.
Returns the response start message (with status and headers).
"""
self._response_queue = asyncio.Queue()
request_complete = False
response_complete = asyncio.Event()
async def receive() -> dict[str, typing.Any]:
nonlocal request_complete
if request_complete:
await response_complete.wait()
return {"type": "http.disconnect"}
request_complete = True
return {"type": "http.request", "body": body, "more_body": False}
async def send(message: dict[str, typing.Any]) -> None:
await self._response_queue.put(message) # type: ignore[union-attr]
if message["type"] == "http.response.body" and not message.get("more_body", False):
response_complete.set()
async def run_app() -> None:
try:
await app(scope, receive, send)
finally:
await self._response_queue.put(None) # type: ignore[union-attr]
self._task = asyncio.create_task(run_app())
# Wait for http.response.start
while True:
message = await self._response_queue.get()
if message is None:
raise ConnectionError("ASGI app closed before sending response headers")
if message["type"] == "http.response.start":
return message
@property
def closed(self) -> bool:
return self._closed
async def next_payload(self, *, raw: bool = False) -> ServerSentEvent | str | None:
"""Read and parse the next SSE event from the ASGI response stream.
Returns None when the stream ends."""
if self._closed:
raise OSError("The SSE extension is closed")
while True:
# Check if we already have a complete event in the buffer
sep_idx = self._buffer.find("\n\n")
if sep_idx == -1:
sep_idx = self._buffer.find("\r\n\r\n")
if sep_idx != -1:
sep_len = 4
else:
sep_len = 2
else:
sep_len = 2
if sep_idx != -1:
raw_event = self._buffer[:sep_idx]
self._buffer = self._buffer[sep_idx + sep_len :]
event = self._parse_event(raw_event)
if event is not None:
if raw:
return raw_event + "\n\n"
return event
# Empty event (e.g. just comments), try next
continue
# Need more data from the ASGI response queue
chunk = await self._read_chunk()
if chunk is None:
self._closed = True
return None
self._buffer += chunk
async def _read_chunk(self) -> str | None:
"""Read the next body chunk from the ASGI response."""
if self._response_queue is None:
return None
message = await self._response_queue.get()
if message is None:
return None
if message["type"] == "http.response.body":
body = message.get("body", b"")
if body:
return body.decode("utf-8")
return None
def _parse_event(self, raw_event: str) -> ServerSentEvent | None:
"""Parse a raw SSE event block into a ServerSentEvent."""
kwargs: dict[str, typing.Any] = {}
for line in raw_event.splitlines():
if not line or line.startswith(":"):
continue
key, _, value = line.partition(":")
if key not in {"event", "data", "retry", "id"}:
continue
if value.startswith(" "):
value = value[1:]
if key == "id":
if "\u0000" in value:
continue
if key == "retry":
try:
value = int(value) # type: ignore[assignment]
except (ValueError, TypeError):
continue
kwargs[key] = value
if not kwargs:
return None
if "id" not in kwargs and self._last_event_id is not None:
kwargs["id"] = self._last_event_id
event = ServerSentEvent(**kwargs)
if event.id:
self._last_event_id = event.id
return event
async def send_payload(self, buf: str | bytes) -> None:
"""SSE is one-way only."""
raise NotImplementedError("SSE is only one-way. Sending is forbidden.")
async def close(self) -> None:
"""Close the SSE stream and clean up the app task."""
if self._closed:
return
self._closed = True
if self._task is not None and not self._task.done():
self._task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._task

View File

@@ -0,0 +1,98 @@
from __future__ import annotations
import asyncio
import contextlib
import typing
class ASGIWebSocketExtension:
"""Async WebSocket extension for ASGI applications.
Uses the ASGI websocket protocol with send/receive queues to communicate
with the application task.
"""
def __init__(self) -> None:
self._closed = False
self._app_send_queue: asyncio.Queue[dict[str, typing.Any]] = asyncio.Queue()
self._app_receive_queue: asyncio.Queue[dict[str, typing.Any]] = asyncio.Queue()
self._task: asyncio.Task[None] | None = None
async def start(self, app: typing.Any, scope: dict[str, typing.Any]) -> None:
"""Start the ASGI app task and perform the WebSocket handshake."""
async def receive() -> dict[str, typing.Any]:
return await self._app_receive_queue.get()
async def send(message: dict[str, typing.Any]) -> None:
await self._app_send_queue.put(message)
self._task = asyncio.create_task(app(scope, receive, send))
# Send connect and wait for accept
await self._app_receive_queue.put({"type": "websocket.connect"})
message = await self._app_send_queue.get()
if message["type"] == "websocket.close":
self._closed = True
self._task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._task
raise ConnectionError(f"WebSocket connection rejected with code {message.get('code', 1000)}")
if message["type"] != "websocket.accept":
self._closed = True
self._task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._task
raise ConnectionError(f"Unexpected ASGI message during handshake: {message['type']}")
@property
def closed(self) -> bool:
return self._closed
async def next_payload(self) -> str | bytes | None:
"""Await the next message from the ASGI WebSocket app.
Returns None when the app closes the connection."""
if self._closed:
raise OSError("The WebSocket extension is closed")
message = await self._app_send_queue.get()
if message["type"] == "websocket.send":
if "text" in message:
return message["text"]
if "bytes" in message:
return message["bytes"]
return b""
if message["type"] == "websocket.close":
self._closed = True
return None
return None
async def send_payload(self, buf: str | bytes) -> None:
"""Send a message to the ASGI WebSocket app."""
if self._closed:
raise OSError("The WebSocket extension is closed")
if isinstance(buf, (bytes, bytearray)):
await self._app_receive_queue.put({"type": "websocket.receive", "bytes": bytes(buf)})
else:
await self._app_receive_queue.put({"type": "websocket.receive", "text": buf})
async def close(self) -> None:
"""Close the WebSocket and clean up the app task."""
if self._closed:
return
self._closed = True
await self._app_receive_queue.put({"type": "websocket.disconnect", "code": 1000})
if self._task is not None and not self._task.done():
self._task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._task

View File

@@ -0,0 +1,175 @@
from __future__ import annotations
import typing
from concurrent.futures import Future
from ...packages.urllib3.contrib.webextensions.sse import ServerSentEvent
if typing.TYPE_CHECKING:
import asyncio
from ._async._sse import ASGISSEExtension
class WSGISSEExtension:
"""SSE extension for WSGI applications.
Reads from a WSGI response iterator, buffers text, and parses SSE events.
"""
def __init__(self, generator: typing.Generator[bytes, None, None]) -> None:
self._generator = generator
self._closed = False
self._buffer: str = ""
self._last_event_id: str | None = None
@property
def closed(self) -> bool:
return self._closed
def next_payload(self, *, raw: bool = False) -> ServerSentEvent | str | None:
"""Read and parse the next SSE event from the WSGI response.
Returns None when the stream ends."""
if self._closed:
raise OSError("The SSE extension is closed")
while True:
# Check if we already have a complete event in the buffer
sep_idx = self._buffer.find("\n\n")
if sep_idx == -1:
sep_idx = self._buffer.find("\r\n\r\n")
if sep_idx != -1:
sep_len = 4
else:
sep_len = 2
else:
sep_len = 2
if sep_idx != -1:
raw_event = self._buffer[:sep_idx]
self._buffer = self._buffer[sep_idx + sep_len :]
event = self._parse_event(raw_event)
if event is not None:
if raw:
return raw_event + "\n\n"
return event
# Empty event (e.g. just comments), try next
continue
# Need more data
chunk = self._read_chunk()
if chunk is None:
self._closed = True
return None
self._buffer += chunk
def _read_chunk(self) -> str | None:
"""Read the next chunk from the WSGI response iterator."""
try:
chunk = next(self._generator)
if chunk:
return chunk.decode("utf-8")
return None
except StopIteration:
return None
def _parse_event(self, raw_event: str) -> ServerSentEvent | None:
"""Parse a raw SSE event block into a ServerSentEvent."""
kwargs: dict[str, typing.Any] = {}
for line in raw_event.splitlines():
if not line or line.startswith(":"):
continue
key, _, value = line.partition(":")
if key not in {"event", "data", "retry", "id"}:
continue
if value.startswith(" "):
value = value[1:]
if key == "id":
if "\u0000" in value:
continue
if key == "retry":
try:
value = int(value) # type: ignore[assignment]
except (ValueError, TypeError):
continue
kwargs[key] = value
if not kwargs:
return None
if "id" not in kwargs and self._last_event_id is not None:
kwargs["id"] = self._last_event_id
event = ServerSentEvent(**kwargs)
if event.id:
self._last_event_id = event.id
return event
def send_payload(self, buf: str | bytes) -> None:
"""SSE is one-way only."""
raise NotImplementedError("SSE is only one-way. Sending is forbidden.")
def close(self) -> None:
"""Close the stream and release resources."""
if self._closed:
return
self._closed = True
if hasattr(self._generator, "close"):
self._generator.close()
class ThreadASGISSEExtension:
"""Synchronous SSE extension wrapping an async ASGISSEExtension.
Delegates all operations to the async extension on a background event loop,
blocking the calling thread via concurrent.futures.Future.
"""
def __init__(self, async_ext: ASGISSEExtension, loop: asyncio.AbstractEventLoop) -> None:
self._async_ext = async_ext
self._loop = loop
@property
def closed(self) -> bool:
return self._async_ext.closed
def next_payload(self, *, raw: bool = False) -> ServerSentEvent | str | None:
"""Block until the next SSE event arrives from the ASGI app."""
future: Future[ServerSentEvent | str | None] = Future()
async def _do() -> None:
try:
result = await self._async_ext.next_payload(raw=raw)
future.set_result(result)
except Exception as e:
future.set_exception(e)
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(_do()))
return future.result()
def send_payload(self, buf: str | bytes) -> None:
"""SSE is one-way only."""
raise NotImplementedError("SSE is only one-way. Sending is forbidden.")
def close(self) -> None:
"""Close the SSE stream and clean up."""
future: Future[None] = Future()
async def _do() -> None:
try:
await self._async_ext.close()
future.set_result(None)
except Exception as e:
future.set_exception(e)
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(_do()))
future.result()

View File

@@ -0,0 +1,67 @@
from __future__ import annotations
import typing
from concurrent.futures import Future
if typing.TYPE_CHECKING:
import asyncio
from ._async._ws import ASGIWebSocketExtension
class ThreadASGIWebSocketExtension:
"""Synchronous WebSocket extension wrapping an async ASGIWebSocketExtension.
Delegates all operations to the async extension on a background event loop,
blocking the calling thread via concurrent.futures.Future.
"""
def __init__(self, async_ext: ASGIWebSocketExtension, loop: asyncio.AbstractEventLoop) -> None:
self._async_ext = async_ext
self._loop = loop
@property
def closed(self) -> bool:
return self._async_ext.closed
def next_payload(self) -> str | bytes | None:
"""Block until the next message arrives from the ASGI WebSocket app."""
future: Future[str | bytes | None] = Future()
async def _do() -> None:
try:
result = await self._async_ext.next_payload()
future.set_result(result)
except Exception as e:
future.set_exception(e)
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(_do()))
return future.result()
def send_payload(self, buf: str | bytes) -> None:
"""Send a message to the ASGI WebSocket app."""
future: Future[None] = Future()
async def _do() -> None:
try:
await self._async_ext.send_payload(buf)
future.set_result(None)
except Exception as e:
future.set_exception(e)
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(_do()))
future.result()
def close(self) -> None:
"""Close the WebSocket and clean up."""
future: Future[None] = Future()
async def _do() -> None:
try:
await self._async_ext.close()
future.set_result(None)
except Exception as e:
future.set_exception(e)
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(_do()))
future.result()

View File

@@ -0,0 +1,100 @@
from __future__ import annotations
import socket
import typing
from urllib.parse import unquote
from ..._constant import DEFAULT_POOLBLOCK
from ...adapters import HTTPAdapter
from ...exceptions import RequestException
from ...packages.urllib3.connection import HTTPConnection
from ...packages.urllib3.connectionpool import HTTPConnectionPool
from ...packages.urllib3.contrib.webextensions import ServerSideEventExtensionFromHTTP, WebSocketExtensionFromHTTP
from ...packages.urllib3.poolmanager import PoolManager
from ...typing import CacheLayerAltSvcType
from ...utils import select_proxy
class UnixServerSideEventExtensionFromHTTP(ServerSideEventExtensionFromHTTP):
@staticmethod
def implementation() -> str:
return "unix"
@staticmethod
def scheme_to_http_scheme(scheme: str) -> str:
return {"psse": "http+unix"}[scheme]
if WebSocketExtensionFromHTTP is not None:
class UnixWebSocketExtensionFromHTTP(WebSocketExtensionFromHTTP):
@staticmethod
def implementation() -> str:
return "unix"
@staticmethod
def scheme_to_http_scheme(scheme: str) -> str:
return {"ws": "http+unix"}[scheme]
class UnixHTTPConnection(HTTPConnection):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.host: str = unquote(self.host)
self.socket_path = self.host
self.host = self.socket_path.split("/")[-1]
def connect(self):
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.settimeout(self.timeout)
sock.connect(self.socket_path)
self.sock = sock
self._post_conn()
class UnixHTTPConnectionPool(HTTPConnectionPool):
ConnectionCls = UnixHTTPConnection
class UnixAdapter(HTTPAdapter):
def init_poolmanager(
self,
connections: int,
maxsize: int,
block: bool = DEFAULT_POOLBLOCK,
quic_cache_layer: CacheLayerAltSvcType | None = None,
**pool_kwargs: typing.Any,
):
self._pool_connections = connections
self._pool_maxsize = maxsize
self._pool_block = block
self._quic_cache_layer = quic_cache_layer
self.poolmanager = PoolManager(
num_pools=connections,
maxsize=maxsize,
block=block,
preemptive_quic_cache=quic_cache_layer,
**pool_kwargs,
)
self.poolmanager.key_fn_by_scheme["http+unix"] = self.poolmanager.key_fn_by_scheme["http"]
self.poolmanager.pool_classes_by_scheme = {
"http+unix": UnixHTTPConnectionPool,
}
def get_connection(self, url, proxies=None):
proxy = select_proxy(url, proxies)
if proxy:
raise RequestException("unix socket cannot be associated with proxies")
return self.poolmanager.connection_from_url(url)
def request_url(self, request, proxies):
return request.path_url
def close(self):
self.poolmanager.clear()
__all__ = ("UnixAdapter",)

View File

@@ -0,0 +1,101 @@
from __future__ import annotations
import socket
import typing
from urllib.parse import unquote
from ...._constant import DEFAULT_POOLBLOCK
from ....adapters import AsyncHTTPAdapter
from ....exceptions import RequestException
from ....packages.urllib3._async.connection import AsyncHTTPConnection
from ....packages.urllib3._async.connectionpool import AsyncHTTPConnectionPool
from ....packages.urllib3._async.poolmanager import AsyncPoolManager
from ....packages.urllib3.contrib.ssa import AsyncSocket
from ....packages.urllib3.contrib.webextensions._async import (
AsyncServerSideEventExtensionFromHTTP,
AsyncWebSocketExtensionFromHTTP,
)
from ....typing import CacheLayerAltSvcType
from ....utils import select_proxy
class AsyncUnixServerSideEventExtensionFromHTTP(AsyncServerSideEventExtensionFromHTTP):
@staticmethod
def implementation() -> str:
return "unix"
@staticmethod
def scheme_to_http_scheme(scheme: str) -> str:
return {"psse": "http+unix"}[scheme]
if AsyncWebSocketExtensionFromHTTP is not None:
class AsyncUnixWebSocketExtensionFromHTTP(AsyncWebSocketExtensionFromHTTP):
@staticmethod
def implementation() -> str:
return "unix"
@staticmethod
def scheme_to_http_scheme(scheme: str) -> str:
return {"ws": "http+unix"}[scheme]
class AsyncUnixHTTPConnection(AsyncHTTPConnection):
def __init__(self, host, **kwargs):
super().__init__(host, **kwargs)
self.host: str = unquote(self.host)
self.socket_path = self.host
self.host = self.socket_path.split("/")[-1]
async def connect(self):
sock = AsyncSocket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.settimeout(self.timeout)
await sock.connect(self.socket_path)
self.sock = sock
await self._post_conn()
class AsyncUnixHTTPConnectionPool(AsyncHTTPConnectionPool):
ConnectionCls = AsyncUnixHTTPConnection
class AsyncUnixAdapter(AsyncHTTPAdapter):
def init_poolmanager(
self,
connections: int,
maxsize: int,
block: bool = DEFAULT_POOLBLOCK,
quic_cache_layer: CacheLayerAltSvcType | None = None,
**pool_kwargs: typing.Any,
):
self._pool_connections = connections
self._pool_maxsize = maxsize
self._pool_block = block
self._quic_cache_layer = quic_cache_layer
self.poolmanager = AsyncPoolManager(
num_pools=connections,
maxsize=maxsize,
block=block,
preemptive_quic_cache=quic_cache_layer,
**pool_kwargs,
)
self.poolmanager.key_fn_by_scheme["http+unix"] = self.poolmanager.key_fn_by_scheme["http"]
self.poolmanager.pool_classes_by_scheme = {
"http+unix": AsyncUnixHTTPConnectionPool,
}
def get_connection(self, url, proxies=None):
proxy = select_proxy(url, proxies)
if proxy:
raise RequestException("unix socket cannot be associated with proxies")
return self.poolmanager.connection_from_url(url)
def request_url(self, request, proxies):
return request.path_url
__all__ = ("AsyncUnixAdapter",)