fix: 포트 충돌 회피 — note_bridge 8098, intent_service 8099

Jellyfin(8096), OrbStack(8097) 포트 충돌으로 변경.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Hyungi Ahn
2026-03-19 13:53:55 +09:00
parent dc08d29509
commit c2257d3a86
2709 changed files with 619549 additions and 10 deletions

View File

@@ -0,0 +1,159 @@
"""
Python HTTP library with thread-safe connection pooling, file post support, user-friendly, and more
"""
from __future__ import annotations
# Set default logging handler to avoid "No handler found" warnings.
import logging
import typing
import warnings
from logging import NullHandler
from os import environ
from . import exceptions
from ._async.connectionpool import AsyncHTTPConnectionPool, AsyncHTTPSConnectionPool
from ._async.connectionpool import connection_from_url as async_connection_from_url
from ._async.poolmanager import AsyncPoolManager, AsyncProxyManager
from ._async.poolmanager import proxy_from_url as async_proxy_from_url
from ._async.response import AsyncHTTPResponse
from ._collections import HTTPHeaderDict
from ._typing import _TYPE_BODY, _TYPE_FIELDS
from ._version import __version__
from .backend import ConnectionInfo, HttpVersion, ResponsePromise
from .connectionpool import HTTPConnectionPool, HTTPSConnectionPool, connection_from_url
from .contrib.resolver import ResolverDescription
from .contrib.resolver._async import AsyncResolverDescription
from .filepost import encode_multipart_formdata
from .poolmanager import PoolManager, ProxyManager, proxy_from_url
from .response import BaseHTTPResponse, HTTPResponse
from .util.request import make_headers
from .util.retry import Retry
from .util.timeout import Timeout
__author__ = "Andrey Petrov (andrey.petrov@shazow.net)"
__license__ = "MIT"
__version__ = __version__
__all__ = (
"HTTPConnectionPool",
"HTTPHeaderDict",
"HTTPSConnectionPool",
"PoolManager",
"ProxyManager",
"HTTPResponse",
"Retry",
"Timeout",
"add_stderr_logger",
"connection_from_url",
"disable_warnings",
"encode_multipart_formdata",
"make_headers",
"proxy_from_url",
"request",
"BaseHTTPResponse",
"HttpVersion",
"ConnectionInfo",
"ResponsePromise",
"ResolverDescription",
"AsyncHTTPResponse",
"AsyncResolverDescription",
"AsyncHTTPConnectionPool",
"AsyncHTTPSConnectionPool",
"AsyncPoolManager",
"AsyncProxyManager",
"async_proxy_from_url",
"async_connection_from_url",
)
logging.getLogger(__name__).addHandler(NullHandler())
def add_stderr_logger(
level: int = logging.DEBUG,
) -> logging.StreamHandler[typing.TextIO]:
"""
Helper for quickly adding a StreamHandler to the logger. Useful for
debugging.
Returns the handler after adding it.
"""
# This method needs to be in this __init__.py to get the __name__ correct
# even if urllib3 is vendored within another package.
logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
logger.addHandler(handler)
logger.setLevel(level)
logger.debug("Added a stderr logging handler to logger: %s", __name__)
return handler
# ... Clean up.
del NullHandler
if (
environ.get("SSHKEYLOGFILE", None) is not None
or environ.get("QUICLOGDIR", None) is not None
):
warnings.warn( # Defensive: security warning only. not feature.
"urllib3.future detected that development/debug environment variable are set. "
"If you are unaware of it please audit your environment. "
"Variables 'SSHKEYLOGFILE' and 'QUICLOGDIR' can only be set in a non-production environment.",
exceptions.SecurityWarning,
)
# All warning filters *must* be appended unless you're really certain that they
# shouldn't be: otherwise, it's very hard for users to use most Python
# mechanisms to silence them.
# SecurityWarning's always go off by default.
warnings.simplefilter("always", exceptions.SecurityWarning, append=True)
# InsecurePlatformWarning's don't vary between requests, so we keep it default.
warnings.simplefilter("default", exceptions.InsecurePlatformWarning, append=True)
def disable_warnings(category: type[Warning] = exceptions.HTTPWarning) -> None:
"""
Helper for quickly disabling all urllib3 warnings.
"""
warnings.simplefilter("ignore", category)
_DEFAULT_POOL = PoolManager()
def request(
method: str,
url: str,
*,
body: _TYPE_BODY | None = None,
fields: _TYPE_FIELDS | None = None,
headers: typing.Mapping[str, str] | None = None,
preload_content: bool | None = True,
decode_content: bool | None = True,
redirect: bool | None = True,
retries: Retry | bool | int | None = None,
timeout: Timeout | float | int | None = 3,
json: typing.Any | None = None,
) -> HTTPResponse:
"""
A convenience, top-level request method. It uses a module-global ``PoolManager`` instance.
Therefore, its side effects could be shared across dependencies relying on it.
To avoid side effects create a new ``PoolManager`` instance and use it instead.
The method does not accept low-level ``**urlopen_kw``.
"""
return _DEFAULT_POOL.request(
method,
url,
body=body,
fields=fields,
headers=headers,
preload_content=preload_content,
decode_content=decode_content,
redirect=redirect,
retries=retries,
timeout=timeout,
json=json,
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,539 @@
from __future__ import annotations
import io
import json as _json
import sys
import typing
import warnings
from contextlib import asynccontextmanager
from socket import timeout as SocketTimeout
from .._collections import HTTPHeaderDict
from .._typing import _TYPE_BODY
from ..backend._async import AsyncLowLevelResponse
from ..exceptions import (
BaseSSLError,
HTTPError,
IncompleteRead,
ProtocolError,
ReadTimeoutError,
ResponseNotReady,
SSLError,
MustRedialError,
)
from ..response import ContentDecoder, HTTPResponse
from ..util.response import is_fp_closed, BytesQueueBuffer
from ..util.retry import Retry
from .connection import AsyncHTTPConnection
if typing.TYPE_CHECKING:
from email.message import Message
from .._async.connectionpool import AsyncHTTPConnectionPool
from ..contrib.webextensions._async import AsyncExtensionFromHTTP
from ..util._async.traffic_police import AsyncTrafficPolice
class AsyncHTTPResponse(HTTPResponse):
def __init__(
self,
body: _TYPE_BODY = "",
headers: typing.Mapping[str, str] | typing.Mapping[bytes, bytes] | None = None,
status: int = 0,
version: int = 0,
reason: str | None = None,
preload_content: bool = True,
decode_content: bool = True,
original_response: AsyncLowLevelResponse | None = None,
pool: AsyncHTTPConnectionPool | None = None,
connection: AsyncHTTPConnection | None = None,
msg: Message | None = None,
retries: Retry | None = None,
enforce_content_length: bool = True,
request_method: str | None = None,
request_url: str | None = None,
auto_close: bool = True,
police_officer: AsyncTrafficPolice[AsyncHTTPConnection] | None = None,
) -> None:
if isinstance(headers, HTTPHeaderDict):
self.headers = headers
else:
self.headers = HTTPHeaderDict(headers) # type: ignore[arg-type]
try:
self.status = int(status)
except ValueError:
self.status = 0 # merely for tests, was supported due to broken httplib.
self.version = version
self.reason = reason
self.decode_content = decode_content
self._has_decoded_content = False
self._request_url: str | None = request_url
self._retries: Retry | None = None
self._extension: AsyncExtensionFromHTTP | None = None # type: ignore[assignment]
self.retries = retries
self.chunked = False
if "transfer-encoding" in self.headers:
tr_enc = self.headers.get("transfer-encoding", "").lower()
# Don't incur the penalty of creating a list and then discarding it
encodings = (enc.strip() for enc in tr_enc.split(","))
if "chunked" in encodings:
self.chunked = True
self._decoder: ContentDecoder | None = None
self.enforce_content_length = enforce_content_length
self.auto_close = auto_close
self._body = None
self._fp: AsyncLowLevelResponse | typing.IO[typing.Any] | None = None # type: ignore[assignment]
self._original_response = original_response # type: ignore[assignment]
self._fp_bytes_read = 0
if msg is not None:
warnings.warn(
"Passing msg=.. is deprecated and no-op in urllib3.future and is scheduled to be removed in a future major.",
DeprecationWarning,
stacklevel=2,
)
self.msg = msg
if body and isinstance(body, (str, bytes)):
self._body = body
self._pool: AsyncHTTPConnectionPool = pool # type: ignore[assignment]
self._connection: AsyncHTTPConnection = connection # type: ignore[assignment]
if hasattr(body, "read"):
self._fp = body # type: ignore[assignment]
# Are we using the chunked-style of transfer encoding?
self.chunk_left: int | None = None
# Determine length of response
self._request_method: str | None = request_method
self.length_remaining: int | None = self._init_length(self._request_method)
# Used to return the correct amount of bytes for partial read()s
self._decoded_buffer = BytesQueueBuffer()
self._police_officer: AsyncTrafficPolice[AsyncHTTPConnection] | None = (
police_officer # type: ignore[assignment]
)
self._preloaded_content: bool = preload_content
if self._police_officer is not None:
self._police_officer.memorize(self, self._connection)
# we can utilize a ConnectionPool without level-0 PoolManager!
if self._police_officer.parent is not None:
self._police_officer.parent.memorize(self, self._pool)
async def readinto(self, b: bytearray) -> int: # type: ignore[override]
temp = await self.read(len(b))
if len(temp) == 0:
return 0
else:
b[: len(temp)] = temp
return len(temp)
@asynccontextmanager
async def _error_catcher(self) -> typing.AsyncGenerator[None, None]: # type: ignore[override]
"""
Catch low-level python exceptions, instead re-raising urllib3
variants, so that low-level exceptions are not leaked in the
high-level api.
On exit, release the connection back to the pool.
"""
clean_exit = False
try:
try:
yield
except SocketTimeout as e:
# FIXME: Ideally we'd like to include the url in the ReadTimeoutError but
# there is yet no clean way to get at it from this context.
raise ReadTimeoutError(self._pool, None, "Read timed out.") from e # type: ignore[arg-type]
except BaseSSLError as e:
# FIXME: Is there a better way to differentiate between SSLErrors?
if "read operation timed out" not in str(e):
# SSL errors related to framing/MAC get wrapped and reraised here
raise SSLError(e) from e
raise ReadTimeoutError(self._pool, None, "Read timed out.") from e # type: ignore[arg-type]
except (OSError, MustRedialError) as e:
# This includes IncompleteRead.
raise ProtocolError(f"Connection broken: {e!r}", e) from e
# If no exception is thrown, we should avoid cleaning up
# unnecessarily.
clean_exit = True
finally:
# If we didn't terminate cleanly, we need to throw away our
# connection.
if not clean_exit:
# The response may not be closed but we're not going to use it
# anymore so close it now to ensure that the connection is
# released back to the pool.
if self._original_response:
self._original_response.close()
# Closing the response may not actually be sufficient to close
# everything, so if we have a hold of the connection close that
# too.
if self._connection:
await self._connection.close()
# If we hold the original response but it's closed now, we should
# return the connection back to the pool.
if self._original_response and self._original_response.isclosed():
self.release_conn()
async def drain_conn(self) -> None: # type: ignore[override]
"""
Read and discard any remaining HTTP response data in the response connection.
Unread data in the HTTPResponse connection blocks the connection from being released back to the pool.
"""
try:
await self.read(
# Do not spend resources decoding the content unless
# decoding has already been initiated.
decode_content=self._has_decoded_content,
)
except (HTTPError, OSError, BaseSSLError):
pass
@property
def trailers(self) -> HTTPHeaderDict | None:
"""
Retrieve post-response (trailing headers) if any.
This WILL return None if no HTTP Trailer Headers have been received.
"""
if self._fp is None:
return None
if hasattr(self._fp, "trailers"):
return self._fp.trailers
return None
@property
def extension(self) -> AsyncExtensionFromHTTP | None: # type: ignore[override]
return self._extension
async def start_extension(self, item: AsyncExtensionFromHTTP) -> None: # type: ignore[override]
if self._extension is not None:
raise OSError("extension already plugged in")
if not hasattr(self._fp, "_dsa"):
raise ResponseNotReady()
await item.start(self)
self._extension = item
async def json(self) -> typing.Any:
"""
Parses the body of the HTTP response as JSON.
To use a custom JSON decoder pass the result of :attr:`HTTPResponse.data` to the decoder.
This method can raise either `UnicodeDecodeError` or `json.JSONDecodeError`.
Read more :ref:`here <json>`.
"""
data = (await self.data).decode("utf-8")
return _json.loads(data)
@property
async def data(self) -> bytes: # type: ignore[override]
# For backwards-compat with earlier urllib3 0.4 and earlier.
if self._body:
return self._body # type: ignore[return-value]
if self._fp:
return await self.read(cache_content=True)
return None # type: ignore[return-value]
async def _fp_read(self, amt: int | None = None) -> bytes: # type: ignore[override]
"""
Read a response with the thought that reading the number of bytes
larger than can fit in a 32-bit int at a time via SSL in some
known cases leads to an overflow error that has to be prevented
if `amt` or `self.length_remaining` indicate that a problem may
happen.
The known cases:
* 3.8 <= CPython < 3.9.7 because of a bug
https://github.com/urllib3/urllib3/issues/2513#issuecomment-1152559900.
* urllib3 injected with pyOpenSSL-backed SSL-support.
* CPython < 3.10 only when `amt` does not fit 32-bit int.
"""
assert self._fp
c_int_max = 2**31 - 1
if (
(amt and amt > c_int_max)
or (self.length_remaining and self.length_remaining > c_int_max)
) and sys.version_info < (3, 10):
buffer = io.BytesIO()
# Besides `max_chunk_amt` being a maximum chunk size, it
# affects memory overhead of reading a response by this
# method in CPython.
# `c_int_max` equal to 2 GiB - 1 byte is the actual maximum
# chunk size that does not lead to an overflow error, but
# 256 MiB is a compromise.
max_chunk_amt = 2**28
while amt is None or amt != 0:
if amt is not None:
chunk_amt = min(amt, max_chunk_amt)
amt -= chunk_amt
else:
chunk_amt = max_chunk_amt
try:
if isinstance(self._fp, AsyncLowLevelResponse):
data = await self._fp.read(chunk_amt)
else:
data = self._fp.read(chunk_amt) # type: ignore[attr-defined]
except ValueError: # Defensive: overly protective
break # Defensive: can also be an indicator that read ended, should not happen.
if not data:
break
buffer.write(data)
del data # to reduce peak memory usage by `max_chunk_amt`.
return buffer.getvalue()
else:
# StringIO doesn't like amt=None
if isinstance(self._fp, AsyncLowLevelResponse):
return await self._fp.read(amt)
return self._fp.read(amt) if amt is not None else self._fp.read() # type: ignore[no-any-return]
async def _raw_read( # type: ignore[override]
self,
amt: int | None = None,
) -> bytes:
"""
Reads `amt` of bytes from the socket.
"""
if self._fp is None:
return None # type: ignore[return-value]
fp_closed = getattr(self._fp, "closed", False)
async with self._error_catcher():
data = (await self._fp_read(amt)) if not fp_closed else b""
# Mocking library often use io.BytesIO
# which does not auto-close when reading data
# with amt=None.
is_foreign_fp_unclosed = (
amt is None and getattr(self._fp, "closed", False) is False
)
if (amt is not None and amt != 0 and not data) or is_foreign_fp_unclosed:
if is_foreign_fp_unclosed:
self._fp_bytes_read += len(data)
if self.length_remaining is not None:
self.length_remaining -= len(data)
# Platform-specific: Buggy versions of Python.
# Close the connection when no data is returned
#
# This is redundant to what httplib/http.client _should_
# already do. However, versions of python released before
# December 15, 2012 (http://bugs.python.org/issue16298) do
# not properly close the connection in all cases. There is
# no harm in redundantly calling close.
self._fp.close()
if (
self.enforce_content_length
and self.length_remaining is not None
and self.length_remaining != 0
):
# This is an edge case that httplib failed to cover due
# to concerns of backward compatibility. We're
# addressing it here to make sure IncompleteRead is
# raised during streaming, so all calls with incorrect
# Content-Length are caught.
raise IncompleteRead(self._fp_bytes_read, self.length_remaining)
if data and not is_foreign_fp_unclosed:
self._fp_bytes_read += len(data)
if self.length_remaining is not None:
self.length_remaining -= len(data)
return data
async def read1( # type: ignore[override]
self,
amt: int | None = None,
decode_content: bool | None = None,
) -> bytes:
"""
Similar to ``http.client.HTTPResponse.read1`` and documented
in :meth:`io.BufferedReader.read1`, but with an additional parameter:
``decode_content``.
:param amt:
How much of the content to read.
:param decode_content:
If True, will attempt to decode the body based on the
'content-encoding' header.
"""
data = await self.read(
amt=amt or -1,
decode_content=decode_content,
)
if amt is not None and len(data) > amt:
self._decoded_buffer.put(data)
return self._decoded_buffer.get(amt)
return data
async def read( # type: ignore[override]
self,
amt: int | None = None,
decode_content: bool | None = None,
cache_content: bool = False,
) -> bytes:
try:
self._init_decoder()
if decode_content is None:
decode_content = self.decode_content
if amt is not None:
cache_content = False
if amt < 0 and len(self._decoded_buffer):
return self._decoded_buffer.get(len(self._decoded_buffer))
if 0 < amt <= len(self._decoded_buffer):
return self._decoded_buffer.get(amt)
if self._police_officer is not None:
async with self._police_officer.borrow(self):
data = await self._raw_read(amt)
else:
data = await self._raw_read(amt)
if amt and amt < 0:
amt = len(data)
flush_decoder = False
if amt is None:
flush_decoder = True
elif amt != 0 and not data:
flush_decoder = True
if not data and len(self._decoded_buffer) == 0:
return data
if amt is None:
data = self._decode(data, decode_content, flush_decoder)
if cache_content:
self._body = data
else:
# do not waste memory on buffer when not decoding
if not decode_content:
if self._has_decoded_content:
raise RuntimeError(
"Calling read(decode_content=False) is not supported after "
"read(decode_content=True) was called."
)
return data
decoded_data = self._decode(data, decode_content, flush_decoder)
self._decoded_buffer.put(decoded_data)
while len(self._decoded_buffer) < amt and data:
# TODO make sure to initially read enough data to get past the headers
# For example, the GZ file header takes 10 bytes, we don't want to read
# it one byte at a time
if self._police_officer is not None:
async with self._police_officer.borrow(self):
data = await self._raw_read(amt)
else:
data = await self._raw_read(amt)
decoded_data = self._decode(data, decode_content, flush_decoder)
self._decoded_buffer.put(decoded_data)
data = self._decoded_buffer.get(amt)
return data
finally:
if (
self._fp
and hasattr(self._fp, "_eot")
and self._fp._eot
and self._police_officer is not None
):
# an HTTP extension could be live, we don't want to accidentally kill it!
if (
not hasattr(self._fp, "_dsa")
or self._fp._dsa is None
or self._fp._dsa.closed is True
):
self._police_officer.forget(self)
self._police_officer = None
async def stream( # type: ignore[override]
self, amt: int | None = 2**16, decode_content: bool | None = None
) -> typing.AsyncGenerator[bytes, None]:
if self._fp is None:
return
while not is_fp_closed(self._fp) or len(self._decoded_buffer) > 0:
data = await self.read(amt=amt, decode_content=decode_content)
if data:
yield data
async def close(self) -> None: # type: ignore[override]
if self.extension is not None and self.extension.closed is False:
await self.extension.close()
if not self.closed and self._fp:
self._fp.close()
if self._connection:
await self._connection.close()
if not self.auto_close:
io.IOBase.close(self)
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
buffer: list[bytes] = []
async for chunk in self.stream(-1, decode_content=True):
if b"\n" in chunk:
chunks = chunk.split(b"\n")
yield b"".join(buffer) + chunks[0] + b"\n"
for x in chunks[1:-1]:
yield x + b"\n"
if chunks[-1]:
buffer = [chunks[-1]]
else:
buffer = []
else:
buffer.append(chunk)
if buffer:
yield b"".join(buffer)
def __del__(self) -> None:
if not self.closed:
if not self.closed and self._fp:
self._fp.close()
if not self.auto_close:
io.IOBase.close(self)

View File

@@ -0,0 +1,440 @@
from __future__ import annotations
import typing
from collections import OrderedDict
from enum import Enum, auto
from functools import lru_cache
from threading import RLock
if typing.TYPE_CHECKING:
# We can only import Protocol if TYPE_CHECKING because it's a development
# dependency, and is not available at runtime.
from typing_extensions import Protocol
class HasGettableStringKeys(Protocol):
def keys(self) -> typing.Iterator[str]: ...
def __getitem__(self, key: str) -> str: ...
__all__ = ["RecentlyUsedContainer", "HTTPHeaderDict"]
# Key type
_KT = typing.TypeVar("_KT")
# Value type
_VT = typing.TypeVar("_VT")
# Default type
_DT = typing.TypeVar("_DT")
ValidHTTPHeaderSource = typing.Union[
"HTTPHeaderDict",
typing.Mapping[str, str],
typing.Iterable[typing.Tuple[str, str]],
"HasGettableStringKeys",
]
class _Sentinel(Enum):
not_passed = auto()
@lru_cache(maxsize=64)
def _lower_wrapper(string: str) -> str:
"""Reasoning: We are often calling lower on repetitive identical header key. This was unnecessary exhausting!"""
return string.lower()
def ensure_can_construct_http_header_dict(
potential: object,
) -> ValidHTTPHeaderSource | None:
if isinstance(potential, HTTPHeaderDict):
return potential
elif isinstance(potential, typing.Mapping):
# Full runtime checking of the contents of a Mapping is expensive, so for the
# purposes of typechecking, we assume that any Mapping is the right shape.
return typing.cast(typing.Mapping[str, str], potential)
elif isinstance(potential, typing.Iterable):
# Similarly to Mapping, full runtime checking of the contents of an Iterable is
# expensive, so for the purposes of typechecking, we assume that any Iterable
# is the right shape.
return typing.cast(typing.Iterable[typing.Tuple[str, str]], potential)
elif hasattr(potential, "keys") and hasattr(potential, "__getitem__"):
return typing.cast("HasGettableStringKeys", potential)
else:
return None
class RecentlyUsedContainer(typing.Generic[_KT, _VT], typing.MutableMapping[_KT, _VT]):
"""
Provides a thread-safe dict-like container which maintains up to
``maxsize`` keys while throwing away the least-recently-used keys beyond
``maxsize``. Caution: RecentlyUsedContainer is deprecated and scheduled for
removal in a next major of urllib3.future. It has been replaced by a more
suitable implementation in ``urllib3.util.traffic_police``.
:param maxsize:
Maximum number of recent elements to retain.
:param dispose_func:
Every time an item is evicted from the container,
``dispose_func(value)`` is called. Callback which will get called
"""
_container: typing.OrderedDict[_KT, _VT]
_maxsize: int
dispose_func: typing.Callable[[_VT], None] | None
lock: RLock
def __init__(
self,
maxsize: int = 10,
dispose_func: typing.Callable[[_VT], None] | None = None,
) -> None:
super().__init__()
self._maxsize = maxsize
self.dispose_func = dispose_func
self._container = OrderedDict()
self.lock = RLock()
def __getitem__(self, key: _KT) -> _VT:
# Re-insert the item, moving it to the end of the eviction line.
with self.lock:
item = self._container.pop(key)
self._container[key] = item
return item
def __setitem__(self, key: _KT, value: _VT) -> None:
evicted_item = None
with self.lock:
# Possibly evict the existing value of 'key'
try:
# If the key exists, we'll overwrite it, which won't change the
# size of the pool. Because accessing a key should move it to
# the end of the eviction line, we pop it out first.
evicted_item = key, self._container.pop(key)
self._container[key] = value
except KeyError:
# When the key does not exist, we insert the value first so that
# evicting works in all cases, including when self._maxsize is 0
self._container[key] = value
if len(self._container) > self._maxsize:
# If we didn't evict an existing value, and we've hit our maximum
# size, then we have to evict the least recently used item from
# the beginning of the container.
evicted_item = self._container.popitem(last=False)
# After releasing the lock on the pool, dispose of any evicted value.
if evicted_item is not None and self.dispose_func:
_, evicted_value = evicted_item
self.dispose_func(evicted_value)
def __delitem__(self, key: _KT) -> None:
with self.lock:
value = self._container.pop(key)
if self.dispose_func:
self.dispose_func(value)
def __len__(self) -> int:
with self.lock:
return len(self._container)
def __iter__(self) -> typing.NoReturn:
raise NotImplementedError(
"Iteration over this class is unlikely to be threadsafe."
)
def clear(self) -> None:
with self.lock:
# Copy pointers to all values, then wipe the mapping
values = list(self._container.values())
self._container.clear()
if self.dispose_func:
for value in values:
self.dispose_func(value)
def keys(self) -> set[_KT]: # type: ignore[override]
with self.lock:
return set(self._container.keys())
class HTTPHeaderDictItemView(typing.Set[typing.Tuple[str, str]]):
"""
HTTPHeaderDict is unusual for a Mapping[str, str] in that it has two modes of
address.
If we directly try to get an item with a particular name, we will get a string
back that is the concatenated version of all the values:
>>> d['X-Header-Name']
'Value1, Value2, Value3'
However, if we iterate over an HTTPHeaderDict's items, we will optionally combine
these values based on whether combine=True was called when building up the dictionary
>>> d = HTTPHeaderDict({"A": "1", "B": "foo"})
>>> d.add("A", "2", combine=True)
>>> d.add("B", "bar")
>>> list(d.items())
[
('A', '1, 2'),
('B', 'foo'),
('B', 'bar'),
]
This class conforms to the interface required by the MutableMapping ABC while
also giving us the nonstandard iteration behavior we want; items with duplicate
keys, ordered by time of first insertion.
"""
_headers: HTTPHeaderDict
def __init__(self, headers: HTTPHeaderDict) -> None:
self._headers = headers
def __len__(self) -> int:
return len(list(self._headers.iteritems()))
def __iter__(self) -> typing.Iterator[tuple[str, str]]:
return self._headers.iteritems()
def __contains__(self, item: object) -> bool:
if isinstance(item, tuple) and len(item) == 2:
passed_key, passed_val = item
if isinstance(passed_key, str) and isinstance(passed_val, str):
return self._headers._has_value_for_header(passed_key, passed_val)
return False
class HTTPHeaderDict(typing.MutableMapping[str, str]):
"""
:param headers:
An iterable of field-value pairs. Must not contain multiple field names
when compared case-insensitively.
:param kwargs:
Additional field-value pairs to pass in to ``dict.update``.
A ``dict`` like container for storing HTTP Headers.
Field names are stored and compared case-insensitively in compliance with
RFC 7230. Iteration provides the first case-sensitive key seen for each
case-insensitive pair.
Using ``__setitem__`` syntax overwrites fields that compare equal
case-insensitively in order to maintain ``dict``'s api. For fields that
compare equal, instead create a new ``HTTPHeaderDict`` and use ``.add``
in a loop.
If multiple fields that are equal case-insensitively are passed to the
constructor or ``.update``, the behavior is undefined and some will be
lost.
>>> headers = HTTPHeaderDict()
>>> headers.add('Set-Cookie', 'foo=bar')
>>> headers.add('set-cookie', 'baz=quxx')
>>> headers['content-length'] = '7'
>>> headers['SET-cookie']
'foo=bar, baz=quxx'
>>> headers['Content-Length']
'7'
"""
_container: typing.MutableMapping[str, list[str]]
def __init__(self, headers: ValidHTTPHeaderSource | None = None, **kwargs: str):
super().__init__()
self._container = {} # 'dict' is insert-ordered in Python 3.7+
if headers is not None:
if isinstance(headers, HTTPHeaderDict):
self._copy_from(headers)
else:
self.extend(headers)
if kwargs:
self.extend(kwargs)
def __setitem__(self, key: str, val: str) -> None:
# avoid a bytes/str comparison by decoding before httplib
self._container[_lower_wrapper(key)] = [key, val]
def __getitem__(self, key: str) -> str:
if isinstance(key, bytes):
key = key.decode("latin-1")
val = self._container[_lower_wrapper(key)]
return ", ".join(val[1:])
def __delitem__(self, key: str) -> None:
if isinstance(key, bytes):
key = key.decode("latin-1")
del self._container[_lower_wrapper(key)]
def __contains__(self, key: object) -> bool:
if isinstance(key, bytes):
key = key.decode("latin-1")
if isinstance(key, str):
return _lower_wrapper(key) in self._container
return False
def setdefault(self, key: str, default: str = "") -> str:
return super().setdefault(key, default)
def __eq__(self, other: object) -> bool:
maybe_constructable = ensure_can_construct_http_header_dict(other)
if maybe_constructable is None:
return False
else:
other_as_http_header_dict = type(self)(maybe_constructable)
return {_lower_wrapper(k): v for k, v in self.itermerged()} == {
_lower_wrapper(k): v for k, v in other_as_http_header_dict.itermerged()
}
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __len__(self) -> int:
return len(self._container)
def __iter__(self) -> typing.Iterator[str]:
# Only provide the originally cased names
for vals in self._container.values():
yield vals[0]
def discard(self, key: str) -> None:
try:
del self[key]
except KeyError:
pass
def add(self, key: str, val: str, *, combine: bool = False) -> None:
"""Adds a (name, value) pair, doesn't overwrite the value if it already
exists.
If this is called with combine=True, instead of adding a new header value
as a distinct item during iteration, this will instead append the value to
any existing header value with a comma. If no existing header value exists
for the key, then the value will simply be added, ignoring the combine parameter.
>>> headers = HTTPHeaderDict(foo='bar')
>>> headers.add('Foo', 'baz')
>>> headers['foo']
'bar, baz'
>>> list(headers.items())
[('foo', 'bar'), ('foo', 'baz')]
>>> headers.add('foo', 'quz', combine=True)
>>> list(headers.items())
[('foo', 'bar, baz, quz')]
"""
key_lower = _lower_wrapper(key)
new_vals = [key, val]
# Keep the common case aka no item present as fast as possible
vals = self._container.setdefault(key_lower, new_vals)
if new_vals is not vals:
# if there are values here, then there is at least the initial
# key/value pair
if combine:
vals[-1] = vals[-1] + ", " + val
else:
vals.append(val)
def extend(self, *args: ValidHTTPHeaderSource, **kwargs: str) -> None:
"""Generic import function for any type of header-like object.
Adapted version of MutableMapping.update in order to insert items
with self.add instead of self.__setitem__
"""
if len(args) > 1:
raise TypeError(
f"extend() takes at most 1 positional arguments ({len(args)} given)"
)
other = args[0] if len(args) >= 1 else ()
if isinstance(other, HTTPHeaderDict):
for key, val in other.iteritems():
self.add(key, val)
elif isinstance(other, typing.Mapping):
for key, val in other.items():
self.add(key, val)
elif isinstance(other, typing.Iterable):
for key, value in other:
self.add(key, value)
elif hasattr(other, "keys") and hasattr(other, "__getitem__"):
# THIS IS NOT A TYPESAFE BRANCH
# In this branch, the object has a `keys` attr but is not a Mapping or any of
# the other types indicated in the method signature. We do some stuff with
# it as though it partially implements the Mapping interface, but we're not
# doing that stuff safely AT ALL.
for key in other.keys():
self.add(key, other[key])
for key, value in kwargs.items():
self.add(key, value)
@typing.overload
def getlist(self, key: str) -> list[str]: ...
@typing.overload
def getlist(self, key: str, default: _DT) -> list[str] | _DT: ...
def getlist(
self, key: str, default: _Sentinel | _DT = _Sentinel.not_passed
) -> list[str] | _DT:
"""Returns a list of all the values for the named field. Returns an
empty list if the key doesn't exist."""
if isinstance(key, bytes):
key = key.decode("latin-1")
try:
vals = self._container[_lower_wrapper(key)]
except KeyError:
if default is _Sentinel.not_passed:
# _DT is unbound; empty list is instance of List[str]
return []
# _DT is bound; default is instance of _DT
return default
else:
# _DT may or may not be bound; vals[1:] is instance of List[str], which
# meets our external interface requirement of `Union[List[str], _DT]`.
return vals[1:]
# Backwards compatibility for httplib
getheaders = getlist
getallmatchingheaders = getlist
iget = getlist
# Backwards compatibility for http.cookiejar
get_all = getlist
def __repr__(self) -> str:
return f"{type(self).__name__}({dict(self.itermerged())})"
def _copy_from(self, other: HTTPHeaderDict) -> None:
for key in other:
val = other.getlist(key)
self._container[_lower_wrapper(key)] = [key, *val]
def copy(self) -> HTTPHeaderDict:
clone = type(self)()
clone._copy_from(self)
return clone
def iteritems(self) -> typing.Iterator[tuple[str, str]]:
"""Iterate over all header lines, including duplicate ones."""
for key in self:
vals = self._container[_lower_wrapper(key)]
for val in vals[1:]:
yield vals[0], val
def itermerged(self) -> typing.Iterator[tuple[str, str]]:
"""Iterate over all headers, merging duplicate ones together."""
for key in self:
val = self._container[_lower_wrapper(key)]
yield val[0], ", ".join(val[1:])
def items(self) -> HTTPHeaderDictItemView: # type: ignore[override]
return HTTPHeaderDictItemView(self)
def _has_value_for_header(self, header_name: str, potential_value: str) -> bool:
if header_name in self:
return potential_value in self._container[_lower_wrapper(header_name)][1:]
return False

View File

@@ -0,0 +1,255 @@
from __future__ import annotations
import socket
import typing
from enum import IntEnum
class HTTPStatus(IntEnum):
"""HTTP status codes and reason phrases
Status codes from the following RFCs are all observed:
* RFC 7231: Hypertext Transfer Protocol (HTTP/1.1), obsoletes 2616
* RFC 6585: Additional HTTP Status Codes
* RFC 3229: Delta encoding in HTTP
* RFC 4918: HTTP Extensions for WebDAV, obsoletes 2518
* RFC 5842: Binding Extensions to WebDAV
* RFC 7238: Permanent Redirect
* RFC 2295: Transparent Content Negotiation in HTTP
* RFC 2774: An HTTP Extension Framework
* RFC 7725: An HTTP Status Code to Report Legal Obstacles
* RFC 7540: Hypertext Transfer Protocol Version 2 (HTTP/2)
* RFC 2324: Hyper Text Coffee Pot Control Protocol (HTCPCP/1.0)
* RFC 8297: An HTTP Status Code for Indicating Hints
* RFC 8470: Using Early Data in HTTP
"""
phrase: str
description: str
standard: bool
def __new__(
cls, value: int, phrase: str, description: str = "", is_standard: bool = True
) -> HTTPStatus:
obj = int.__new__(cls, value)
obj._value_ = value
obj.phrase = phrase
obj.description = description
obj.standard = is_standard
return obj
# informational
CONTINUE = 100, "Continue", "Request received, please continue"
SWITCHING_PROTOCOLS = (
101,
"Switching Protocols",
"Switching to new protocol; obey Upgrade header",
)
PROCESSING = 102, "Processing"
EARLY_HINTS = 103, "Early Hints"
# success
OK = 200, "OK", "Request fulfilled, document follows"
CREATED = 201, "Created", "Document created, URL follows"
ACCEPTED = (202, "Accepted", "Request accepted, processing continues off-line")
NON_AUTHORITATIVE_INFORMATION = (
203,
"Non-Authoritative Information",
"Request fulfilled from cache",
)
NO_CONTENT = 204, "No Content", "Request fulfilled, nothing follows"
RESET_CONTENT = 205, "Reset Content", "Clear input form for further input"
PARTIAL_CONTENT = 206, "Partial Content", "Partial content follows"
MULTI_STATUS = 207, "Multi-Status"
ALREADY_REPORTED = 208, "Already Reported"
IM_USED = 226, "IM Used"
# redirection
MULTIPLE_CHOICES = (
300,
"Multiple Choices",
"Object has several resources -- see URI list",
)
MOVED_PERMANENTLY = (
301,
"Moved Permanently",
"Object moved permanently -- see URI list",
)
FOUND = 302, "Found", "Object moved temporarily -- see URI list"
SEE_OTHER = 303, "See Other", "Object moved -- see Method and URL list"
NOT_MODIFIED = (304, "Not Modified", "Document has not changed since given time")
USE_PROXY = (
305,
"Use Proxy",
"You must use proxy specified in Location to access this resource",
)
TEMPORARY_REDIRECT = (
307,
"Temporary Redirect",
"Object moved temporarily -- see URI list",
)
PERMANENT_REDIRECT = (
308,
"Permanent Redirect",
"Object moved permanently -- see URI list",
)
# client error
BAD_REQUEST = (400, "Bad Request", "Bad request syntax or unsupported method")
UNAUTHORIZED = (401, "Unauthorized", "No permission -- see authorization schemes")
PAYMENT_REQUIRED = (402, "Payment Required", "No payment -- see charging schemes")
FORBIDDEN = (403, "Forbidden", "Request forbidden -- authorization will not help")
NOT_FOUND = (404, "Not Found", "Nothing matches the given URI")
METHOD_NOT_ALLOWED = (
405,
"Method Not Allowed",
"Specified method is invalid for this resource",
)
NOT_ACCEPTABLE = (406, "Not Acceptable", "URI not available in preferred format")
PROXY_AUTHENTICATION_REQUIRED = (
407,
"Proxy Authentication Required",
"You must authenticate with this proxy before proceeding",
)
REQUEST_TIMEOUT = (408, "Request Timeout", "Request timed out; try again later")
CONFLICT = 409, "Conflict", "Request conflict"
GONE = (410, "Gone", "URI no longer exists and has been permanently removed")
LENGTH_REQUIRED = (411, "Length Required", "Client must specify Content-Length")
PRECONDITION_FAILED = (
412,
"Precondition Failed",
"Precondition in headers is false",
)
REQUEST_ENTITY_TOO_LARGE = (413, "Request Entity Too Large", "Entity is too large")
REQUEST_URI_TOO_LONG = (414, "Request-URI Too Long", "URI is too long")
UNSUPPORTED_MEDIA_TYPE = (
415,
"Unsupported Media Type",
"Entity body in unsupported format",
)
REQUESTED_RANGE_NOT_SATISFIABLE = (
416,
"Requested Range Not Satisfiable",
"Cannot satisfy request range",
)
EXPECTATION_FAILED = (
417,
"Expectation Failed",
"Expect condition could not be satisfied",
)
IM_A_TEAPOT = (
418,
"I'm a Teapot",
"Server refuses to brew coffee because it is a teapot.",
)
MISDIRECTED_REQUEST = (
421,
"Misdirected Request",
"Server is not able to produce a response",
)
UNPROCESSABLE_ENTITY = 422, "Unprocessable Entity"
LOCKED = 423, "Locked"
FAILED_DEPENDENCY = 424, "Failed Dependency"
TOO_EARLY = 425, "Too Early"
UPGRADE_REQUIRED = 426, "Upgrade Required"
PRECONDITION_REQUIRED = (
428,
"Precondition Required",
"The origin server requires the request to be conditional",
)
TOO_MANY_REQUESTS = (
429,
"Too Many Requests",
"The user has sent too many requests in "
'a given amount of time ("rate limiting")',
)
REQUEST_HEADER_FIELDS_TOO_LARGE = (
431,
"Request Header Fields Too Large",
"The server is unwilling to process the request because its header "
"fields are too large",
)
UNAVAILABLE_FOR_LEGAL_REASONS = (
451,
"Unavailable For Legal Reasons",
"The server is denying access to the "
"resource as a consequence of a legal demand",
)
# server errors
INTERNAL_SERVER_ERROR = (
500,
"Internal Server Error",
"Server got itself in trouble",
)
NOT_IMPLEMENTED = (501, "Not Implemented", "Server does not support this operation")
BAD_GATEWAY = (502, "Bad Gateway", "Invalid responses from another server/proxy")
SERVICE_UNAVAILABLE = (
503,
"Service Unavailable",
"The server cannot process the request due to a high load",
)
GATEWAY_TIMEOUT = (
504,
"Gateway Timeout",
"The gateway server did not receive a timely response",
)
HTTP_VERSION_NOT_SUPPORTED = (
505,
"HTTP Version Not Supported",
"Cannot fulfill request",
)
VARIANT_ALSO_NEGOTIATES = 506, "Variant Also Negotiates"
INSUFFICIENT_STORAGE = 507, "Insufficient Storage"
LOOP_DETECTED = 508, "Loop Detected"
NOT_EXTENDED = 510, "Not Extended"
NETWORK_AUTHENTICATION_REQUIRED = (
511,
"Network Authentication Required",
"The client needs to authenticate to gain network access",
)
# another hack to maintain backwards compatibility
# Mapping status codes to official W3C names
responses: typing.Mapping[int, str] = {
v: v.phrase for v in HTTPStatus.__members__.values()
}
# Default value for `blocksize` - a new parameter introduced to
# http.client.HTTPConnection & http.client.HTTPSConnection in Python 3.7
# The maximum TCP packet size is 65535 octets. But OSes can have a recv buffer greater than 65k, so
# passing the highest value does improve responsiveness.
# udp usually is set to 212992
# and tcp usually is set to 131072 (16384 * 8) or (65535 * 2)
try:
# dynamically retrieve the kernel rcvbuf size.
stream = socket.socket(type=socket.SOCK_STREAM)
dgram = socket.socket(type=socket.SOCK_DGRAM)
DEFAULT_BLOCKSIZE: int = stream.getsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF)
UDP_DEFAULT_BLOCKSIZE: int = dgram.getsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF)
stream.close()
dgram.close()
except OSError:
DEFAULT_BLOCKSIZE = 131072
UDP_DEFAULT_BLOCKSIZE = 212992
TCP_DEFAULT_BLOCKSIZE: int = DEFAULT_BLOCKSIZE
UDP_LINUX_GRO: int = 104
UDP_LINUX_SEGMENT: int = 103
# Mozilla TLS recommendations for ciphers
# General-purpose servers with a variety of clients, recommended for almost all systems.
MOZ_INTERMEDIATE_CIPHERS: str = "ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384:DHE-RSA-CHACHA20-POLY1305"
DEFAULT_BACKGROUND_WATCH_WINDOW: float = 5.0
MINIMAL_BACKGROUND_WATCH_WINDOW: float = 0.05
DEFAULT_KEEPALIVE_DELAY: float = 3600.0
DEFAULT_KEEPALIVE_IDLE_WINDOW: float = 60.0
MINIMAL_KEEPALIVE_IDLE_WINDOW: float = 1.0

View File

@@ -0,0 +1,642 @@
from __future__ import annotations
import json as _json
import typing
from urllib.parse import urlencode
from ._async.response import AsyncHTTPResponse
from ._collections import HTTPHeaderDict
from ._typing import _TYPE_ASYNC_BODY, _TYPE_BODY, _TYPE_ENCODE_URL_FIELDS, _TYPE_FIELDS
from .filepost import encode_multipart_formdata
from .response import HTTPResponse
if typing.TYPE_CHECKING:
from typing_extensions import Literal
from .backend import ResponsePromise
__all__ = ["RequestMethods", "AsyncRequestMethods"]
class RequestMethods:
"""
Convenience mixin for classes who implement a :meth:`urlopen` method, such
as :class:`urllib3.HTTPConnectionPool` and
:class:`urllib3.PoolManager`.
Provides behavior for making common types of HTTP request methods and
decides which type of request field encoding to use.
Specifically,
:meth:`.request_encode_url` is for sending requests whose fields are
encoded in the URL (such as GET, HEAD, DELETE).
:meth:`.request_encode_body` is for sending requests whose fields are
encoded in the *body* of the request using multipart or www-form-urlencoded
(such as for POST, PUT, PATCH).
:meth:`.request` is for making any kind of request, it will look up the
appropriate encoding format and use one of the above two methods to make
the request.
Initializer parameters:
:param headers:
Headers to include with all requests, unless other headers are given
explicitly.
"""
_encode_url_methods = {"DELETE", "GET", "HEAD", "OPTIONS"}
def __init__(self, headers: typing.Mapping[str, str] | None = None) -> None:
self.headers = headers or {}
@typing.overload
def urlopen(
self,
method: str,
url: str,
body: _TYPE_BODY | None = None,
headers: typing.Mapping[str, str] | None = None,
encode_multipart: bool = True,
multipart_boundary: str | None = None,
*,
multiplexed: Literal[False] = ...,
**kw: typing.Any,
) -> HTTPResponse: ...
@typing.overload
def urlopen(
self,
method: str,
url: str,
body: _TYPE_BODY | None = None,
headers: typing.Mapping[str, str] | None = None,
encode_multipart: bool = True,
multipart_boundary: str | None = None,
*,
multiplexed: Literal[True],
**kw: typing.Any,
) -> ResponsePromise: ...
def urlopen(
self,
method: str,
url: str,
body: _TYPE_BODY | None = None,
headers: typing.Mapping[str, str] | None = None,
encode_multipart: bool = True,
multipart_boundary: str | None = None,
**kw: typing.Any,
) -> HTTPResponse | ResponsePromise:
raise NotImplementedError(
"Classes extending RequestMethods must implement "
"their own ``urlopen`` method."
)
@typing.overload
def request(
self,
method: str,
url: str,
body: _TYPE_BODY | None = ...,
fields: _TYPE_FIELDS | None = ...,
headers: typing.Mapping[str, str] | None = ...,
json: typing.Any | None = ...,
*,
multiplexed: Literal[False] = ...,
**urlopen_kw: typing.Any,
) -> HTTPResponse: ...
@typing.overload
def request(
self,
method: str,
url: str,
body: _TYPE_BODY | None = ...,
fields: _TYPE_FIELDS | None = ...,
headers: typing.Mapping[str, str] | None = ...,
json: typing.Any | None = ...,
*,
multiplexed: Literal[True],
**urlopen_kw: typing.Any,
) -> ResponsePromise: ...
def request(
self,
method: str,
url: str,
body: _TYPE_BODY | None = None,
fields: _TYPE_FIELDS | None = None,
headers: typing.Mapping[str, str] | None = None,
json: typing.Any | None = None,
**urlopen_kw: typing.Any,
) -> HTTPResponse | ResponsePromise:
"""
Make a request using :meth:`urlopen` with the appropriate encoding of
``fields`` based on the ``method`` used.
This is a convenience method that requires the least amount of manual
effort. It can be used in most situations, while still having the
option to drop down to more specific methods when necessary, such as
:meth:`request_encode_url`, :meth:`request_encode_body`,
or even the lowest level :meth:`urlopen`.
"""
method = method.upper()
if json is not None and body is not None:
raise TypeError(
"request got values for both 'body' and 'json' parameters which are mutually exclusive"
)
if json is not None:
if headers is None:
headers = self.headers.copy() # type: ignore
if "content-type" not in map(str.lower, headers.keys()):
headers["Content-Type"] = "application/json" # type: ignore
body = _json.dumps(json, separators=(",", ":"), ensure_ascii=False).encode(
"utf-8"
)
if body is not None:
urlopen_kw["body"] = body
if method in self._encode_url_methods:
return self.request_encode_url(
method,
url,
fields=fields, # type: ignore[arg-type]
headers=headers,
**urlopen_kw,
)
else:
return self.request_encode_body( # type: ignore[no-any-return]
method,
url,
fields=fields,
headers=headers,
**urlopen_kw,
)
@typing.overload
def request_encode_url(
self,
method: str,
url: str,
fields: _TYPE_ENCODE_URL_FIELDS | None = ...,
headers: typing.Mapping[str, str] | None = ...,
*,
multiplexed: Literal[False] = ...,
**urlopen_kw: typing.Any,
) -> HTTPResponse: ...
@typing.overload
def request_encode_url(
self,
method: str,
url: str,
fields: _TYPE_ENCODE_URL_FIELDS | None = ...,
headers: typing.Mapping[str, str] | None = ...,
*,
multiplexed: Literal[True],
**urlopen_kw: typing.Any,
) -> ResponsePromise: ...
def request_encode_url(
self,
method: str,
url: str,
fields: _TYPE_ENCODE_URL_FIELDS | None = None,
headers: typing.Mapping[str, str] | None = None,
**urlopen_kw: typing.Any,
) -> HTTPResponse | ResponsePromise:
"""
Make a request using :meth:`urlopen` with the ``fields`` encoded in
the url. This is useful for request methods like GET, HEAD, DELETE, etc.
"""
if headers is None:
headers = self.headers
extra_kw: dict[str, typing.Any] = {"headers": headers}
extra_kw.update(urlopen_kw)
if fields:
url += "?" + urlencode(fields)
return self.urlopen(method, url, **extra_kw) # type: ignore[no-any-return]
@typing.overload
def request_encode_body(
self,
method: str,
url: str,
fields: _TYPE_FIELDS | None = ...,
headers: typing.Mapping[str, str] | None = ...,
encode_multipart: bool = ...,
multipart_boundary: str | None = ...,
*,
multiplexed: Literal[False] = ...,
**urlopen_kw: typing.Any,
) -> HTTPResponse: ...
@typing.overload
def request_encode_body(
self,
method: str,
url: str,
fields: _TYPE_FIELDS | None = ...,
headers: typing.Mapping[str, str] | None = ...,
encode_multipart: bool = ...,
multipart_boundary: str | None = ...,
*,
multiplexed: Literal[True],
**urlopen_kw: typing.Any,
) -> ResponsePromise: ...
def request_encode_body(
self,
method: str,
url: str,
fields: _TYPE_FIELDS | None = None,
headers: typing.Mapping[str, str] | None = None,
encode_multipart: bool = True,
multipart_boundary: str | None = None,
**urlopen_kw: typing.Any,
) -> HTTPResponse | ResponsePromise:
"""
Make a request using :meth:`urlopen` with the ``fields`` encoded in
the body. This is useful for request methods like POST, PUT, PATCH, etc.
When ``encode_multipart=True`` (default), then
:func:`urllib3.encode_multipart_formdata` is used to encode
the payload with the appropriate content type. Otherwise
:func:`urllib.parse.urlencode` is used with the
'application/x-www-form-urlencoded' content type.
Multipart encoding must be used when posting files, and it's reasonably
safe to use it in other times too. However, it may break request
signing, such as with OAuth.
Supports an optional ``fields`` parameter of key/value strings AND
key/filetuple. A filetuple is a (filename, data, MIME type) tuple where
the MIME type is optional. For example::
fields = {
'foo': 'bar',
'fakefile': ('foofile.txt', 'contents of foofile'),
'realfile': ('barfile.txt', open('realfile').read()),
'typedfile': ('bazfile.bin', open('bazfile').read(),
'image/jpeg'),
'nonamefile': 'contents of nonamefile field',
}
When uploading a file, providing a filename (the first parameter of the
tuple) is optional but recommended to best mimic behavior of browsers.
Note that if ``headers`` are supplied, the 'Content-Type' header will
be overwritten because it depends on the dynamic random boundary string
which is used to compose the body of the request. The random boundary
string can be explicitly set with the ``multipart_boundary`` parameter.
"""
if headers is None:
headers = self.headers
extra_kw: dict[str, typing.Any] = {"headers": HTTPHeaderDict(headers)}
body: bytes | str
if fields:
if "body" in urlopen_kw:
raise TypeError(
"request got values for both 'fields' and 'body', can only specify one."
)
if encode_multipart:
body, content_type = encode_multipart_formdata(
fields, boundary=multipart_boundary
)
else:
body, content_type = (
urlencode(fields), # type: ignore[arg-type]
"application/x-www-form-urlencoded",
)
extra_kw["body"] = body
extra_kw["headers"].setdefault("Content-Type", content_type)
extra_kw.update(urlopen_kw)
return self.urlopen(method, url, **extra_kw) # type: ignore[no-any-return]
class AsyncRequestMethods:
"""
Convenience mixin for classes who implement a :meth:`urlopen` method, such
as :class:`urllib3.AsyncHTTPConnectionPool` and
:class:`urllib3.AsyncPoolManager`.
Provides behavior for making common types of HTTP request methods and
decides which type of request field encoding to use.
Specifically,
:meth:`.request_encode_url` is for sending requests whose fields are
encoded in the URL (such as GET, HEAD, DELETE).
:meth:`.request_encode_body` is for sending requests whose fields are
encoded in the *body* of the request using multipart or www-form-urlencoded
(such as for POST, PUT, PATCH).
:meth:`.request` is for making any kind of request, it will look up the
appropriate encoding format and use one of the above two methods to make
the request.
Initializer parameters:
:param headers:
Headers to include with all requests, unless other headers are given
explicitly.
"""
_encode_url_methods = {"DELETE", "GET", "HEAD", "OPTIONS"}
def __init__(self, headers: typing.Mapping[str, str] | None = None) -> None:
self.headers = headers or {}
@typing.overload
async def urlopen(
self,
method: str,
url: str,
body: _TYPE_BODY | _TYPE_ASYNC_BODY | None = None,
headers: typing.Mapping[str, str] | None = None,
encode_multipart: bool = True,
multipart_boundary: str | None = None,
*,
multiplexed: Literal[False] = ...,
**kw: typing.Any,
) -> AsyncHTTPResponse: ...
@typing.overload
async def urlopen(
self,
method: str,
url: str,
body: _TYPE_BODY | _TYPE_ASYNC_BODY | None = None,
headers: typing.Mapping[str, str] | None = None,
encode_multipart: bool = True,
multipart_boundary: str | None = None,
*,
multiplexed: Literal[True],
**kw: typing.Any,
) -> ResponsePromise: ...
async def urlopen(
self,
method: str,
url: str,
body: _TYPE_BODY | _TYPE_ASYNC_BODY | None = None,
headers: typing.Mapping[str, str] | None = None,
encode_multipart: bool = True,
multipart_boundary: str | None = None,
**kw: typing.Any,
) -> AsyncHTTPResponse | ResponsePromise:
raise NotImplementedError(
"Classes extending RequestMethods must implement "
"their own ``urlopen`` method."
)
@typing.overload
async def request(
self,
method: str,
url: str,
body: _TYPE_BODY | _TYPE_ASYNC_BODY | None = ...,
fields: _TYPE_FIELDS | None = ...,
headers: typing.Mapping[str, str] | None = ...,
json: typing.Any | None = ...,
*,
multiplexed: Literal[False] = ...,
**urlopen_kw: typing.Any,
) -> AsyncHTTPResponse: ...
@typing.overload
async def request(
self,
method: str,
url: str,
body: _TYPE_BODY | _TYPE_ASYNC_BODY | None = ...,
fields: _TYPE_FIELDS | None = ...,
headers: typing.Mapping[str, str] | None = ...,
json: typing.Any | None = ...,
*,
multiplexed: Literal[True],
**urlopen_kw: typing.Any,
) -> ResponsePromise: ...
async def request(
self,
method: str,
url: str,
body: _TYPE_BODY | _TYPE_ASYNC_BODY | None = None,
fields: _TYPE_FIELDS | None = None,
headers: typing.Mapping[str, str] | None = None,
json: typing.Any | None = None,
**urlopen_kw: typing.Any,
) -> AsyncHTTPResponse | ResponsePromise:
"""
Make a request using :meth:`urlopen` with the appropriate encoding of
``fields`` based on the ``method`` used.
This is a convenience method that requires the least amount of manual
effort. It can be used in most situations, while still having the
option to drop down to more specific methods when necessary, such as
:meth:`request_encode_url`, :meth:`request_encode_body`,
or even the lowest level :meth:`urlopen`.
"""
method = method.upper()
if json is not None and body is not None:
raise TypeError(
"request got values for both 'body' and 'json' parameters which are mutually exclusive"
)
if json is not None:
if headers is None:
headers = self.headers.copy() # type: ignore
if "content-type" not in map(str.lower, headers.keys()):
headers["Content-Type"] = "application/json" # type: ignore
body = _json.dumps(json, separators=(",", ":"), ensure_ascii=False).encode(
"utf-8"
)
if body is not None:
urlopen_kw["body"] = body
if method in self._encode_url_methods:
return await self.request_encode_url(
method,
url,
fields=fields, # type: ignore[arg-type]
headers=headers,
**urlopen_kw,
)
else:
return await self.request_encode_body( # type: ignore[no-any-return]
method,
url,
fields=fields,
headers=headers,
**urlopen_kw,
)
@typing.overload
async def request_encode_url(
self,
method: str,
url: str,
fields: _TYPE_ENCODE_URL_FIELDS | None = ...,
headers: typing.Mapping[str, str] | None = ...,
*,
multiplexed: Literal[False] = ...,
**urlopen_kw: typing.Any,
) -> AsyncHTTPResponse: ...
@typing.overload
async def request_encode_url(
self,
method: str,
url: str,
fields: _TYPE_ENCODE_URL_FIELDS | None = ...,
headers: typing.Mapping[str, str] | None = ...,
*,
multiplexed: Literal[True],
**urlopen_kw: typing.Any,
) -> ResponsePromise: ...
async def request_encode_url(
self,
method: str,
url: str,
fields: _TYPE_ENCODE_URL_FIELDS | None = None,
headers: typing.Mapping[str, str] | None = None,
**urlopen_kw: typing.Any,
) -> AsyncHTTPResponse | ResponsePromise:
"""
Make a request using :meth:`urlopen` with the ``fields`` encoded in
the url. This is useful for request methods like GET, HEAD, DELETE, etc.
"""
if headers is None:
headers = self.headers
extra_kw: dict[str, typing.Any] = {"headers": headers}
extra_kw.update(urlopen_kw)
if fields:
url += "?" + urlencode(fields)
return await self.urlopen(method, url, **extra_kw) # type: ignore[no-any-return]
@typing.overload
async def request_encode_body(
self,
method: str,
url: str,
fields: _TYPE_FIELDS | None = ...,
headers: typing.Mapping[str, str] | None = ...,
encode_multipart: bool = ...,
multipart_boundary: str | None = ...,
*,
multiplexed: Literal[False] = ...,
**urlopen_kw: typing.Any,
) -> AsyncHTTPResponse: ...
@typing.overload
async def request_encode_body(
self,
method: str,
url: str,
fields: _TYPE_FIELDS | None = ...,
headers: typing.Mapping[str, str] | None = ...,
encode_multipart: bool = ...,
multipart_boundary: str | None = ...,
*,
multiplexed: Literal[True],
**urlopen_kw: typing.Any,
) -> ResponsePromise: ...
async def request_encode_body(
self,
method: str,
url: str,
fields: _TYPE_FIELDS | None = None,
headers: typing.Mapping[str, str] | None = None,
encode_multipart: bool = True,
multipart_boundary: str | None = None,
**urlopen_kw: typing.Any,
) -> AsyncHTTPResponse | ResponsePromise:
"""
Make a request using :meth:`urlopen` with the ``fields`` encoded in
the body. This is useful for request methods like POST, PUT, PATCH, etc.
When ``encode_multipart=True`` (default), then
:func:`urllib3.encode_multipart_formdata` is used to encode
the payload with the appropriate content type. Otherwise
:func:`urllib.parse.urlencode` is used with the
'application/x-www-form-urlencoded' content type.
Multipart encoding must be used when posting files, and it's reasonably
safe to use it in other times too. However, it may break request
signing, such as with OAuth.
Supports an optional ``fields`` parameter of key/value strings AND
key/filetuple. A filetuple is a (filename, data, MIME type) tuple where
the MIME type is optional. For example::
fields = {
'foo': 'bar',
'fakefile': ('foofile.txt', 'contents of foofile'),
'realfile': ('barfile.txt', open('realfile').read()),
'typedfile': ('bazfile.bin', open('bazfile').read(),
'image/jpeg'),
'nonamefile': 'contents of nonamefile field',
}
When uploading a file, providing a filename (the first parameter of the
tuple) is optional but recommended to best mimic behavior of browsers.
Note that if ``headers`` are supplied, the 'Content-Type' header will
be overwritten because it depends on the dynamic random boundary string
which is used to compose the body of the request. The random boundary
string can be explicitly set with the ``multipart_boundary`` parameter.
"""
if headers is None:
headers = self.headers
extra_kw: dict[str, typing.Any] = {"headers": HTTPHeaderDict(headers)}
body: bytes | str
if fields:
if "body" in urlopen_kw:
raise TypeError(
"request got values for both 'fields' and 'body', can only specify one."
)
if encode_multipart:
body, content_type = encode_multipart_formdata(
fields, boundary=multipart_boundary
)
else:
body, content_type = (
urlencode(fields), # type: ignore[arg-type]
"application/x-www-form-urlencoded",
)
extra_kw["body"] = body
extra_kw["headers"].setdefault("Content-Type", content_type)
extra_kw.update(urlopen_kw)
return await self.urlopen(method, url, **extra_kw) # type: ignore[no-any-return]

View File

@@ -0,0 +1,94 @@
from __future__ import annotations
import typing
from enum import Enum
from .backend import LowLevelResponse
from .backend._async import AsyncLowLevelResponse
from .fields import RequestField
from .util.request import _TYPE_FAILEDTELL
from .util.timeout import _TYPE_DEFAULT, Timeout
if typing.TYPE_CHECKING:
import ssl
from typing_extensions import Literal, TypedDict
class _TYPE_PEER_CERT_RET_DICT(TypedDict, total=False):
subjectAltName: tuple[tuple[str, str], ...]
subject: tuple[tuple[tuple[str, str], ...], ...]
serialNumber: str
_TYPE_BODY: typing.TypeAlias = typing.Union[
bytes,
typing.IO[typing.Any],
typing.Iterable[bytes],
typing.Iterable[str],
str,
LowLevelResponse,
AsyncLowLevelResponse,
]
_TYPE_ASYNC_BODY: typing.TypeAlias = typing.Union[
typing.AsyncIterable[bytes],
typing.AsyncIterable[str],
]
_TYPE_FIELD_VALUE: typing.TypeAlias = typing.Union[str, bytes]
_TYPE_FIELD_VALUE_TUPLE: typing.TypeAlias = typing.Union[
_TYPE_FIELD_VALUE,
typing.Tuple[str, _TYPE_FIELD_VALUE],
typing.Tuple[str, _TYPE_FIELD_VALUE, str],
]
_TYPE_FIELDS_SEQUENCE: typing.TypeAlias = typing.Sequence[
typing.Union[typing.Tuple[str, _TYPE_FIELD_VALUE_TUPLE], RequestField]
]
_TYPE_FIELDS: typing.TypeAlias = typing.Union[
_TYPE_FIELDS_SEQUENCE,
typing.Mapping[str, _TYPE_FIELD_VALUE_TUPLE],
]
_TYPE_ENCODE_URL_FIELDS: typing.TypeAlias = typing.Union[
typing.Sequence[typing.Tuple[str, typing.Union[str, bytes]]],
typing.Mapping[str, typing.Union[str, bytes]],
]
_TYPE_SOCKET_OPTIONS: typing.TypeAlias = typing.Sequence[
typing.Union[
typing.Tuple[int, int, int],
typing.Tuple[int, int, int, str],
]
]
_TYPE_REDUCE_RESULT: typing.TypeAlias = typing.Tuple[
typing.Callable[..., object], typing.Tuple[object, ...]
]
_TYPE_TIMEOUT: typing.TypeAlias = typing.Union[float, _TYPE_DEFAULT, Timeout, None]
_TYPE_TIMEOUT_INTERNAL: typing.TypeAlias = typing.Union[float, _TYPE_DEFAULT, None]
_TYPE_PEER_CERT_RET: typing.TypeAlias = typing.Union[
"_TYPE_PEER_CERT_RET_DICT", bytes, None
]
_TYPE_BODY_POSITION: typing.TypeAlias = typing.Union[int, _TYPE_FAILEDTELL]
try:
from typing import TypedDict
class _TYPE_SOCKS_OPTIONS(TypedDict):
socks_version: int | Enum
proxy_host: str | None
proxy_port: str | None
username: str | None
password: str | None
rdns: bool
except ImportError: # Python 3.7
_TYPE_SOCKS_OPTIONS = typing.Dict[str, typing.Any] # type: ignore[misc, assignment]
class ProxyConfig(typing.NamedTuple):
ssl_context: ssl.SSLContext | None
use_forwarding_for_https: bool
assert_hostname: None | str | Literal[False]
assert_fingerprint: str | None

View File

@@ -0,0 +1,4 @@
# This file is protected via CODEOWNERS
from __future__ import annotations
__version__ = "2.17.902"

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
from ._base import (
BaseBackend,
ConnectionInfo,
HttpVersion,
LowLevelResponse,
QuicPreemptiveCacheType,
ResponsePromise,
)
from .hface import HfaceBackend
__all__ = (
"BaseBackend",
"HfaceBackend",
"HttpVersion",
"QuicPreemptiveCacheType",
"LowLevelResponse",
"ConnectionInfo",
"ResponsePromise",
)

View File

@@ -0,0 +1,10 @@
from __future__ import annotations
from ._base import AsyncBaseBackend, AsyncLowLevelResponse
from .hface import AsyncHfaceBackend
__all__ = (
"AsyncBaseBackend",
"AsyncLowLevelResponse",
"AsyncHfaceBackend",
)

View File

@@ -0,0 +1,330 @@
from __future__ import annotations
import typing
from ..._collections import HTTPHeaderDict
from ...contrib.ssa import AsyncSocket, SSLAsyncSocket
from .._base import BaseBackend, ResponsePromise
from ...util.response import BytesQueueBuffer
class AsyncDirectStreamAccess:
def __init__(
self,
stream_id: int,
read: typing.Callable[
[int | None, int | None, bool, bool],
typing.Awaitable[tuple[bytes, bool, HTTPHeaderDict | None]],
]
| None = None,
write: typing.Callable[[bytes, int, bool], typing.Awaitable[None]]
| None = None,
) -> None:
self._stream_id = stream_id
self._read = read
self._write = write
@property
def closed(self) -> bool:
return self._read is None and self._write is None
async def readinto(self, b: bytearray) -> int:
if self._read is None:
raise OSError("read operation on a closed stream")
temp = await self.recv(len(b))
if len(temp) == 0:
return 0
else:
b[: len(temp)] = temp
return len(temp)
def readable(self) -> bool:
return self._read is not None
def writable(self) -> bool:
return self._write is not None
def seekable(self) -> bool:
return False
def fileno(self) -> int:
return -1
def name(self) -> int:
return -1
async def recv(self, __bufsize: int, __flags: int = 0) -> bytes:
data, _, _ = await self.recv_extended(__bufsize)
return data
async def recv_extended(
self, __bufsize: int | None
) -> tuple[bytes, bool, HTTPHeaderDict | None]:
if self._read is None:
raise OSError("stream closed error")
data, eot, trailers = await self._read(
__bufsize,
self._stream_id,
__bufsize is not None,
False,
)
if eot:
self._read = None
return data, eot, trailers
async def sendall(self, __data: bytes, __flags: int = 0) -> None:
if self._write is None:
raise OSError("stream write not permitted")
await self._write(__data, self._stream_id, False)
async def write(self, __data: bytes) -> int:
if self._write is None:
raise OSError("stream write not permitted")
await self._write(__data, self._stream_id, False)
return len(__data)
async def sendall_extended(
self, __data: bytes, __close_stream: bool = False
) -> None:
if self._write is None:
raise OSError("stream write not permitted")
await self._write(__data, self._stream_id, __close_stream)
async def close(self) -> None:
if self._write is not None:
await self._write(b"", self._stream_id, True)
self._write = None
if self._read is not None:
await self._read(None, self._stream_id, False, True)
self._read = None
class AsyncLowLevelResponse:
"""Implemented for backward compatibility purposes. It is there to impose http.client like
basic response object. So that we don't have to change urllib3 tested behaviors."""
__internal_read_st: (
typing.Callable[
[int | None, int | None],
typing.Awaitable[tuple[bytes, bool, HTTPHeaderDict | None]],
]
| None
)
def __init__(
self,
method: str,
status: int,
version: int,
reason: str,
headers: HTTPHeaderDict,
body: typing.Callable[
[int | None, int | None],
typing.Awaitable[tuple[bytes, bool, HTTPHeaderDict | None]],
]
| None,
*,
authority: str | None = None,
port: int | None = None,
stream_id: int | None = None,
# this obj should not be always available[...]
dsa: AsyncDirectStreamAccess | None = None,
stream_abort: typing.Callable[[int], typing.Awaitable[None]] | None = None,
) -> None:
self.status = status
self.version = version
self.reason = reason
self.msg = headers
self._method = method
self.__internal_read_st = body
has_body = self.__internal_read_st is not None
self.closed = has_body is False
self._eot = self.closed
# is kept to determine if we can upgrade conn
self.authority = authority
self.port = port
# http.client compat layer
self.debuglevel: int = 0 # no-op flag, kept for strict backward compatibility!
self.chunked: bool = ( # is "chunked" being used? http1 only!
self.version == 11 and "chunked" == self.msg.get("transfer-encoding")
)
self.chunk_left: int | None = None # bytes left to read in current chunk
self.length: int | None = None # number of bytes left in response
self.will_close: bool = (
False # no-op flag, kept for strict backward compatibility!
)
if not self.chunked:
content_length = self.msg.get("content-length")
self.length = int(content_length) if content_length else None
#: not part of http.client but useful to track (raw) download speeds!
self.data_in_count = 0
self._stream_id = stream_id
self.__buffer_excess: BytesQueueBuffer = BytesQueueBuffer()
self.__promise: ResponsePromise | None = None
self._dsa = dsa
self._stream_abort = stream_abort
self.trailers: HTTPHeaderDict | None = None
@property
def fp(self) -> typing.NoReturn:
raise RuntimeError(
"urllib3-future no longer expose a filepointer-like in responses. It was a remnant from the http.client era. "
"We no longer support it."
)
@property
def from_promise(self) -> ResponsePromise | None:
return self.__promise
@from_promise.setter
def from_promise(self, value: ResponsePromise) -> None:
if value.stream_id != self._stream_id:
raise ValueError(
"Trying to assign a ResponsePromise to an unrelated LowLevelResponse"
)
self.__promise = value
@property
def method(self) -> str:
"""Original HTTP verb used in the request."""
return self._method
def isclosed(self) -> bool:
"""Here we do not create a fp sock like http.client Response."""
return self.closed
async def read(self, __size: int | None = None) -> bytes:
if self.closed is True or self.__internal_read_st is None:
# overly protective, just in case.
raise ValueError(
"I/O operation on closed file."
) # Defensive: Should not be reachable in normal condition
if __size == 0:
return b"" # Defensive: This is unreachable, this case is already covered higher in the stack.
buf_capacity = len(self.__buffer_excess)
data_ready_to_go = (
__size is not None and buf_capacity > 0 and buf_capacity >= __size
)
if self._eot is False and not data_ready_to_go:
data, self._eot, self.trailers = await self.__internal_read_st(
__size, self._stream_id
)
self.__buffer_excess.put(data)
buf_capacity = len(self.__buffer_excess)
data = self.__buffer_excess.get(
__size if __size is not None and __size > 0 else buf_capacity
)
size_in = len(data)
buf_capacity -= size_in
if self._eot and buf_capacity == 0:
self._stream_abort = None
self.closed = True
self._sock = None
if self.chunked:
self.chunk_left = buf_capacity if buf_capacity else None
elif self.length is not None:
self.length -= size_in
self.data_in_count += size_in
return data
async def abort(self) -> None:
if self._stream_abort is not None:
if self._eot is False:
if self._stream_id is not None:
await self._stream_abort(self._stream_id)
self._eot = True
self._stream_abort = None
self.closed = True
self._dsa = None
def close(self) -> None:
self.__internal_read_st = None
self.closed = True
self._dsa = None
class AsyncBaseBackend(BaseBackend):
sock: AsyncSocket | SSLAsyncSocket | None # type: ignore[assignment]
async def _upgrade(self) -> None: # type: ignore[override]
"""Upgrade conn from svn ver to max supported."""
raise NotImplementedError
async def _tunnel(self) -> None: # type: ignore[override]
"""Emit proper CONNECT request to the http (server) intermediary."""
raise NotImplementedError
async def _new_conn(self) -> AsyncSocket | None: # type: ignore[override]
"""Run protocol initialization from there. Return None to ensure that the child
class correctly create the socket / connection."""
raise NotImplementedError
async def _post_conn(self) -> None: # type: ignore[override]
"""Should be called after _new_conn proceed as expected.
Expect protocol handshake to be done here."""
raise NotImplementedError
async def endheaders( # type: ignore[override]
self,
message_body: bytes | None = None,
*,
encode_chunked: bool = False,
expect_body_afterward: bool = False,
) -> ResponsePromise | None:
"""This method conclude the request context construction."""
raise NotImplementedError
async def getresponse( # type: ignore[override]
self, *, promise: ResponsePromise | None = None
) -> AsyncLowLevelResponse:
"""Fetch the HTTP response. You SHOULD not retrieve the body in that method, it SHOULD be done
in the LowLevelResponse, so it enable stream capabilities and remain efficient.
"""
raise NotImplementedError
async def close(self) -> None: # type: ignore[override]
"""End the connection, do some reinit, closing of fd, etc..."""
raise NotImplementedError
async def send( # type: ignore[override]
self,
data: (bytes | typing.IO[typing.Any] | typing.Iterable[bytes] | str),
*,
eot: bool = False,
) -> ResponsePromise | None:
"""The send() method SHOULD be invoked after calling endheaders() if and only if the request
context specify explicitly that a body is going to be sent."""
raise NotImplementedError
async def ping(self) -> None: # type: ignore[override]
raise NotImplementedError

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,690 @@
from __future__ import annotations
import enum
import socket
import time
import typing
import warnings
from base64 import b64encode
from datetime import datetime, timedelta
from secrets import token_bytes
if typing.TYPE_CHECKING:
from ssl import SSLSocket, SSLContext, TLSVersion
from .._typing import _TYPE_SOCKET_OPTIONS
from ._async import AsyncLowLevelResponse
from .._collections import HTTPHeaderDict
from .._constant import DEFAULT_BLOCKSIZE, DEFAULT_KEEPALIVE_DELAY
from ..util.response import BytesQueueBuffer
class HttpVersion(str, enum.Enum):
"""Describe possible SVN protocols that can be supported."""
h11 = "HTTP/1.1"
# we know that it is rather "HTTP/2" than "HTTP/2.0"
# it is this way to remain somewhat compatible with http.client
# http_svn (int). 9 -> 11 -> 20 -> 30
h2 = "HTTP/2.0"
h3 = "HTTP/3.0"
class ConnectionInfo:
def __init__(self) -> None:
#: Time taken to establish the connection
self.established_latency: timedelta | None = None
#: HTTP protocol used with the remote peer (not the proxy)
self.http_version: HttpVersion | None = None
#: The SSL certificate presented by the remote peer (not the proxy)
self.certificate_der: bytes | None = None
self.certificate_dict: (
dict[str, int | tuple[tuple[str, str], ...] | tuple[str, ...] | str] | None
) = None
#: The SSL issuer certificate for the remote peer certificate (not the proxy)
self.issuer_certificate_der: bytes | None = None
self.issuer_certificate_dict: (
dict[str, int | tuple[tuple[str, str], ...] | tuple[str, ...] | str] | None
) = None
#: The IP address used to reach the remote peer (not the proxy), that was yield by your resolver.
self.destination_address: tuple[str, int] | None = None
#: The TLS cipher used to secure the exchanges (not the proxy)
self.cipher: str | None = None
#: The TLS revision used (not the proxy)
self.tls_version: TLSVersion | None = None
#: The time taken to reach a complete TLS liaison between the remote peer and us. (not the proxy)
self.tls_handshake_latency: timedelta | None = None
#: Time taken to resolve a domain name into a reachable IP address.
self.resolution_latency: timedelta | None = None
#: Time taken to encode and send the whole request through the socket.
self.request_sent_latency: timedelta | None = None
def __repr__(self) -> str:
return str(
{
"established_latency": self.established_latency,
"certificate_der": self.certificate_der,
"certificate_dict": self.certificate_dict,
"issuer_certificate_der": self.issuer_certificate_der,
"issuer_certificate_dict": self.issuer_certificate_dict,
"destination_address": self.destination_address,
"cipher": self.cipher,
"tls_version": self.tls_version,
"tls_handshake_latency": self.tls_handshake_latency,
"http_version": self.http_version,
"resolution_latency": self.resolution_latency,
"request_sent_latency": self.request_sent_latency,
}
)
def is_encrypted(self) -> bool:
return self.certificate_der is not None
class DirectStreamAccess:
def __init__(
self,
stream_id: int,
read: typing.Callable[
[int | None, int | None, bool, bool],
tuple[bytes, bool, HTTPHeaderDict | None],
]
| None = None,
write: typing.Callable[[bytes, int, bool], None] | None = None,
) -> None:
self._stream_id = stream_id
if read is not None:
self._read: (
typing.Callable[
[int | None, bool], tuple[bytes, bool, HTTPHeaderDict | None]
]
| None
) = lambda amt, fo: read(amt, self._stream_id, amt is not None, fo)
else:
self._read = None
if write is not None:
self._write: typing.Callable[[bytes, bool], None] | None = (
lambda buf, eot: write(buf, self._stream_id, eot)
)
else:
self._write = None
@property
def closed(self) -> bool:
return self._read is None and self._write is None
def readinto(self, b: bytearray) -> int:
if self._read is None:
raise OSError("read operation on a closed stream")
temp = self.recv(len(b))
if len(temp) == 0:
return 0
else:
b[: len(temp)] = temp
return len(temp)
def readable(self) -> bool:
return self._read is not None
def writable(self) -> bool:
return self._write is not None
def seekable(self) -> bool:
return False
def fileno(self) -> int:
return -1
def name(self) -> int:
return -1
def recv(self, __bufsize: int, __flags: int = 0) -> bytes:
data, _, _ = self.recv_extended(__bufsize)
return data
def recv_extended(
self, __bufsize: int | None
) -> tuple[bytes, bool, HTTPHeaderDict | None]:
if self._read is None:
raise OSError("stream closed error")
data, eot, trailers = self._read(__bufsize, False)
if eot:
self._read = None
return data, eot, trailers
def sendall(self, __data: bytes, __flags: int = 0) -> None:
if self._write is None:
raise OSError("stream write not permitted")
self._write(__data, False)
def write(self, __data: bytes) -> int:
if self._write is None:
raise OSError("stream write not permitted")
self._write(__data, False)
return len(__data)
def sendall_extended(self, __data: bytes, __close_stream: bool = False) -> None:
if self._write is None:
raise OSError("stream write not permitted")
self._write(__data, __close_stream)
def close(self) -> None:
if self._write is not None:
self._write(b"", True)
self._write = None
if self._read is not None:
self._read(None, True)
self._read = None
class LowLevelResponse:
"""Implemented for backward compatibility purposes. It is there to impose http.client like
basic response object. So that we don't have to change urllib3 tested behaviors."""
def __init__(
self,
method: str,
status: int,
version: int,
reason: str,
headers: HTTPHeaderDict,
body: typing.Callable[
[int | None, int | None], tuple[bytes, bool, HTTPHeaderDict | None]
]
| None,
*,
authority: str | None = None,
port: int | None = None,
stream_id: int | None = None,
sock: socket.socket | None = None,
# this obj should not be always available[...]
dsa: DirectStreamAccess | None = None,
stream_abort: typing.Callable[[int], None] | None = None,
):
self.status = status
self.version = version
self.reason = reason
self.msg = headers
self._method = method
self.__internal_read_st = body
has_body = self.__internal_read_st is not None
self.closed = has_body is False
self._eot = self.closed
# is kept to determine if we can upgrade conn
self.authority = authority
self.port = port
# http.client additional compat layer
# although rarely used, some 3rd party library may
# peek at those for whatever reason. most of the time they
# are wrong to do so.
self.debuglevel: int = 0 # no-op flag, kept for strict backward compatibility!
self.chunked: bool = ( # is "chunked" being used? http1 only!
self.version == 11 and "chunked" == self.msg.get("transfer-encoding")
)
self.chunk_left: int | None = None # bytes left to read in current chunk
self.length: int | None = None # number of bytes left in response
self.will_close: bool = (
False # no-op flag, kept for strict backward compatibility!
)
if not self.chunked:
content_length = self.msg.get("content-length")
self.length = int(content_length) if content_length else None
#: not part of http.client but useful to track (raw) download speeds!
self.data_in_count = 0
# tricky part...
# sometime 3rd party library tend to access hazardous materials...
# they want a direct socket access.
self._sock = sock
self._fp: socket.SocketIO | None = None
self._dsa = dsa
self._stream_abort = stream_abort
self._stream_id = stream_id
self.__buffer_excess: BytesQueueBuffer = BytesQueueBuffer()
self.__promise: ResponsePromise | None = None
self.trailers: HTTPHeaderDict | None = None
@property
def fp(self) -> socket.SocketIO | DirectStreamAccess | None:
warnings.warn(
(
"This is a rather awkward situation. A program (probably) tried to access the socket object "
"directly, thus bypassing our state-machine protocol (amongst other things). "
"This is currently unsupported and dangerous. Errors will occurs if you negotiated HTTP/2 or later versions. "
"We tried to be rather strict on the backward compatibility between urllib3 and urllib3-future, "
"but this is rather complicated to support (e.g. direct socket access). "
"You are probably better off using our higher level read() function. "
"Please open an issue at https://github.com/jawah/urllib3.future/issues to gain support or "
"insights on it."
),
DeprecationWarning,
2,
)
if self._sock is None:
if self.status == 101 or (
self._method == "CONNECT" and 200 <= self.status < 300
):
return self._dsa
# well, there's nothing we can do more :'(
raise AttributeError
if self._fp is None:
self._fp = self._sock.makefile("rb") # type: ignore[assignment]
return self._fp
@property
def from_promise(self) -> ResponsePromise | None:
return self.__promise
@from_promise.setter
def from_promise(self, value: ResponsePromise) -> None:
if value.stream_id != self._stream_id:
raise ValueError(
"Trying to assign a ResponsePromise to an unrelated LowLevelResponse"
)
self.__promise = value
@property
def method(self) -> str:
"""Original HTTP verb used in the request."""
return self._method
def isclosed(self) -> bool:
"""Here we do not create a fp sock like http.client Response."""
return self.closed
def read(self, __size: int | None = None) -> bytes:
if self.closed is True or self.__internal_read_st is None:
# overly protective, just in case.
raise ValueError(
"I/O operation on closed file."
) # Defensive: Should not be reachable in normal condition
if __size == 0:
return b"" # Defensive: This is unreachable, this case is already covered higher in the stack.
buf_capacity = len(self.__buffer_excess)
data_ready_to_go = (
__size is not None and buf_capacity > 0 and buf_capacity >= __size
)
if self._eot is False and not data_ready_to_go:
data, self._eot, self.trailers = self.__internal_read_st(
__size, self._stream_id
)
self.__buffer_excess.put(data)
buf_capacity = len(self.__buffer_excess)
data = self.__buffer_excess.get(
__size if __size is not None and __size > 0 else buf_capacity
)
size_in = len(data)
buf_capacity -= size_in
if self._eot and buf_capacity == 0:
self._stream_abort = None
self.closed = True
self._sock = None
if self.chunked:
self.chunk_left = buf_capacity if buf_capacity else None
elif self.length is not None:
self.length -= size_in
self.data_in_count += size_in
return data
def abort(self) -> None:
if self._stream_abort is not None:
if self._eot is False:
if self._stream_id is not None:
self._stream_abort(self._stream_id)
self._eot = True
self._stream_abort = None
self.closed = True
self._dsa = None
def close(self) -> None:
self.__internal_read_st = None
self.closed = True
self._sock = None
self._dsa = None
class ResponsePromise:
def __init__(
self,
conn: BaseBackend,
stream_id: int,
request_headers: list[tuple[bytes, bytes]],
**parameters: typing.Any,
) -> None:
self._uid: str = b64encode(token_bytes(16)).decode("ascii")
self._conn: BaseBackend = conn
self._stream_id: int = stream_id
self._response: LowLevelResponse | AsyncLowLevelResponse | None = None
self._request_headers = request_headers
self._parameters: typing.MutableMapping[str, typing.Any] = parameters
def __eq__(self, other: object) -> bool:
if not isinstance(other, ResponsePromise):
return False
return self.uid == other.uid
def __repr__(self) -> str:
return f"<ResponsePromise '{self.uid}' {self._conn._http_vsn_str} Stream[{self.stream_id}]>"
@property
def uid(self) -> str:
return self._uid
@property
def request_headers(self) -> list[tuple[bytes, bytes]]:
return self._request_headers
@property
def stream_id(self) -> int:
return self._stream_id
@property
def is_ready(self) -> bool:
return self._response is not None
@property
def response(self) -> LowLevelResponse | AsyncLowLevelResponse:
if not self._response:
raise OSError
return self._response
@response.setter
def response(self, value: LowLevelResponse | AsyncLowLevelResponse) -> None:
self._response = value
def set_parameter(self, key: str, value: typing.Any) -> None:
self._parameters[key] = value
def get_parameter(self, key: str) -> typing.Any | None:
return self._parameters[key] if key in self._parameters else None
def update_parameters(self, data: dict[str, typing.Any]) -> None:
self._parameters.update(data)
_HostPortType: typing.TypeAlias = typing.Tuple[str, int]
QuicPreemptiveCacheType: typing.TypeAlias = typing.MutableMapping[
_HostPortType, typing.Optional[_HostPortType]
]
class BaseBackend:
"""
The goal here is to detach ourselves from the http.client package.
At first, we'll strictly follow the methods in http.client.HTTPConnection. So that
we would be able to implement other backend without disrupting the actual code base.
Extend that base class in order to ship another backend with urllib3.
"""
supported_svn: typing.ClassVar[list[HttpVersion] | None] = None
scheme: typing.ClassVar[str]
default_socket_kind: socket.SocketKind = socket.SOCK_STREAM
#: Disable Nagle's algorithm by default.
default_socket_options: typing.ClassVar[_TYPE_SOCKET_OPTIONS] = [
(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1, "tcp")
]
#: Whether this connection verifies the host's certificate.
is_verified: bool = False
#: Whether this proxy connection verified the proxy host's certificate.
# If no proxy is currently connected to the value will be ``None``.
proxy_is_verified: bool | None = None
response_class = LowLevelResponse
def __init__(
self,
host: str,
port: int | None = None,
timeout: int | float | None = -1,
source_address: tuple[str, int] | None = None,
blocksize: int = DEFAULT_BLOCKSIZE,
*,
socket_options: _TYPE_SOCKET_OPTIONS | None = default_socket_options,
disabled_svn: set[HttpVersion] | None = None,
preemptive_quic_cache: QuicPreemptiveCacheType | None = None,
keepalive_delay: float | int | None = DEFAULT_KEEPALIVE_DELAY,
):
self.host = host
self.port = port
self.timeout = timeout
self.source_address = source_address
self.blocksize = blocksize
self.socket_kind = BaseBackend.default_socket_kind
self.socket_options = socket_options
self.sock: socket.socket | SSLSocket | None = None
self._response: LowLevelResponse | AsyncLowLevelResponse | None = None
# Set it as default
self._svn: HttpVersion | None = HttpVersion.h11
self._tunnel_host: str | None = None
self._tunnel_port: int | None = None
self._tunnel_scheme: str | None = None
self._tunnel_headers: typing.Mapping[str, str] = dict()
self._disabled_svn = disabled_svn if disabled_svn is not None else set()
self._preemptive_quic_cache = preemptive_quic_cache
if self._disabled_svn:
if len(self._disabled_svn) == len(list(HttpVersion)):
raise RuntimeError(
"You disabled every supported protocols. The HTTP connection object is left with no outcomes."
)
# valuable intel
self.conn_info: ConnectionInfo | None = None
self._promises: dict[str, ResponsePromise] = {}
self._promises_per_stream: dict[int, ResponsePromise] = {}
self._pending_responses: dict[
int, LowLevelResponse | AsyncLowLevelResponse
] = {}
self._start_last_request: datetime | None = None
self._cached_http_vsn: int | None = None
self._keepalive_delay: float | None = (
keepalive_delay # just forwarded for qh3 idle_timeout conf.
)
self._connected_at: float | None = None
self._last_used_at: float = time.monotonic()
self._recv_size_ema: float = 0.0
def __contains__(self, item: ResponsePromise) -> bool:
return item.uid in self._promises
@property
def _fast_recv_mode(self) -> bool:
if len(self._promises) <= 1 or self._svn is HttpVersion.h3:
return True
return self._recv_size_ema >= 1450
@property
def last_used_at(self) -> float:
return self._last_used_at
@property
def connected_at(self) -> float | None:
return self._connected_at
@property
def disabled_svn(self) -> set[HttpVersion]:
return self._disabled_svn
@property
def _http_vsn_str(self) -> str:
"""Reimplemented for backward compatibility purposes."""
assert self._svn is not None
return self._svn.value
@property
def _http_vsn(self) -> int:
"""Reimplemented for backward compatibility purposes."""
assert self._svn is not None
if self._cached_http_vsn is None:
self._cached_http_vsn = int(self._svn.value.split("/")[-1].replace(".", ""))
return self._cached_http_vsn
@property
def is_saturated(self) -> bool:
raise NotImplementedError
@property
def is_idle(self) -> bool:
return not self._promises and not self._pending_responses
@property
def max_stream_count(self) -> int:
raise NotImplementedError
@property
def is_multiplexed(self) -> bool:
raise NotImplementedError
@property
def max_frame_size(self) -> int:
raise NotImplementedError
def _upgrade(self) -> None:
"""Upgrade conn from svn ver to max supported."""
raise NotImplementedError
def _tunnel(self) -> None:
"""Emit proper CONNECT request to the http (server) intermediary."""
raise NotImplementedError
def _new_conn(self) -> socket.socket | None:
"""Run protocol initialization from there. Return None to ensure that the child
class correctly create the socket / connection."""
raise NotImplementedError
def _post_conn(self) -> None:
"""Should be called after _new_conn proceed as expected.
Expect protocol handshake to be done here."""
raise NotImplementedError
def _custom_tls(
self,
ssl_context: SSLContext | None = None,
ca_certs: str | None = None,
ca_cert_dir: str | None = None,
ca_cert_data: None | str | bytes = None,
ssl_minimum_version: int | None = None,
ssl_maximum_version: int | None = None,
cert_file: str | None = None,
key_file: str | None = None,
key_password: str | None = None,
) -> bool:
"""This method serve as bypassing any default tls setup.
It is most useful when the encryption does not lie on the TCP layer. This method
WILL raise NotImplementedError if the connection is not concerned."""
raise NotImplementedError
def set_tunnel(
self,
host: str,
port: int | None = None,
headers: typing.Mapping[str, str] | None = None,
scheme: str = "http",
) -> None:
"""Prepare the connection to set up a tunnel. Does NOT actually do the socket and http connect.
Here host:port represent the target (final) server and not the intermediary."""
raise NotImplementedError
def putrequest(
self,
method: str,
url: str,
skip_host: bool = False,
skip_accept_encoding: bool = False,
) -> None:
"""It is the first method called, setting up the request initial context."""
raise NotImplementedError
def putheader(self, header: str, *values: str) -> None:
"""For a single header name, assign one or multiple value. This method is called right after putrequest()
for each entries."""
raise NotImplementedError
def endheaders(
self,
message_body: bytes | None = None,
*,
encode_chunked: bool = False,
expect_body_afterward: bool = False,
) -> ResponsePromise | None:
"""This method conclude the request context construction."""
raise NotImplementedError
def getresponse(
self, *, promise: ResponsePromise | None = None
) -> LowLevelResponse:
"""Fetch the HTTP response. You SHOULD not retrieve the body in that method, it SHOULD be done
in the LowLevelResponse, so it enable stream capabilities and remain efficient.
"""
raise NotImplementedError
def close(self) -> None:
"""End the connection, do some reinit, closing of fd, etc..."""
raise NotImplementedError
def send(
self,
data: bytes | bytearray,
*,
eot: bool = False,
) -> ResponsePromise | None:
"""The send() method SHOULD be invoked after calling endheaders() if and only if the request
context specify explicitly that a body is going to be sent."""
raise NotImplementedError
def ping(self) -> None:
"""Send a PING to the remote peer."""
raise NotImplementedError

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,237 @@
"""
This module contains provisional support for SOCKS proxies from within
urllib3. This module supports SOCKS4, SOCKS4A (an extension of SOCKS4), and
SOCKS5. To enable its functionality, either install PySocks or install this
module with the ``socks`` extra.
The SOCKS implementation supports the full range of urllib3 features. It also
supports the following SOCKS features:
- SOCKS4A (``proxy_url='socks4a://...``)
- SOCKS4 (``proxy_url='socks4://...``)
- SOCKS5 with remote DNS (``proxy_url='socks5h://...``)
- SOCKS5 with local DNS (``proxy_url='socks5://...``)
- Usernames and passwords for the SOCKS proxy
.. note::
It is recommended to use ``socks5h://`` or ``socks4a://`` schemes in
your ``proxy_url`` to ensure that DNS resolution is done from the remote
server instead of client-side when connecting to a domain name.
SOCKS4 supports IPv4 and domain names with the SOCKS4A extension. SOCKS5
supports IPv4, IPv6, and domain names.
When connecting to a SOCKS4 proxy the ``username`` portion of the ``proxy_url``
will be sent as the ``userid`` section of the SOCKS request:
.. code-block:: python
proxy_url="socks4a://<userid>@proxy-host"
When connecting to a SOCKS5 proxy the ``username`` and ``password`` portion
of the ``proxy_url`` will be sent as the username/password to authenticate
with the proxy:
.. code-block:: python
proxy_url="socks5h://<username>:<password>@proxy-host"
"""
from __future__ import annotations
try:
import socks
except ImportError:
import warnings
from ..exceptions import DependencyWarning
warnings.warn(
(
"SOCKS support in urllib3 requires the installation of optional "
"dependencies: specifically, PySocks. For more information, see "
"https://urllib3.readthedocs.io/en/latest/contrib.html#socks-proxies"
),
DependencyWarning,
)
raise
import typing
from socket import timeout as SocketTimeout
from .._typing import _TYPE_SOCKS_OPTIONS
from ..backend import HttpVersion
from ..connection import HTTPConnection, HTTPSConnection
from ..connectionpool import HTTPConnectionPool, HTTPSConnectionPool
from ..exceptions import ConnectTimeoutError, NewConnectionError
from ..poolmanager import PoolManager
from ..util.url import parse_url
try:
import ssl
except ImportError:
ssl = None # type: ignore[assignment]
class SOCKSConnection(HTTPConnection):
"""
A plain-text HTTP connection that connects via a SOCKS proxy.
"""
def __init__(
self,
_socks_options: _TYPE_SOCKS_OPTIONS,
*args: typing.Any,
**kwargs: typing.Any,
) -> None:
self._socks_options = _socks_options
super().__init__(*args, **kwargs)
def _new_conn(self) -> socks.socksocket:
"""
Establish a new connection via the SOCKS proxy.
"""
extra_kw: dict[str, typing.Any] = {}
if self.source_address:
extra_kw["source_address"] = self.source_address
if self.socket_options:
only_tcp_options = []
for opt in self.socket_options:
if len(opt) == 3:
only_tcp_options.append(opt)
elif len(opt) == 4:
protocol: str = opt[3].lower()
if protocol == "udp":
continue
only_tcp_options.append(opt[:3])
extra_kw["socket_options"] = only_tcp_options
try:
conn = socks.create_connection(
(self.host, self.port),
proxy_type=self._socks_options["socks_version"], # type: ignore[arg-type]
proxy_addr=self._socks_options["proxy_host"],
proxy_port=self._socks_options["proxy_port"], # type: ignore[arg-type]
proxy_username=self._socks_options["username"],
proxy_password=self._socks_options["password"],
proxy_rdns=self._socks_options["rdns"],
timeout=self.timeout, # type: ignore[arg-type]
**extra_kw,
)
except SocketTimeout as e:
raise ConnectTimeoutError(
self,
f"Connection to {self.host} timed out. (connect timeout={self.timeout})",
) from e
except socks.ProxyError as e:
# This is fragile as hell, but it seems to be the only way to raise
# useful errors here.
if e.socket_err:
error = e.socket_err
if isinstance(error, SocketTimeout):
raise ConnectTimeoutError(
self,
f"Connection to {self.host} timed out. (connect timeout={self.timeout})",
) from e
else:
# Adding `from e` messes with coverage somehow, so it's omitted.
# See #2386.
raise NewConnectionError(
self, f"Failed to establish a new connection: {error}"
)
else:
raise NewConnectionError(
self, f"Failed to establish a new connection: {e}"
) from e
except OSError as e: # Defensive: PySocks should catch all these.
raise NewConnectionError(
self, f"Failed to establish a new connection: {e}"
) from e
return conn
# We don't need to duplicate the Verified/Unverified distinction from
# urllib3/connection.py here because the HTTPSConnection will already have been
# correctly set to either the Verified or Unverified form by that module. This
# means the SOCKSHTTPSConnection will automatically be the correct type.
class SOCKSHTTPSConnection(SOCKSConnection, HTTPSConnection):
pass
class SOCKSHTTPConnectionPool(HTTPConnectionPool):
ConnectionCls = SOCKSConnection
class SOCKSHTTPSConnectionPool(HTTPSConnectionPool):
ConnectionCls = SOCKSHTTPSConnection
class SOCKSProxyManager(PoolManager):
"""
A version of the urllib3 ProxyManager that routes connections via the
defined SOCKS proxy.
"""
pool_classes_by_scheme = {
"http": SOCKSHTTPConnectionPool,
"https": SOCKSHTTPSConnectionPool,
}
def __init__(
self,
proxy_url: str,
username: str | None = None,
password: str | None = None,
num_pools: int = 10,
headers: typing.Mapping[str, str] | None = None,
**connection_pool_kw: typing.Any,
):
parsed = parse_url(proxy_url)
if username is None and password is None and parsed.auth is not None:
split = parsed.auth.split(":")
if len(split) == 2:
username, password = split
if parsed.scheme == "socks5":
socks_version = socks.PROXY_TYPE_SOCKS5
rdns = False
elif parsed.scheme == "socks5h":
socks_version = socks.PROXY_TYPE_SOCKS5
rdns = True
elif parsed.scheme == "socks4":
socks_version = socks.PROXY_TYPE_SOCKS4
rdns = False
elif parsed.scheme == "socks4a":
socks_version = socks.PROXY_TYPE_SOCKS4
rdns = True
else:
raise ValueError(f"Unable to determine SOCKS version from {proxy_url}")
self.proxy_url = proxy_url
socks_options = {
"socks_version": socks_version,
"proxy_host": parsed.host,
"proxy_port": parsed.port,
"username": username,
"password": password,
"rdns": rdns,
}
connection_pool_kw["_socks_options"] = socks_options
if "disabled_svn" not in connection_pool_kw:
connection_pool_kw["disabled_svn"] = set()
connection_pool_kw["disabled_svn"].add(HttpVersion.h3)
super().__init__(num_pools, headers, **connection_pool_kw)
self.pool_classes_by_scheme = SOCKSProxyManager.pool_classes_by_scheme

View File

@@ -0,0 +1,173 @@
"""
This is hazmat. It can blow up anytime.
Use it with precautions!
Reasoning behind this:
1) python-socks requires another dependency, namely asyncio-timeout, that is one too much for us.
2) it does not support our AsyncSocket wrapper (it has his own internally)
"""
from __future__ import annotations
import asyncio
import socket
import typing
import warnings
from python_socks import _abc as abc
# look the other way if unpleasant. No choice for now.
# will start discussions once we have a solid traffic.
from python_socks._connectors.abc import AsyncConnector
from python_socks._connectors.socks4_async import Socks4AsyncConnector
from python_socks._connectors.socks5_async import Socks5AsyncConnector
from python_socks._errors import ProxyError, ProxyTimeoutError
from python_socks._helpers import parse_proxy_url
from python_socks._protocols.errors import ReplyError
from python_socks._types import ProxyType
from .ssa import AsyncSocket
from .ssa._timeout import timeout as timeout_
class Resolver(abc.AsyncResolver):
def __init__(self, loop: asyncio.AbstractEventLoop):
self._loop = loop
async def resolve(
self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_UNSPEC
) -> tuple[socket.AddressFamily, str]:
infos = await self._loop.getaddrinfo(
host=host,
port=port,
family=family,
type=socket.SOCK_STREAM,
)
if not infos: # Defensive:
raise OSError(f"Can`t resolve address {host}:{port} [{family}]")
infos = sorted(infos, key=lambda info: info[0])
family, _, _, _, address = infos[0]
return family, address[0]
def create_connector(
proxy_type: ProxyType,
username: str | None,
password: str | None,
rdns: bool,
resolver: abc.AsyncResolver,
) -> AsyncConnector:
if proxy_type == ProxyType.SOCKS4:
return Socks4AsyncConnector(
user_id=username,
rdns=rdns,
resolver=resolver,
)
if proxy_type == ProxyType.SOCKS5:
return Socks5AsyncConnector(
username=username,
password=password,
rdns=rdns,
resolver=resolver,
)
raise ValueError(f"Invalid proxy type: {proxy_type}")
class AsyncioProxy:
def __init__(
self,
proxy_type: ProxyType,
host: str,
port: int,
username: str | None = None,
password: str | None = None,
rdns: bool = False,
):
self._loop = asyncio.get_event_loop()
self._proxy_type = proxy_type
self._proxy_host = host
self._proxy_port = port
self._password = password
self._username = username
self._rdns = rdns
self._resolver = Resolver(loop=self._loop)
async def connect(
self,
dest_host: str,
dest_port: int,
timeout: float | None = None,
_socket: AsyncSocket | None = None,
) -> AsyncSocket:
if timeout is None:
timeout = 60
try:
async with timeout_(timeout):
# our dependency started to deprecate passing "_socket"
# which is ... vital for our integration. We'll start by silencing the warning.
# then we'll think on how to proceed.
# A) the maintainer agrees to revert https://github.com/romis2012/python-socks/commit/173a7390469c06aa033f8dca67c827854b462bc3#diff-e4086fa970d1c98b1eb341e58cb70e9ceffe7391b2feecc4b66c7e92ea2de76fR64
# B) the maintainer pursue the removal -> do we vendor our copy of python-socks? is there an alternative?
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
return await self._connect(
dest_host=dest_host,
dest_port=dest_port,
_socket=_socket, # type: ignore[arg-type]
)
except asyncio.TimeoutError as e:
raise ProxyTimeoutError(f"Proxy connection timed out: {timeout}") from e
async def _connect(
self, dest_host: str, dest_port: int, _socket: AsyncSocket
) -> AsyncSocket:
try:
connector = create_connector(
proxy_type=self._proxy_type,
username=self._username,
password=self._password,
rdns=self._rdns,
resolver=self._resolver,
)
await connector.connect(
stream=_socket, # type: ignore[arg-type]
host=dest_host,
port=dest_port,
)
return _socket
except asyncio.CancelledError: # Defensive:
_socket.close()
raise
except ReplyError as e:
_socket.close()
raise ProxyError(e, error_code=e.error_code) # type: ignore[no-untyped-call]
except Exception: # Defensive:
_socket.close()
raise
@property
def proxy_host(self) -> str:
return self._proxy_host
@property
def proxy_port(self) -> int:
return self._proxy_port
@classmethod
def create(cls, *args: typing.Any, **kwargs: typing.Any) -> AsyncioProxy:
return cls(*args, **kwargs)
@classmethod
def from_url(cls, url: str, **kwargs: typing.Any) -> AsyncioProxy:
url_args = parse_proxy_url(url)
return cls(*url_args, **kwargs)

View File

@@ -0,0 +1,27 @@
# Dummy file to match upstream modules
# without actually serving them.
# urllib3-future diverged from urllib3.
# only the top-level (public API) are guaranteed to be compatible.
# in-fact urllib3-future propose a better way to migrate/transition toward
# newer protocols.
from __future__ import annotations
import warnings
def inject_into_urllib3() -> None:
warnings.warn(
(
"urllib3-future does not support WASM / Emscripten platform. "
"Please reinstall legacy urllib3 in the meantime. "
"Run `pip uninstall -y urllib3 urllib3-future` then "
"`pip install urllib3-future`, finally `pip install urllib3`. "
"Sorry for the inconvenience."
),
DeprecationWarning,
)
def extract_from_urllib3() -> None:
pass

View File

@@ -0,0 +1,39 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from ._configuration import QuicTLSConfig
from .protocols import (
HTTP1Protocol,
HTTP2Protocol,
HTTP3Protocol,
HTTPOverQUICProtocol,
HTTPOverTCPProtocol,
HTTPProtocol,
HTTPProtocolFactory,
)
__all__ = (
"QuicTLSConfig",
"HTTP1Protocol",
"HTTP2Protocol",
"HTTP3Protocol",
"HTTPOverQUICProtocol",
"HTTPOverTCPProtocol",
"HTTPProtocol",
"HTTPProtocolFactory",
)

View File

@@ -0,0 +1,59 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import dataclasses
from typing import Any, Mapping
@dataclasses.dataclass
class QuicTLSConfig:
"""
Client TLS configuration.
"""
#: Allows to proceed for server without valid TLS certificates.
insecure: bool = False
#: File with CA certificates to trust for server verification
cafile: str | None = None
#: Directory with CA certificates to trust for server verification
capath: str | None = None
#: Blob with CA certificates to trust for server verification
cadata: bytes | None = None
#: If provided, will trigger an additional load_cert_chain() upon the QUIC Configuration
certfile: str | bytes | None = None
keyfile: str | bytes | None = None
keypassword: str | bytes | None = None
#: The QUIC session ticket which should be used for session resumption
session_ticket: Any | None = None
cert_fingerprint: str | None = None
cert_use_common_name: bool = False
verify_hostname: bool = True
assert_hostname: str | None = None
ciphers: list[Mapping[str, Any]] | None = None
idle_timeout: float = 300.0

View File

@@ -0,0 +1,151 @@
from __future__ import annotations
import typing
from collections import deque
from .events import Event
class StreamMatrix:
"""Efficient way to store events for concurrent streams."""
__slots__ = (
"_matrix",
"_count",
"_event_cursor_id",
)
def __init__(self) -> None:
self._matrix: dict[int | None, deque[Event]] = {}
self._count: int = 0
self._event_cursor_id: int = 0
def __len__(self) -> int:
return self._count
def __bool__(self) -> bool:
return self._count > 0
@property
def streams(self) -> list[int]:
return sorted(i for i in self._matrix.keys() if i is not None)
def append(self, event: Event) -> None:
matrix_idx = getattr(event, "stream_id", None)
event._id = self._event_cursor_id
self._event_cursor_id += 1
if matrix_idx not in self._matrix:
self._matrix[matrix_idx] = deque()
self._matrix[matrix_idx].append(event)
self._count += 1
def extend(self, events: typing.Iterable[Event]) -> None:
triaged_events: dict[int | None, list[Event]] = {}
for event in events:
matrix_idx = getattr(event, "stream_id", None)
event._id = self._event_cursor_id
self._event_cursor_id += 1
self._count += 1
if matrix_idx not in triaged_events:
triaged_events[matrix_idx] = []
triaged_events[matrix_idx].append(event)
for k, v in triaged_events.items():
if k not in self._matrix:
self._matrix[k] = deque()
self._matrix[k].extend(v)
def appendleft(self, event: Event) -> None:
matrix_idx = getattr(event, "stream_id", None)
event._id = self._event_cursor_id
self._event_cursor_id += 1
if matrix_idx not in self._matrix:
self._matrix[matrix_idx] = deque()
self._matrix[matrix_idx].appendleft(event)
self._count += 1
def popleft(self, stream_id: int | None = None) -> Event | None:
if self._count == 0:
return None
have_global_event: bool = None in self._matrix and bool(self._matrix[None])
any_stream_event: bool = (
bool(self._matrix) if not have_global_event else len(self._matrix) > 1
)
if stream_id is None and any_stream_event:
matrix_dict_iter = self._matrix.__iter__()
stream_id = next(matrix_dict_iter)
if stream_id is None:
stream_id = next(matrix_dict_iter)
if (
stream_id is not None
and have_global_event
and stream_id in self._matrix
and self._matrix[None][0]._id < self._matrix[stream_id][0]._id
):
stream_id = None
elif have_global_event is True and stream_id not in self._matrix:
stream_id = None
if stream_id not in self._matrix:
return None
ev = self._matrix[stream_id].popleft()
if ev is not None:
self._count -= 1
if stream_id is not None and not self._matrix[stream_id]:
del self._matrix[stream_id]
return ev
def count(
self,
stream_id: int | None = None,
excl_event: tuple[type[Event], ...] | None = None,
) -> int:
if stream_id is None:
return self._count
if stream_id not in self._matrix:
return 0
return len(
self._matrix[stream_id]
if excl_event is None
else [e for e in self._matrix[stream_id] if not isinstance(e, excl_event)]
)
def has(
self,
stream_id: int | None = None,
excl_event: tuple[type[Event], ...] | None = None,
) -> bool:
if stream_id is None:
return True if self._count else False
if stream_id not in self._matrix:
return False
if excl_event is not None:
return any(
e for e in self._matrix[stream_id] if not isinstance(e, excl_event)
)
return True if self._matrix[stream_id] else False

View File

@@ -0,0 +1,25 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Sequence, Tuple
HeaderType = Tuple[bytes, bytes]
HeadersType = Sequence[HeaderType]
AddressType = Tuple[str, int]
DatagramType = Tuple[bytes, AddressType]

View File

@@ -0,0 +1,43 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from ._events import (
ConnectionTerminated,
DataReceived,
EarlyHeadersReceived,
Event,
GoawayReceived,
HandshakeCompleted,
HeadersReceived,
StreamEvent,
StreamReset,
StreamResetReceived,
)
__all__ = (
"Event",
"ConnectionTerminated",
"GoawayReceived",
"StreamEvent",
"StreamReset",
"StreamResetReceived",
"HeadersReceived",
"DataReceived",
"HandshakeCompleted",
"EarlyHeadersReceived",
)

View File

@@ -0,0 +1,202 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing
from dataclasses import dataclass, field
from .._typing import HeadersType
class Event:
"""
Base class for HTTP events.
This is an abstract base class that should not be initialized.
"""
_id: int
#
# Connection events
#
@dataclass
class ConnectionTerminated(Event):
"""
Connection was terminated.
Extends :class:`.Event`.
"""
#: Reason for closing the connection.
error_code: int = 0
#: Optional message with more information
message: str | None = field(default=None, compare=False)
def __repr__(self) -> str: # Defensive: debug purposes only
cls = type(self).__name__
return f"{cls}(error_code={self.error_code!r}, message={self.message!r})"
@dataclass
class GoawayReceived(Event):
"""
GOAWAY frame was received
Extends :class:`.Event`.
"""
#: Highest stream ID that could be processed.
last_stream_id: int
#: Reason for closing the connection.
error_code: int = 0
def __repr__(self) -> str: # Defensive: debug purposes only
cls = type(self).__name__
return (
f"{cls}(last_stream_id={self.last_stream_id!r}, "
f"error_code={self.error_code!r})"
)
#
# Stream events
#
@dataclass
class StreamEvent(Event):
"""
Event on one HTTP stream.
This is an abstract base class that should not be used directly.
Extends :class:`.Event`.
"""
#: Stream ID
stream_id: int
@dataclass
class StreamReset(StreamEvent):
"""
One stream of an HTTP connection was reset.
When a stream is reset, it must no longer be used, but the parent
connection and other streams are unaffected.
This is an abstract base class that should not be used directly.
More specific subclasses (StreamResetSent or StreamResetReceived)
should be emitted.
Extends :class:`.StreamEvent`.
"""
#: Reason for closing the stream.
error_code: int = 0
end_stream: bool = True
def __repr__(self) -> str: # Defensive: debug purposes only
cls = type(self).__name__
return f"{cls}(stream_id={self.stream_id!r}, error_code={self.error_code!r})"
@dataclass
class StreamResetReceived(StreamReset):
"""
One stream of an HTTP connection was reset by the peer.
This probably means that we did something that the peer does not like.
Extends :class:`.StreamReset`.
"""
@dataclass
class HandshakeCompleted(Event):
alpn_protocol: str | None
def __repr__(self) -> str: # Defensive: debug purposes only
cls = type(self).__name__
return f"{cls}(alpn={self.alpn_protocol})"
@dataclass
class HeadersReceived(StreamEvent):
"""
A frame with HTTP headers was received.
Extends :class:`.StreamEvent`.
"""
#: The received HTTP headers
headers: HeadersType
#: Signals that data will not be sent by the peer over the stream.
end_stream: bool = False
def __repr__(self) -> str: # Defensive: debug purposes only
cls = type(self).__name__
return (
f"{cls}(stream_id={self.stream_id!r}, "
f"len(headers)={len(self.headers)}, end_stream={self.end_stream!r})"
)
@dataclass
class DataReceived(StreamEvent):
"""
A frame with HTTP data was received.
Extends :class:`.StreamEvent`.
"""
#: The received data.
data: bytes
#: Signals that no more data will be sent by the peer over the stream.
end_stream: bool = False
def __repr__(self) -> str: # Defensive: debug purposes only
cls = type(self).__name__
return (
f"{cls}(stream_id={self.stream_id!r}, "
f"len(data)={len(self.data)}, end_stream={self.end_stream!r})"
)
@dataclass
class EarlyHeadersReceived(StreamEvent):
#: The received HTTP headers
headers: HeadersType
def __repr__(self) -> str: # Defensive: debug purposes only
cls = type(self).__name__
return (
f"{cls}(stream_id={self.stream_id!r}, "
f"len(headers)={len(self.headers)}, end_stream=False)"
)
@property
def end_stream(self) -> typing.Literal[False]:
return False

View File

@@ -0,0 +1,37 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from ._factories import HTTPProtocolFactory
from ._protocols import (
HTTP1Protocol,
HTTP2Protocol,
HTTP3Protocol,
HTTPOverQUICProtocol,
HTTPOverTCPProtocol,
HTTPProtocol,
)
__all__ = (
"HTTP1Protocol",
"HTTP2Protocol",
"HTTP3Protocol",
"HTTPOverQUICProtocol",
"HTTPOverTCPProtocol",
"HTTPProtocol",
"HTTPProtocolFactory",
)

View File

@@ -0,0 +1,90 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
HTTP factories create HTTP protools based on defined set of arguments.
We define the :class:`HTTPProtocol` interface to allow interchange
HTTP versions and protocol implementations. But constructors of
the class is not part of the interface. Every implementation
can use a different options to init instances.
Factories unify access to the creation of the protocol instances,
so that clients and servers can swap protocol implementations,
delegating the initialization to factories.
"""
from __future__ import annotations
import importlib
import inspect
from abc import ABCMeta
from typing import Any
from ._protocols import HTTPOverQUICProtocol, HTTPOverTCPProtocol, HTTPProtocol
class HTTPProtocolFactory(metaclass=ABCMeta):
@staticmethod
def new(
type_protocol: type[HTTPProtocol],
implementation: str | None = None,
**kwargs: Any,
) -> HTTPOverQUICProtocol | HTTPOverTCPProtocol:
"""Create a new state-machine that target given protocol type."""
assert type_protocol != HTTPProtocol, (
"HTTPProtocol is ambiguous and cannot be requested in the factory."
)
package_name: str = __name__.split(".")[0]
version_target: str = "".join(
c for c in str(type_protocol).replace(package_name, "") if c.isdigit()
)
module_expr: str = f".protocols.http{version_target}"
if implementation:
module_expr += f"._{implementation.lower()}"
try:
http_module = importlib.import_module(
module_expr, f"{package_name}.contrib.hface"
)
except ImportError as e:
raise NotImplementedError(
f"{type_protocol} cannot be loaded. Tried to import '{module_expr}'."
) from e
implementations: list[
tuple[str, type[HTTPOverQUICProtocol | HTTPOverTCPProtocol]]
] = inspect.getmembers(
http_module,
lambda e: isinstance(e, type)
and issubclass(e, (HTTPOverQUICProtocol, HTTPOverTCPProtocol)),
)
if not implementations:
raise NotImplementedError(
f"{type_protocol} cannot be loaded. "
"No compatible implementation available. "
"Make sure your implementation inherit either from HTTPOverQUICProtocol or HTTPOverTCPProtocol."
)
implementation_target: type[HTTPOverQUICProtocol | HTTPOverTCPProtocol] = (
implementations.pop()[1]
)
return implementation_target(**kwargs)

View File

@@ -0,0 +1,358 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing
from abc import ABCMeta, abstractmethod
from typing import Any, Sequence
if typing.TYPE_CHECKING:
from typing_extensions import Literal
from .._typing import HeadersType
from ..events import Event
class BaseProtocol(metaclass=ABCMeta):
"""Sans-IO common methods whenever it is TCP, UDP or QUIC."""
@abstractmethod
def bytes_received(self, data: bytes) -> None:
"""
Called when some data is received.
"""
raise NotImplementedError
# Sending direction
@abstractmethod
def bytes_to_send(self) -> bytes:
"""
Returns data for sending out of the internal data buffer.
"""
raise NotImplementedError
@abstractmethod
def connection_lost(self) -> None:
"""
Called when the connection is lost or closed.
"""
raise NotImplementedError
def should_wait_remote_flow_control(
self, stream_id: int, amt: int | None = None
) -> bool | None:
"""
Verify if the client should listen network incoming data for
the flow control update purposes.
"""
raise NotImplementedError
def max_frame_size(self) -> int:
"""
Determine if the remote set a limited size for each data frame.
"""
raise NotImplementedError
class OverTCPProtocol(BaseProtocol, metaclass=ABCMeta):
"""
Interface for sans-IO protocols on top TCP.
"""
@abstractmethod
def eof_received(self) -> None:
"""
Called when the other end signals it wont send any more data.
"""
raise NotImplementedError
class OverUDPProtocol(BaseProtocol, metaclass=ABCMeta):
"""
Interface for sans-IO protocols on top UDP.
"""
class OverQUICProtocol(OverUDPProtocol):
@property
@abstractmethod
def connection_ids(self) -> Sequence[bytes]:
"""
QUIC connection IDs
This property can be used to assign UDP packets to QUIC connections.
:return: a sequence of connection IDs
"""
raise NotImplementedError
@property
@abstractmethod
def session_ticket(self) -> Any | None:
raise NotImplementedError
@typing.overload
def getpeercert(self, *, binary_form: Literal[True]) -> bytes: ...
@typing.overload
def getpeercert(self, *, binary_form: Literal[False] = ...) -> dict[str, Any]: ...
@abstractmethod
def getpeercert(self, *, binary_form: bool = False) -> bytes | dict[str, Any]:
raise NotImplementedError
@typing.overload
def getissuercert(self, *, binary_form: Literal[True]) -> bytes | None: ...
@typing.overload
def getissuercert(
self, *, binary_form: Literal[False] = ...
) -> dict[str, Any] | None: ...
@abstractmethod
def getissuercert(
self, *, binary_form: bool = False
) -> bytes | dict[str, Any] | None:
raise NotImplementedError
@abstractmethod
def cipher(self) -> str | None:
raise NotImplementedError
class HTTPProtocol(metaclass=ABCMeta):
"""
Sans-IO representation of an HTTP connection
"""
implementation: str
@staticmethod
@abstractmethod
def exceptions() -> tuple[type[BaseException], ...]:
"""Return exception types that should be handled in your application."""
raise NotImplementedError
@property
@abstractmethod
def multiplexed(self) -> bool:
"""
Whether this connection supports multiple parallel streams.
Returns ``True`` for HTTP/2 and HTTP/3 connections.
"""
raise NotImplementedError
@property
@abstractmethod
def max_stream_count(self) -> int:
"""Determine how much concurrent stream the connection can handle."""
raise NotImplementedError
@abstractmethod
def is_idle(self) -> bool:
"""
Return True if this connection is BOTH available and not doing anything.
"""
raise NotImplementedError
@abstractmethod
def is_available(self) -> bool:
"""
Return whether this connection is capable to open new streams.
"""
raise NotImplementedError
@abstractmethod
def has_expired(self) -> bool:
"""
Return whether this connection is closed or should be closed.
"""
raise NotImplementedError
@abstractmethod
def get_available_stream_id(self) -> int:
"""
Return an ID that can be used to create a new stream.
Use the returned ID with :meth:`.submit_headers` to create the stream.
This method may or may not return one value until that method is called.
:return: stream ID
"""
raise NotImplementedError
@abstractmethod
def submit_headers(
self, stream_id: int, headers: HeadersType, end_stream: bool = False
) -> None:
"""
Submit a frame with HTTP headers.
If this is a client connection, this method starts an HTTP request.
If this is a server connection, it starts an HTTP response.
:param stream_id: stream ID
:param headers: HTTP headers
:param end_stream: whether to close the stream for sending
"""
raise NotImplementedError
@abstractmethod
def submit_data(
self, stream_id: int, data: bytes, end_stream: bool = False
) -> None:
"""
Submit a frame with HTTP data.
:param stream_id: stream ID
:param data: payload
:param end_stream: whether to close the stream for sending
"""
raise NotImplementedError
@abstractmethod
def submit_stream_reset(self, stream_id: int, error_code: int = 0) -> None:
"""
Immediate terminate a stream.
Stream reset is used to request cancellation of a stream
or to indicate that an error condition has occurred.
Use :attr:`.error_codes` to obtain error codes for common problems.
:param stream_id: stream ID
:param error_code: indicates why the stream is being terminated
"""
raise NotImplementedError
@abstractmethod
def submit_close(self, error_code: int = 0) -> None:
"""
Submit graceful close the connection.
Use :attr:`.error_codes` to obtain error codes for common problems.
:param error_code: indicates why the connections is being closed
"""
raise NotImplementedError
@abstractmethod
def next_event(self, stream_id: int | None = None) -> Event | None:
"""
Consume next HTTP event.
:return: an event instance
"""
raise NotImplementedError
def events(self, stream_id: int | None = None) -> typing.Iterator[Event]:
"""
Consume available HTTP events.
:return: an iterator that unpack "next_event" until exhausted.
"""
while True:
ev = self.next_event(stream_id=stream_id)
if ev is None:
break
yield ev
@abstractmethod
def has_pending_event(
self,
*,
stream_id: int | None = None,
excl_event: tuple[type[Event], ...] | None = None,
) -> bool:
"""Verify if there is queued event waiting to be consumed."""
raise NotImplementedError
@abstractmethod
def reshelve(self, *events: Event) -> None:
"""Put back events into the deque."""
raise NotImplementedError
@abstractmethod
def ping(self) -> None:
"""Send a PING frame to the remote peer. Thus keeping the connection alive."""
raise NotImplementedError
class HTTPOverTCPProtocol(HTTPProtocol, OverTCPProtocol, metaclass=ABCMeta):
"""
:class:`HTTPProtocol` over a TCP connection
An interface for HTTP/1 and HTTP/2 protocols.
Extends :class:`.HTTPProtocol`.
"""
class HTTPOverQUICProtocol(HTTPProtocol, OverQUICProtocol, metaclass=ABCMeta):
"""
:class:`HTTPProtocol` over a QUIC connection
Abstract base class for HTTP/3 protocols.
Extends :class:`.HTTPProtocol`.
"""
class HTTP1Protocol(HTTPOverTCPProtocol, metaclass=ABCMeta):
"""
Sans-IO representation of an HTTP/1 connection
An interface for HTTP/1 implementations.
Extends :class:`.HTTPOverTCPProtocol`.
"""
@property
def multiplexed(self) -> bool:
return False
def should_wait_remote_flow_control(
self, stream_id: int, amt: int | None = None
) -> bool | None:
return NotImplemented # type: ignore[no-any-return]
class HTTP2Protocol(HTTPOverTCPProtocol, metaclass=ABCMeta):
"""
Sans-IO representation of an HTTP/2 connection
An abstract base class for HTTP/2 implementations.
Extends :class:`.HTTPOverTCPProtocol`.
"""
@property
def multiplexed(self) -> bool:
return True
class HTTP3Protocol(HTTPOverQUICProtocol, metaclass=ABCMeta):
"""
Sans-IO representation of an HTTP/2 connection
An abstract base class for HTTP/3 implementations.
Extends :class:`.HTTPOverQUICProtocol`
"""
@property
def multiplexed(self) -> bool:
return True

View File

@@ -0,0 +1,21 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from ._h11 import HTTP1ProtocolHyperImpl
__all__ = ("HTTP1ProtocolHyperImpl",)

View File

@@ -0,0 +1,347 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import warnings
from functools import lru_cache
import h11
from h11._state import _SWITCH_UPGRADE, ConnectionState
from ..._stream_matrix import StreamMatrix
from ..._typing import HeadersType
from ...events import (
ConnectionTerminated,
DataReceived,
EarlyHeadersReceived,
Event,
HeadersReceived,
)
from .._protocols import HTTP1Protocol
@lru_cache(maxsize=64)
def capitalize_header_name(name: bytes) -> bytes:
"""
Take a header name and capitalize it.
>>> capitalize_header_name(b"x-hEllo-wORLD")
'X-Hello-World'
>>> capitalize_header_name(b"server")
'Server'
>>> capitalize_header_name(b"contEnt-TYPE")
'Content-Type'
>>> capitalize_header_name(b"content_type")
'Content-Type'
"""
return b"-".join(el.capitalize() for el in name.split(b"-"))
def headers_to_request(headers: HeadersType) -> h11.Event:
method = authority = path = host = None
regular_headers = []
for name, value in headers:
if name.startswith(b":"):
if name == b":method":
method = value
elif name == b":scheme":
pass
elif name == b":authority":
authority = value
elif name == b":path":
path = value
else:
raise ValueError("Unexpected request header: " + name.decode())
else:
if host is None and name == b"host":
host = value
# We found that many projects... actually expect the header name to be sent capitalized... hardcoded
# within their tests. Bad news, we have to keep doing this nonsense (namely capitalize_header_name)
regular_headers.append((capitalize_header_name(name), value))
if authority is None:
raise ValueError("Missing request header: :authority")
if method == b"CONNECT" and path is None:
# CONNECT requests are a special case.
target = authority
else:
target = path # type: ignore[assignment]
if host is None:
regular_headers.insert(0, (b"Host", authority))
elif host != authority:
raise ValueError("Host header does not match :authority.")
return h11.Request(
method=method, # type: ignore[arg-type]
headers=regular_headers,
target=target,
)
def headers_from_response(
response: h11.InformationalResponse | h11.Response,
) -> HeadersType:
"""
Converts an HTTP/1.0 or HTTP/1.1 response to HTTP/2-like headers.
Generates from pseudo (colon) headers from a response line.
"""
return [
(b":status", str(response.status_code).encode("ascii"))
] + response.headers.raw_items()
class RelaxConnectionState(ConnectionState):
def process_event( # type: ignore[no-untyped-def]
self,
role,
event_type,
server_switch_event=None,
) -> None:
if server_switch_event is not None:
if server_switch_event not in self.pending_switch_proposals:
if server_switch_event is _SWITCH_UPGRADE:
warnings.warn(
f"Received server {server_switch_event} event without a pending proposal. "
"This will raise an exception in a future version. It is temporarily relaxed to match the "
"legacy http.client standard library.",
DeprecationWarning,
stacklevel=2,
)
self.pending_switch_proposals.add(_SWITCH_UPGRADE)
return super().process_event(role, event_type, server_switch_event)
class HTTP1ProtocolHyperImpl(HTTP1Protocol):
implementation: str = "h11"
def __init__(self) -> None:
self._connection: h11.Connection = h11.Connection(h11.CLIENT)
self._connection._cstate = RelaxConnectionState()
self._data_buffer: list[bytes] = []
self._events: StreamMatrix = StreamMatrix()
self._terminated: bool = False
self._switched: bool = False
self._current_stream_id: int = 1
@staticmethod
def exceptions() -> tuple[type[BaseException], ...]:
return h11.LocalProtocolError, h11.ProtocolError, h11.RemoteProtocolError
def is_available(self) -> bool:
return self._connection.our_state == self._connection.their_state == h11.IDLE
@property
def max_stream_count(self) -> int:
return 1
def is_idle(self) -> bool:
return self._connection.their_state in {
h11.IDLE,
h11.MUST_CLOSE,
}
def has_expired(self) -> bool:
return self._terminated
def get_available_stream_id(self) -> int:
if not self.is_available():
raise RuntimeError(
"Cannot generate a new stream ID because the connection is not idle. "
"HTTP/1.1 is not multiplexed and we do not support HTTP pipelining."
)
return self._current_stream_id
def submit_close(self, error_code: int = 0) -> None:
pass # no-op
def submit_headers(
self, stream_id: int, headers: HeadersType, end_stream: bool = False
) -> None:
if stream_id != self._current_stream_id:
raise ValueError("Invalid stream ID.")
self._h11_submit(headers_to_request(headers))
if end_stream:
self._h11_submit(h11.EndOfMessage())
def submit_data(
self, stream_id: int, data: bytes, end_stream: bool = False
) -> None:
if stream_id != self._current_stream_id:
raise ValueError("Invalid stream ID.")
if self._connection.their_state == h11.SWITCHED_PROTOCOL:
self._data_buffer.append(data)
if end_stream:
self._events.append(self._connection_terminated())
return
self._h11_submit(h11.Data(data))
if end_stream:
self._h11_submit(h11.EndOfMessage())
def submit_stream_reset(self, stream_id: int, error_code: int = 0) -> None:
# HTTP/1 cannot submit a stream (it does not have real streams).
# But if there are no other streams, we can close the connection instead.
self.connection_lost()
def connection_lost(self) -> None:
if self._connection.their_state == h11.SWITCHED_PROTOCOL:
self._events.append(self._connection_terminated())
return
# This method is called when the connection is closed without an EOF.
# But not all connections support EOF, so being here does not
# necessarily mean that something when wrong.
#
# The tricky part is that HTTP/1.0 server can send responses
# without Content-Length or Transfer-Encoding headers,
# meaning that a response body is closed with the connection.
# In such cases, we require a proper EOF to distinguish complete
# messages from partial messages interrupted by network failure.
if not self._terminated:
self._connection.send_failed()
self._events.append(self._connection_terminated())
def eof_received(self) -> None:
if self._connection.their_state == h11.SWITCHED_PROTOCOL:
self._events.append(self._connection_terminated())
return
self._h11_data_received(b"")
def bytes_received(self, data: bytes) -> None:
if not data:
return # h11 treats empty data as EOF.
if self._connection.their_state == h11.SWITCHED_PROTOCOL:
self._events.append(DataReceived(self._current_stream_id, data))
return
else:
self._h11_data_received(data)
def bytes_to_send(self) -> bytes:
data = b"".join(self._data_buffer)
self._data_buffer.clear()
self._maybe_start_next_cycle()
return data
def next_event(self, stream_id: int | None = None) -> Event | None:
return self._events.popleft(stream_id=stream_id)
def has_pending_event(
self,
*,
stream_id: int | None = None,
excl_event: tuple[type[Event], ...] | None = None,
) -> bool:
return self._events.has(stream_id=stream_id, excl_event=excl_event)
def _h11_submit(self, h11_event: h11.Event) -> None:
chunks = self._connection.send_with_data_passthrough(h11_event)
if chunks:
self._data_buffer += chunks
def _h11_data_received(self, data: bytes) -> None:
self._connection.receive_data(data)
self._fetch_events()
def _fetch_events(self) -> None:
a = self._events.append
while not self._terminated:
try:
h11_event = self._connection.next_event()
except h11.RemoteProtocolError as e:
a(self._connection_terminated(e.error_status_hint, str(e)))
break
ev_type = h11_event.__class__
if h11_event is h11.NEED_DATA or h11_event is h11.PAUSED:
if h11.MUST_CLOSE == self._connection.their_state:
a(self._connection_terminated())
else:
break
elif ev_type is h11.Response:
a(
HeadersReceived(
self._current_stream_id,
headers_from_response(h11_event), # type: ignore[arg-type]
)
)
elif ev_type is h11.InformationalResponse:
a(
EarlyHeadersReceived(
stream_id=self._current_stream_id,
headers=headers_from_response(h11_event), # type: ignore[arg-type]
)
)
elif ev_type is h11.Data:
# officially h11 typed data as "bytes"
# but we... found that it store bytearray sometime.
payload = h11_event.data # type: ignore[union-attr]
a(
DataReceived(
self._current_stream_id,
bytes(payload) if payload.__class__ is bytearray else payload,
)
)
elif ev_type is h11.EndOfMessage:
# HTTP/2 and HTTP/3 send END_STREAM flag with HEADERS and DATA frames.
# We emulate similar behavior for HTTP/1.
if h11_event.headers: # type: ignore[union-attr]
last_event: HeadersReceived | DataReceived = HeadersReceived(
self._current_stream_id,
h11_event.headers, # type: ignore[union-attr]
self._connection.their_state != h11.MIGHT_SWITCH_PROTOCOL, # type: ignore[attr-defined]
)
else:
last_event = DataReceived(
self._current_stream_id,
b"",
self._connection.their_state != h11.MIGHT_SWITCH_PROTOCOL, # type: ignore[attr-defined]
)
a(last_event)
self._maybe_start_next_cycle()
elif ev_type is h11.ConnectionClosed:
a(self._connection_terminated())
def _connection_terminated(
self, error_code: int = 0, message: str | None = None
) -> Event:
self._terminated = True
return ConnectionTerminated(error_code, message)
def _maybe_start_next_cycle(self) -> None:
if h11.DONE == self._connection.our_state == self._connection.their_state:
self._connection.start_next_cycle()
self._current_stream_id += 1
if h11.SWITCHED_PROTOCOL == self._connection.their_state and not self._switched:
data, closed = self._connection.trailing_data
if data:
self._events.append(DataReceived(self._current_stream_id, data))
self._switched = True
def reshelve(self, *events: Event) -> None:
for ev in reversed(events):
self._events.appendleft(ev)
def ping(self) -> None:
raise NotImplementedError("http1 does not support PING")

View File

@@ -0,0 +1,21 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from ._h2 import HTTP2ProtocolHyperImpl
__all__ = ("HTTP2ProtocolHyperImpl",)

View File

@@ -0,0 +1,312 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from secrets import token_bytes
from typing import Iterator
import jh2.config # type: ignore
import jh2.connection # type: ignore
import jh2.errors # type: ignore
import jh2.events # type: ignore
import jh2.exceptions # type: ignore
import jh2.settings # type: ignore
from ..._stream_matrix import StreamMatrix
from ..._typing import HeadersType
from ...events import (
ConnectionTerminated,
DataReceived,
EarlyHeadersReceived,
Event,
GoawayReceived,
HandshakeCompleted,
HeadersReceived,
StreamResetReceived,
)
from .._protocols import HTTP2Protocol
class _PatchedH2Connection(jh2.connection.H2Connection): # type: ignore[misc]
"""
This is a performance hotfix class. We internally, already keep
track of the open stream count.
"""
def __init__(
self,
config: jh2.config.H2Configuration | None = None,
observable_impl: HTTP2ProtocolHyperImpl | None = None,
) -> None:
super().__init__(config=config)
# by default CONNECT is disabled
# we need it to support natively WebSocket over HTTP/2 for example.
self.local_settings = jh2.settings.Settings(
client=True,
initial_values={
jh2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 100,
jh2.settings.SettingCodes.MAX_HEADER_LIST_SIZE: self.DEFAULT_MAX_HEADER_LIST_SIZE,
jh2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL: 1,
},
)
self._observable_impl = observable_impl
def _open_streams(self, *args, **kwargs) -> int: # type: ignore[no-untyped-def]
if self._observable_impl is not None:
return self._observable_impl._open_stream_count
return super()._open_streams(*args, **kwargs) # type: ignore[no-any-return]
def _receive_goaway_frame(self, frame): # type: ignore[no-untyped-def]
"""
Receive a GOAWAY frame on the connection.
We purposely override this method to work around a known bug of jh2.
"""
events = self.state_machine.process_input(
jh2.connection.ConnectionInputs.RECV_GOAWAY
)
err_code = jh2.errors._error_code_from_int(frame.error_code)
# GOAWAY allows an
# endpoint to gracefully stop accepting new streams while still
# finishing processing of previously established streams.
# see https://tools.ietf.org/html/rfc7540#section-6.8
# hyper/h2 does not allow such a thing for now. let's work around this.
if (
err_code == 0
and self._observable_impl is not None
and self._observable_impl._open_stream_count > 0
):
self.state_machine.state = jh2.connection.ConnectionState.CLIENT_OPEN
# Clear the outbound data buffer: we cannot send further data now.
self.clear_outbound_data_buffer()
# Fire an appropriate ConnectionTerminated event.
new_event = jh2.events.ConnectionTerminated()
new_event.error_code = err_code
new_event.last_stream_id = frame.last_stream_id
new_event.additional_data = (
frame.additional_data if frame.additional_data else None
)
events.append(new_event)
return [], events
HEADER_OR_TRAILER_TYPE_SET = {
jh2.events.ResponseReceived,
jh2.events.TrailersReceived,
}
class HTTP2ProtocolHyperImpl(HTTP2Protocol):
implementation: str = "h2"
def __init__(
self,
*,
validate_outbound_headers: bool = False,
validate_inbound_headers: bool = False,
normalize_outbound_headers: bool = False,
normalize_inbound_headers: bool = True,
) -> None:
self._connection: jh2.connection.H2Connection = _PatchedH2Connection(
jh2.config.H2Configuration(
client_side=True,
validate_outbound_headers=validate_outbound_headers,
normalize_outbound_headers=normalize_outbound_headers,
validate_inbound_headers=validate_inbound_headers,
normalize_inbound_headers=normalize_inbound_headers,
),
observable_impl=self,
)
self._open_stream_count: int = 0
self._connection.initiate_connection()
self._connection.increment_flow_control_window(2**24)
self._events: StreamMatrix = StreamMatrix()
self._terminated: bool = False
self._goaway_to_honor: bool = False
self._max_stream_count: int = (
self._connection.remote_settings.max_concurrent_streams
)
self._max_frame_size: int = self._connection.remote_settings.max_frame_size
def max_frame_size(self) -> int:
return self._max_frame_size
@staticmethod
def exceptions() -> tuple[type[BaseException], ...]:
return jh2.exceptions.ProtocolError, jh2.exceptions.H2Error
def is_available(self) -> bool:
if self._terminated:
return False
return self._max_stream_count > self._open_stream_count
@property
def max_stream_count(self) -> int:
return self._max_stream_count
def is_idle(self) -> bool:
return self._terminated is False and self._open_stream_count == 0
def has_expired(self) -> bool:
return self._terminated or self._goaway_to_honor
def get_available_stream_id(self) -> int:
return self._connection.get_next_available_stream_id() # type: ignore[no-any-return]
def submit_close(self, error_code: int = 0) -> None:
self._connection.close_connection(error_code)
def submit_headers(
self, stream_id: int, headers: HeadersType, end_stream: bool = False
) -> None:
self._connection.send_headers(stream_id, headers, end_stream)
self._connection.increment_flow_control_window(2**24, stream_id=stream_id)
self._open_stream_count += 1
def submit_data(
self, stream_id: int, data: bytes, end_stream: bool = False
) -> None:
self._connection.send_data(stream_id, data, end_stream)
def submit_stream_reset(self, stream_id: int, error_code: int = 0) -> None:
self._connection.reset_stream(stream_id, error_code)
def next_event(self, stream_id: int | None = None) -> Event | None:
return self._events.popleft(stream_id=stream_id)
def has_pending_event(
self,
*,
stream_id: int | None = None,
excl_event: tuple[type[Event], ...] | None = None,
) -> bool:
return self._events.has(stream_id=stream_id, excl_event=excl_event)
def _map_events(self, h2_events: list[jh2.events.Event]) -> Iterator[Event]:
for e in h2_events:
ev_type = e.__class__
if ev_type in HEADER_OR_TRAILER_TYPE_SET:
end_stream = e.stream_ended is not None
if end_stream:
self._open_stream_count -= 1
stream = self._connection.streams.pop(e.stream_id)
self._connection._closed_streams[e.stream_id] = stream.closed_by
yield HeadersReceived(e.stream_id, e.headers, end_stream=end_stream)
elif ev_type is jh2.events.DataReceived:
end_stream = e.stream_ended is not None
if end_stream:
self._open_stream_count -= 1
stream = self._connection.streams.pop(e.stream_id)
self._connection._closed_streams[e.stream_id] = stream.closed_by
self._connection.acknowledge_received_data(
e.flow_controlled_length, e.stream_id
)
yield DataReceived(e.stream_id, e.data, end_stream=end_stream)
elif ev_type is jh2.events.InformationalResponseReceived:
yield EarlyHeadersReceived(
e.stream_id,
e.headers,
)
elif ev_type is jh2.events.StreamReset:
self._open_stream_count -= 1
# event StreamEnded may occur before StreamReset
if e.stream_id in self._connection.streams:
stream = self._connection.streams.pop(e.stream_id)
self._connection._closed_streams[e.stream_id] = stream.closed_by
yield StreamResetReceived(e.stream_id, e.error_code)
elif ev_type is jh2.events.ConnectionTerminated:
# ConnectionTerminated from h2 means that GOAWAY was received.
# A server can send GOAWAY for graceful shutdown, where clients
# do not open new streams, but inflight requests can be completed.
#
# Saying "connection was terminated" can be confusing,
# so we emit an event called "GoawayReceived".
if e.error_code == 0:
self._goaway_to_honor = True
yield GoawayReceived(e.last_stream_id, e.error_code)
else:
self._terminated = True
yield ConnectionTerminated(e.error_code, None)
elif ev_type in {
jh2.events.SettingsAcknowledged,
jh2.events.RemoteSettingsChanged,
}:
yield HandshakeCompleted(alpn_protocol="h2")
def connection_lost(self) -> None:
self._connection_terminated()
def eof_received(self) -> None:
self._connection_terminated()
def bytes_received(self, data: bytes) -> None:
if not data:
return
try:
h2_events = self._connection.receive_data(data)
except jh2.exceptions.ProtocolError as e:
self._connection_terminated(e.error_code, str(e))
else:
self._events.extend(self._map_events(h2_events))
# we want to perpetually mark the connection as "saturated"
if self._goaway_to_honor:
self._max_stream_count = self._open_stream_count
if self._connection.remote_settings.has_update:
if not self._goaway_to_honor:
self._max_stream_count = (
self._connection.remote_settings.max_concurrent_streams
)
self._max_frame_size = self._connection.remote_settings.max_frame_size
def bytes_to_send(self) -> bytes:
return self._connection.data_to_send() # type: ignore[no-any-return]
def _connection_terminated(
self, error_code: int = 0, message: str | None = None
) -> None:
if self._terminated:
return
error_code = int(error_code) # Convert h2 IntEnum to an actual int
self._terminated = True
self._events.append(ConnectionTerminated(error_code, message))
def should_wait_remote_flow_control(
self, stream_id: int, amt: int | None = None
) -> bool | None:
flow_remaining_bytes: int = self._connection.local_flow_control_window(
stream_id
)
if amt is None:
return flow_remaining_bytes == 0
return amt > flow_remaining_bytes
def reshelve(self, *events: Event) -> None:
for ev in reversed(events):
self._events.appendleft(ev)
def ping(self) -> None:
self._connection.ping(token_bytes(8))

View File

@@ -0,0 +1,21 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from ._qh3 import HTTP3ProtocolAioQuicImpl
__all__ = ("HTTP3ProtocolAioQuicImpl",)

View File

@@ -0,0 +1,592 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import datetime
import ssl
import typing
from collections import deque
from os import environ
from random import randint
from time import time as monotonic
from typing import Any, Iterable, Sequence
if typing.TYPE_CHECKING:
from typing_extensions import Literal
from qh3 import (
CipherSuite,
H3Connection,
H3Error,
ProtocolError,
QuicConfiguration,
QuicConnection,
QuicConnectionError,
QuicFileLogger,
SessionTicket,
h3_events,
quic_events,
)
from qh3.h3.connection import FrameType
from qh3.quic.connection import QuicConnectionState
from ..._configuration import QuicTLSConfig
from ..._stream_matrix import StreamMatrix
from ..._typing import AddressType, HeadersType
from ...events import (
ConnectionTerminated,
DataReceived,
EarlyHeadersReceived,
Event,
GoawayReceived,
)
from ...events import HandshakeCompleted as _HandshakeCompleted
from ...events import HeadersReceived, StreamResetReceived
from .._protocols import HTTP3Protocol
QUIC_RELEVANT_EVENT_TYPES = {
quic_events.HandshakeCompleted,
quic_events.ConnectionTerminated,
quic_events.StreamReset,
}
class HTTP3ProtocolAioQuicImpl(HTTP3Protocol):
implementation: str = "qh3"
def __init__(
self,
*,
remote_address: AddressType,
server_name: str,
tls_config: QuicTLSConfig,
) -> None:
keylogfile_path: str | None = environ.get("SSLKEYLOGFILE", None)
qlogdir_path: str | None = environ.get("QUICLOGDIR", None)
self._configuration: QuicConfiguration = QuicConfiguration(
is_client=True,
verify_mode=ssl.CERT_NONE if tls_config.insecure else ssl.CERT_REQUIRED,
cafile=tls_config.cafile,
capath=tls_config.capath,
cadata=tls_config.cadata,
alpn_protocols=["h3"],
session_ticket=tls_config.session_ticket,
server_name=server_name,
hostname_checks_common_name=tls_config.cert_use_common_name,
assert_fingerprint=tls_config.cert_fingerprint,
verify_hostname=tls_config.verify_hostname,
secrets_log_file=open(keylogfile_path, "w") if keylogfile_path else None, # type: ignore[arg-type]
quic_logger=QuicFileLogger(qlogdir_path) if qlogdir_path else None,
idle_timeout=tls_config.idle_timeout,
max_data=2**24,
max_stream_data=2**24,
)
if tls_config.ciphers:
available_ciphers = {c.name: c for c in CipherSuite}
chosen_ciphers: list[CipherSuite] = []
for cipher in tls_config.ciphers:
if "name" in cipher and isinstance(cipher["name"], str):
chosen_ciphers.append(
available_ciphers[cipher["name"].replace("TLS_", "")]
)
if len(chosen_ciphers) == 0:
raise ValueError(
f"Unable to find a compatible cipher in '{tls_config.ciphers}' to establish a QUIC connection. "
f"QUIC support one of '{['TLS_' + e for e in available_ciphers.keys()]}' only."
)
self._configuration.cipher_suites = chosen_ciphers
if tls_config.certfile:
self._configuration.load_cert_chain(
tls_config.certfile,
tls_config.keyfile,
tls_config.keypassword,
)
self._quic: QuicConnection = QuicConnection(configuration=self._configuration)
self._connection_ids: set[bytes] = set()
self._remote_address = remote_address
self._events: StreamMatrix = StreamMatrix()
self._packets: deque[bytes] = deque()
self._http: H3Connection | None = None
self._terminated: bool = False
self._data_in_flight: bool = False
self._open_stream_count: int = 0
self._total_stream_count: int = 0
self._goaway_to_honor: bool = False
self._max_stream_count: int = (
100 # safe-default, broadly used. (and set by qh3)
)
self._max_frame_size: int | None = None
@staticmethod
def exceptions() -> tuple[type[BaseException], ...]:
return ProtocolError, H3Error, QuicConnectionError, AssertionError
@property
def max_stream_count(self) -> int:
return self._max_stream_count
def is_available(self) -> bool:
return (
self._terminated is False
and self._max_stream_count > self._quic.open_outbound_streams
)
def is_idle(self) -> bool:
return self._terminated is False and self._open_stream_count == 0
def has_expired(self) -> bool:
if not self._terminated and not self._goaway_to_honor:
now = monotonic()
self._quic.handle_timer(now)
self._packets.extend(
map(lambda e: e[0], self._quic.datagrams_to_send(now=now))
)
if self._quic._state in {
QuicConnectionState.CLOSING,
QuicConnectionState.TERMINATED,
}:
self._terminated = True
if (
hasattr(self._quic, "_close_event")
and self._quic._close_event is not None
):
self._events.extend(self._map_quic_event(self._quic._close_event))
self._terminated = True
return self._terminated or self._goaway_to_honor
@property
def session_ticket(self) -> SessionTicket | None:
return self._quic.tls.session_ticket if self._quic and self._quic.tls else None
def get_available_stream_id(self) -> int:
return self._quic.get_next_available_stream_id()
def submit_close(self, error_code: int = 0) -> None:
# QUIC has two different frame types for closing the connection.
# From RFC 9000 (QUIC: A UDP-Based Multiplexed and Secure Transport):
#
# > An endpoint sends a CONNECTION_CLOSE frame (type=0x1c or 0x1d)
# > to notify its peer that the connection is being closed.
# > The CONNECTION_CLOSE frame with a type of 0x1c is used to signal errors
# > at only the QUIC layer, or the absence of errors (with the NO_ERROR code).
# > The CONNECTION_CLOSE frame with a type of 0x1d is used
# > to signal an error with the application that uses QUIC.
frame_type = 0x1D if error_code else 0x1C
self._quic.close(error_code=error_code, frame_type=frame_type)
def submit_headers(
self, stream_id: int, headers: HeadersType, end_stream: bool = False
) -> None:
assert self._http is not None
self._open_stream_count += 1
self._total_stream_count += 1
self._http.send_headers(stream_id, list(headers), end_stream)
def submit_data(
self, stream_id: int, data: bytes, end_stream: bool = False
) -> None:
assert self._http is not None
self._http.send_data(stream_id, data, end_stream)
if end_stream is False:
self._data_in_flight = True
def submit_stream_reset(self, stream_id: int, error_code: int = 0) -> None:
self._quic.reset_stream(stream_id, error_code)
def next_event(self, stream_id: int | None = None) -> Event | None:
return self._events.popleft(stream_id=stream_id)
def has_pending_event(
self,
*,
stream_id: int | None = None,
excl_event: tuple[type[Event], ...] | None = None,
) -> bool:
return self._events.has(stream_id=stream_id, excl_event=excl_event)
@property
def connection_ids(self) -> Sequence[bytes]:
return list(self._connection_ids)
def connection_lost(self) -> None:
self._terminated = True
self._events.append(ConnectionTerminated())
def bytes_received(self, data: bytes) -> None:
self._quic.receive_datagram(data, self._remote_address, now=monotonic())
self._fetch_events()
if self._data_in_flight:
self._data_in_flight = False
# we want to perpetually mark the connection as "saturated"
if self._goaway_to_honor:
self._max_stream_count = self._open_stream_count
else:
# This section may confuse beginners
# See RFC 9000 -> 19.11. MAX_STREAMS Frames
# footer extract:
# Note that these frames (and the corresponding transport parameters)
# do not describe the number of streams that can be opened
# concurrently. The limit includes streams that have been closed as
# well as those that are open.
#
# so, finding that remote_max_streams_bidi is increasing constantly is normal.
new_stream_limit = (
self._quic._remote_max_streams_bidi - self._total_stream_count
)
if (
new_stream_limit
and new_stream_limit != self._max_stream_count
and new_stream_limit > 0
):
self._max_stream_count = new_stream_limit
if (
self._quic._remote_max_stream_data_bidi_remote
and self._quic._remote_max_stream_data_bidi_remote
!= self._max_frame_size
):
self._max_frame_size = self._quic._remote_max_stream_data_bidi_remote
def bytes_to_send(self) -> bytes:
if not self._packets:
now = monotonic()
if self._http is None:
self._quic.connect(self._remote_address, now=now)
self._http = H3Connection(self._quic)
# the QUIC state machine returns datagrams (addr, packet)
# the client never have to worry about the destination
# unless server yield a preferred address?
self._packets.extend(
map(lambda e: e[0], self._quic.datagrams_to_send(now=now))
)
if not self._packets:
return b""
# it is absolutely crucial to return one at a time
# because UDP don't support sending more than
# MTU (to be more precise, lowest MTU in the network path from A (you) to B (server))
return self._packets.popleft()
def _fetch_events(self) -> None:
assert self._http is not None
for quic_event in iter(self._quic.next_event, None):
self._events.extend(self._map_quic_event(quic_event))
for h3_event in self._http.handle_event(quic_event):
self._events.extend(self._map_h3_event(h3_event))
if hasattr(self._quic, "_close_event") and self._quic._close_event is not None:
self._events.extend(self._map_quic_event(self._quic._close_event))
def _map_quic_event(self, quic_event: quic_events.QuicEvent) -> Iterable[Event]:
ev_type = quic_event.__class__
# fastest path execution, most of the time we don't have those
# 3 event types.
if ev_type not in QUIC_RELEVANT_EVENT_TYPES:
return
if ev_type is quic_events.HandshakeCompleted:
yield _HandshakeCompleted(quic_event.alpn_protocol) # type: ignore[attr-defined]
elif ev_type is quic_events.ConnectionTerminated:
if quic_event.frame_type == FrameType.GOAWAY.value: # type: ignore[attr-defined]
self._goaway_to_honor = True
stream_list: list[int] = [
e for e in self._events._matrix.keys() if e is not None
]
yield GoawayReceived(stream_list[-1], quic_event.error_code) # type: ignore[attr-defined]
else:
self._terminated = True
yield ConnectionTerminated(
quic_event.error_code, # type: ignore[attr-defined]
quic_event.reason_phrase, # type: ignore[attr-defined]
)
elif ev_type is quic_events.StreamReset:
self._open_stream_count -= 1
yield StreamResetReceived(quic_event.stream_id, quic_event.error_code) # type: ignore[attr-defined]
def _map_h3_event(self, h3_event: h3_events.H3Event) -> Iterable[Event]:
ev_type = h3_event.__class__
if ev_type is h3_events.HeadersReceived:
if h3_event.stream_ended: # type: ignore[attr-defined]
self._open_stream_count -= 1
yield HeadersReceived(
h3_event.stream_id, # type: ignore[attr-defined]
h3_event.headers, # type: ignore[attr-defined]
h3_event.stream_ended, # type: ignore[attr-defined]
)
elif ev_type is h3_events.DataReceived:
if h3_event.stream_ended: # type: ignore[attr-defined]
self._open_stream_count -= 1
yield DataReceived(h3_event.stream_id, h3_event.data, h3_event.stream_ended) # type: ignore[attr-defined]
elif ev_type is h3_events.InformationalHeadersReceived:
yield EarlyHeadersReceived(
h3_event.stream_id, # type: ignore[attr-defined]
h3_event.headers, # type: ignore[attr-defined]
)
def should_wait_remote_flow_control(
self, stream_id: int, amt: int | None = None
) -> bool | None:
return self._data_in_flight
@typing.overload
def getissuercert(self, *, binary_form: Literal[True]) -> bytes | None: ...
@typing.overload
def getissuercert(
self, *, binary_form: Literal[False] = ...
) -> dict[str, Any] | None: ...
def getissuercert(
self, *, binary_form: bool = False
) -> bytes | dict[str, typing.Any] | None:
x509_certificate = self._quic.get_peercert()
if x509_certificate is None:
raise ValueError("TLS handshake has not been done yet")
if not self._quic.get_issuercerts():
return None
x509_certificate = self._quic.get_issuercerts()[0]
if binary_form:
return x509_certificate.public_bytes()
datetime.datetime.fromtimestamp(
x509_certificate.not_valid_before, tz=datetime.timezone.utc
)
issuer_info = {
"version": x509_certificate.version + 1,
"serialNumber": x509_certificate.serial_number.upper(),
"subject": [],
"issuer": [],
"notBefore": datetime.datetime.fromtimestamp(
x509_certificate.not_valid_before, tz=datetime.timezone.utc
).strftime("%b %d %H:%M:%S %Y")
+ " UTC",
"notAfter": datetime.datetime.fromtimestamp(
x509_certificate.not_valid_after, tz=datetime.timezone.utc
).strftime("%b %d %H:%M:%S %Y")
+ " UTC",
}
_short_name_assoc = {
"CN": "commonName",
"L": "localityName",
"ST": "stateOrProvinceName",
"O": "organizationName",
"OU": "organizationalUnitName",
"C": "countryName",
"STREET": "streetAddress",
"DC": "domainComponent",
"E": "email",
}
for raw_oid, rfc4514_attribute_name, value in x509_certificate.subject:
if rfc4514_attribute_name not in _short_name_assoc:
continue
issuer_info["subject"].append( # type: ignore[attr-defined]
(
(
_short_name_assoc[rfc4514_attribute_name],
value.decode(),
),
)
)
for raw_oid, rfc4514_attribute_name, value in x509_certificate.issuer:
if rfc4514_attribute_name not in _short_name_assoc:
continue
issuer_info["issuer"].append( # type: ignore[attr-defined]
(
(
_short_name_assoc[rfc4514_attribute_name],
value.decode(),
),
)
)
return issuer_info
@typing.overload
def getpeercert(self, *, binary_form: Literal[True]) -> bytes: ...
@typing.overload
def getpeercert(self, *, binary_form: Literal[False] = ...) -> dict[str, Any]: ...
def getpeercert(
self, *, binary_form: bool = False
) -> bytes | dict[str, typing.Any]:
x509_certificate = self._quic.get_peercert()
if x509_certificate is None:
raise ValueError("TLS handshake has not been done yet")
if binary_form:
return x509_certificate.public_bytes()
peer_info = {
"version": x509_certificate.version + 1,
"serialNumber": x509_certificate.serial_number.upper(),
"subject": [],
"issuer": [],
"notBefore": datetime.datetime.fromtimestamp(
x509_certificate.not_valid_before, tz=datetime.timezone.utc
).strftime("%b %d %H:%M:%S %Y")
+ " UTC",
"notAfter": datetime.datetime.fromtimestamp(
x509_certificate.not_valid_after, tz=datetime.timezone.utc
).strftime("%b %d %H:%M:%S %Y")
+ " UTC",
"subjectAltName": [],
"OCSP": [],
"caIssuers": [],
"crlDistributionPoints": [],
}
_short_name_assoc = {
"CN": "commonName",
"L": "localityName",
"ST": "stateOrProvinceName",
"O": "organizationName",
"OU": "organizationalUnitName",
"C": "countryName",
"STREET": "streetAddress",
"DC": "domainComponent",
"E": "email",
}
for raw_oid, rfc4514_attribute_name, value in x509_certificate.subject:
if rfc4514_attribute_name not in _short_name_assoc:
continue
peer_info["subject"].append( # type: ignore[attr-defined]
(
(
_short_name_assoc[rfc4514_attribute_name],
value.decode(),
),
)
)
for raw_oid, rfc4514_attribute_name, value in x509_certificate.issuer:
if rfc4514_attribute_name not in _short_name_assoc:
continue
peer_info["issuer"].append( # type: ignore[attr-defined]
(
(
_short_name_assoc[rfc4514_attribute_name],
value.decode(),
),
)
)
for alt_name in x509_certificate.get_subject_alt_names():
decoded_alt_name = alt_name.decode()
in_parenthesis = decoded_alt_name[
decoded_alt_name.index("(") + 1 : decoded_alt_name.index(")")
]
if decoded_alt_name.startswith("DNS"):
peer_info["subjectAltName"].append(("DNS", in_parenthesis)) # type: ignore[attr-defined]
else:
from ....resolver.utils import inet4_ntoa, inet6_ntoa
if len(in_parenthesis) == 11:
ip_address_decoded = inet4_ntoa(
bytes.fromhex(in_parenthesis.replace(":", ""))
)
else:
ip_address_decoded = inet6_ntoa(
bytes.fromhex(in_parenthesis.replace(":", ""))
)
peer_info["subjectAltName"].append(("IP Address", ip_address_decoded)) # type: ignore[attr-defined]
peer_info["OCSP"] = []
for endpoint in x509_certificate.get_ocsp_endpoints():
decoded_endpoint = endpoint.decode()
peer_info["OCSP"].append( # type: ignore[attr-defined]
decoded_endpoint[decoded_endpoint.index("(") + 1 : -1]
)
peer_info["caIssuers"] = []
for endpoint in x509_certificate.get_issuer_endpoints():
decoded_endpoint = endpoint.decode()
peer_info["caIssuers"].append( # type: ignore[attr-defined]
decoded_endpoint[decoded_endpoint.index("(") + 1 : -1]
)
peer_info["crlDistributionPoints"] = []
for endpoint in x509_certificate.get_crl_endpoints():
decoded_endpoint = endpoint.decode()
peer_info["crlDistributionPoints"].append( # type: ignore[attr-defined]
decoded_endpoint[decoded_endpoint.index("(") + 1 : -1]
)
pop_keys = []
for k in peer_info:
if isinstance(peer_info[k], list):
peer_info[k] = tuple(peer_info[k]) # type: ignore[arg-type]
if not peer_info[k]:
pop_keys.append(k)
for k in pop_keys:
peer_info.pop(k)
return peer_info
def cipher(self) -> str | None:
cipher_suite = self._quic.get_cipher()
if cipher_suite is None:
raise ValueError("TLS handshake has not been done yet")
return f"TLS_{cipher_suite.name}"
def reshelve(self, *events: Event) -> None:
for ev in reversed(events):
self._events.appendleft(ev)
def ping(self) -> None:
self._quic.send_ping(randint(0, 65535))
def max_frame_size(self) -> int:
if self._max_frame_size is not None:
return self._max_frame_size
raise NotImplementedError

View File

@@ -0,0 +1,59 @@
from __future__ import annotations
import typing
from io import UnsupportedOperation
if typing.TYPE_CHECKING:
import ssl
from ._ctypes import load_cert_chain as _ctypes_load_cert_chain
from ._shm import load_cert_chain as _shm_load_cert_chain
SUPPORTED_METHODS: list[
typing.Callable[
[
ssl.SSLContext,
bytes | str,
bytes | str,
bytes | str | typing.Callable[[], str | bytes] | None,
],
None,
]
] = [
_ctypes_load_cert_chain,
_shm_load_cert_chain,
]
def load_cert_chain(
ctx: ssl.SSLContext,
certdata: bytes | str,
keydata: bytes | str,
password: bytes | str | typing.Callable[[], str | bytes] | None = None,
) -> None:
"""
Unique workaround the known limitation of CPython inability to initialize the mTLS context without files.
:raise UnsupportedOperation: If anything goes wrong in the process.
"""
err = None
for supported in SUPPORTED_METHODS:
try:
supported(
ctx,
certdata,
keydata,
password,
)
return
except UnsupportedOperation as e:
if err is None:
err = e
if err is not None:
raise err
raise UnsupportedOperation("unable to initialize mTLS using in-memory cert and key")
__all__ = ("load_cert_chain",)

View File

@@ -0,0 +1,376 @@
from __future__ import annotations
import ctypes
import os
import sys
import typing
from io import UnsupportedOperation
if typing.TYPE_CHECKING:
import ssl
class _OpenSSL:
"""Access hazardous material from CPython OpenSSL (or compatible SSL) implementation."""
def __init__(self) -> None:
import platform
if platform.python_implementation() != "CPython":
raise UnsupportedOperation("Only CPython is supported")
import ssl
self._name = ssl.OPENSSL_VERSION
self.ssl = ssl
# bug seen in Windows + CPython < 3.11
# where CPython official API for options
# cast OpenSSL get_options to SIGNED long
# where we want UNSIGNED long.
_ssl_options_signed_long_bug = False
if not hasattr(ssl, "_ssl"):
raise UnsupportedOperation(
"Unsupported interpreter due to missing private ssl module"
)
if platform.system() == "Windows":
# possible search locations
candidates = {
os.path.dirname(sys.executable),
os.path.join(sys.prefix, "DLLs"),
sys.prefix,
}
if hasattr(ssl._ssl, "__file__"):
candidates.add(os.path.dirname(ssl._ssl.__file__))
_ssl_options_signed_long_bug = sys.version_info < (3, 11)
ssl_potential_match = None
crypto_potential_match = None
for d in candidates:
if not os.path.exists(d):
continue
for filename in os.listdir(d):
if ssl_potential_match is None:
if filename.startswith("libssl") and filename.endswith(".dll"):
ssl_potential_match = os.path.join(d, filename)
if crypto_potential_match is None:
if filename.startswith("libcrypto") and filename.endswith(
".dll"
):
crypto_potential_match = os.path.join(d, filename)
if crypto_potential_match and ssl_potential_match:
break
if not ssl_potential_match or not crypto_potential_match:
raise UnsupportedOperation(
"Could not locate OpenSSL DLLs next to Python; "
"check your /DLLs folder or your PATH."
)
self._ssl = ctypes.CDLL(ssl_potential_match)
self._crypto = ctypes.CDLL(crypto_potential_match)
else:
# that's the most common path
# ssl built in module already loaded both crypto and ssl
# symbols.
if hasattr(ssl._ssl, "__file__"):
self._ssl = ctypes.CDLL(ssl._ssl.__file__)
else:
# _ssl is statically linked into the interpreter
# (e.g. python-build-standalone via uv). OpenSSL symbols
# are in the main process image; ctypes.CDLL(None) exposes them.
# see https://github.com/jawah/urllib3.future/issues/325 for more
# details.
self._ssl = ctypes.CDLL(None)
self._crypto = self._ssl
# we want to ensure a minimal set of symbols
# are present. CPython should have at least:
for required_symbol in [
"SSL_CTX_use_certificate",
"SSL_CTX_check_private_key",
"SSL_CTX_use_PrivateKey",
]:
if not hasattr(self._ssl, required_symbol):
raise UnsupportedOperation(
f"Python interpreter built against '{self._name}' is unsupported. (libssl) {required_symbol} is not present."
)
for required_symbol in [
"BIO_free",
"BIO_new_mem_buf",
"PEM_read_bio_X509",
"PEM_read_bio_PrivateKey",
"ERR_get_error",
"ERR_error_string",
]:
if not hasattr(self._crypto, required_symbol):
raise UnsupportedOperation(
f"Python interpreter built against '{self._name}' is unsupported. (libcrypto) {required_symbol} is not present."
)
# https://docs.openssl.org/3.0/man3/SSL_CTX_use_certificate/
self.SSL_CTX_use_certificate = self._ssl.SSL_CTX_use_certificate
self.SSL_CTX_use_certificate.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
self.SSL_CTX_use_certificate.restype = ctypes.c_int
self.SSL_CTX_check_private_key = self._ssl.SSL_CTX_check_private_key
self.SSL_CTX_check_private_key.argtypes = [ctypes.c_void_p]
self.SSL_CTX_check_private_key.restype = ctypes.c_int
# https://docs.openssl.org/3.0/man3/BIO_new/
self.BIO_free = self._crypto.BIO_free
self.BIO_free.argtypes = [ctypes.c_void_p]
self.BIO_free.restype = None
self.BIO_new_mem_buf = self._crypto.BIO_new_mem_buf
self.BIO_new_mem_buf.argtypes = [ctypes.c_void_p, ctypes.c_int]
self.BIO_new_mem_buf.restype = ctypes.c_void_p
# https://docs.openssl.org/3.0/man3/PEM_read_bio_PrivateKey/
self.PEM_read_bio_X509 = self._crypto.PEM_read_bio_X509
self.PEM_read_bio_X509.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
]
self.PEM_read_bio_X509.restype = ctypes.c_void_p
self.PEM_read_bio_PrivateKey = self._crypto.PEM_read_bio_PrivateKey
self.PEM_read_bio_PrivateKey.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
]
self.PEM_read_bio_PrivateKey.restype = ctypes.c_void_p
# https://docs.openssl.org/3.0/man3/SSL_CTX_use_certificate/
self.SSL_CTX_use_PrivateKey = self._ssl.SSL_CTX_use_PrivateKey
self.SSL_CTX_use_PrivateKey.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
self.SSL_CTX_use_PrivateKey.restype = ctypes.c_int
self.ERR_get_error = self._crypto.ERR_get_error
self.ERR_get_error.argtypes = []
self.ERR_get_error.restype = ctypes.c_ulong
self.ERR_error_string = self._crypto.ERR_error_string
self.ERR_error_string.argtypes = [ctypes.c_ulong, ctypes.c_char_p]
self.ERR_error_string.restype = ctypes.c_char_p
if hasattr(self._ssl, "SSL_CTX_get_options"):
self.SSL_CTX_get_options = self._ssl.SSL_CTX_get_options
self.SSL_CTX_get_options.argtypes = [ctypes.c_void_p]
self.SSL_CTX_get_options.restype = (
ctypes.c_ulong if not _ssl_options_signed_long_bug else ctypes.c_long
) # OpenSSL's options are long
elif hasattr(self._ssl, "SSL_CTX_ctrl"):
# some old build inline SSL_CTX_get_options (mere C define)
# define SSL_CTX_get_options(ctx) SSL_CTX_ctrl((ctx),SSL_CTRL_OPTIONS,0,NULL)
# define SSL_CTRL_OPTIONS 32
self.SSL_CTX_ctrl = self._ssl.SSL_CTX_ctrl
self.SSL_CTX_ctrl.argtypes = [
ctypes.c_void_p,
ctypes.c_int,
ctypes.c_int,
ctypes.c_void_p,
]
self.SSL_CTX_ctrl.restype = (
ctypes.c_ulong if not _ssl_options_signed_long_bug else ctypes.c_long
)
self.SSL_CTX_get_options = lambda ctx: self.SSL_CTX_ctrl( # type: ignore[assignment]
ctx, 32, 0, None
)
else:
raise UnsupportedOperation()
def pull_error(self) -> typing.NoReturn:
raise self.ssl.SSLError(
self.ERR_error_string(
self.ERR_get_error(), ctypes.create_string_buffer(256)
).decode()
)
_IS_GIL_DISABLED = hasattr(sys, "_is_gil_enabled") and sys._is_gil_enabled() is False
_IS_LINUX = sys.platform == "linux"
_FT_HEAD_ADDITIONAL_OFFSET = 1 if _IS_LINUX else 2
_head_extra_fields = []
if sys.flags.debug:
# In debug builds (_POSIX_C_SOURCE or Py_DEBUG is defined), PyObject_HEAD
# is preceded by _PyObject_HEAD_EXTRA, which typically consists of
# two pointers (_ob_next, _ob_prev).
_head_extra_fields = [("_ob_next", ctypes.c_void_p), ("_ob_prev", ctypes.c_void_p)]
# Define the PySSLContext C structure using ctypes.
# This definition assumes that 'SSL_CTX *ctx' is the first member
# immediately following PyObject_HEAD. This has been observed to be
# the case in various CPython versions (e.g., 3.7 through 3.14 so far).
#
# CPython's Modules/_ssl.c (simplified):
# typedef struct {
# PyObject_HEAD // Expands to _PyObject_HEAD_EXTRA (if debug) + ob_refcnt + ob_type
# SSL_CTX *ctx;
# // ... other members ...
# } PySSLContextObject;
#
class PySSLContextStruct(ctypes.Structure):
_fields_ = (
_head_extra_fields # type: ignore[assignment]
+ [
("ob_refcnt", ctypes.c_ssize_t), # Py_ssize_t ob_refcnt;
("ob_type", ctypes.c_void_p), # PyTypeObject *ob_type;
]
+ (
[(f"_ob_ft{i}", ctypes.c_void_p) for i in range(_FT_HEAD_ADDITIONAL_OFFSET)]
if _IS_GIL_DISABLED
else []
)
+ [
("ssl_ctx", ctypes.c_void_p), # SSL_CTX *ctx; (this is the pointer we want)
# If there were other C members between ob_type and ssl_ctx,
# they would need to be defined here with their correct types and padding.
]
)
def _split_client_cert(data: bytes) -> list[bytes]:
line_ending = b"\n" if b"-----\r\n" not in data else b"\r\n"
boundary = b"-----END CERTIFICATE-----" + line_ending
certificates = []
for chunk in data.split(boundary):
if chunk:
start_marker = chunk.find(b"-----BEGIN CERTIFICATE-----" + line_ending)
if start_marker == -1:
break
pem_reconstructed = b"".join([chunk[start_marker:], boundary])
certificates.append(pem_reconstructed)
return certificates
def load_cert_chain(
ctx: ssl.SSLContext,
certdata: bytes | str,
keydata: bytes | str,
password: bytes | str | typing.Callable[[], str | bytes] | None = None,
) -> None:
"""
Unique workaround the known limitation of CPython inability to initialize the mTLS context without files.
:raise UnsupportedOperation: If anything goes wrong in the process.
"""
lib = _OpenSSL()
# Get the memory address of the Python ssl.SSLContext object.
# id() returns the address of the PyObject.
addr = id(ctx)
# Cast this memory address to a pointer to our defined PySSLContextStruct.
ptr_to_pysslcontext_struct = ctypes.cast(addr, ctypes.POINTER(PySSLContextStruct))
# Access the 'ssl_ctx' field from the structure. This field holds the
# actual SSL_CTX* C pointer value.
ssl_ctx_address = ptr_to_pysslcontext_struct.contents.ssl_ctx
# We want to ensure we got the right pointer address
# the safest way to achieve that is to retrieve options
# and compare it with the official ctx property.
if lib.SSL_CTX_get_options is not None:
bypass_options = lib.SSL_CTX_get_options(ssl_ctx_address)
expected_options = int(ctx.options)
if bypass_options != expected_options:
raise UnsupportedOperation(
f"CPython internal SSL_CTX changed! Cannot pursue safely. Expected = {expected_options:x} Actual = {bypass_options:x}"
)
# normalize inputs
if isinstance(certdata, str):
certdata = certdata.encode()
if isinstance(keydata, str):
keydata = keydata.encode()
client_chain = _split_client_cert(certdata)
leaf_certificate = client_chain[0]
# Use a BIO to read the client certificate
# only the leaf certificate is supported here.
cert_bio = lib.BIO_new_mem_buf(leaf_certificate, len(leaf_certificate))
if not cert_bio:
raise MemoryError("Unable to allocate memory to load the client certificate")
# Use a BIO to load the key in-memory
key_bio = lib.BIO_new_mem_buf(keydata, len(keydata))
if not key_bio:
raise MemoryError("Unable to allocate memory to load the client key")
# prepare the password
if callable(password):
password = password()
if isinstance(password, str):
password = password.encode()
assert password is None or isinstance(password, bytes)
# the allocated X509 obj MUST NOT be freed by ourselves
# OpenSSL internals will free it once not needed.
cert = lib.PEM_read_bio_X509(cert_bio, None, None, None)
# we do own the BIO, once the X509 leaf is instantiated, no need
# to keep it afterward.
lib.BIO_free(cert_bio)
if not cert:
lib.pull_error()
pkey = lib.PEM_read_bio_PrivateKey(key_bio, None, None, password)
lib.BIO_free(key_bio)
if not pkey:
lib.pull_error()
if lib.SSL_CTX_use_certificate(ssl_ctx_address, cert) != 1:
lib.pull_error()
if lib.SSL_CTX_use_PrivateKey(ssl_ctx_address, pkey) != 1:
lib.pull_error()
if lib.SSL_CTX_check_private_key(ssl_ctx_address) != 1:
lib.pull_error()
# Unfortunately, most of the time
# SSL_CTX_add_extra_chain_cert is unavailable
# in the final CPython build.
# According to OpenSSL latest docs: "The engine
# will attempt to build the required chain for the CA store"
# It's not going to be used as a trust anchor! (i.e. not self-signed)
# "If no chain is specified, the library will try to complete the
# chain from the available CA certificates in the trusted
# CA storage, see SSL_CTX_load_verify_locations(3)."
# see: https://docs.openssl.org/master/man3/SSL_CTX_add_extra_chain_cert/#notes
if len(client_chain) > 1:
ctx.load_verify_locations(cadata=(b"\n".join(client_chain[1:])).decode())
__all__ = ("load_cert_chain",)

View File

@@ -0,0 +1,122 @@
from __future__ import annotations
import os
import secrets
import stat
import sys
import typing
import warnings
from hashlib import sha256
from io import UnsupportedOperation
if typing.TYPE_CHECKING:
import ssl
def load_cert_chain(
ctx: ssl.SSLContext,
certdata: str | bytes,
keydata: str | bytes | None = None,
password: typing.Callable[[], str | bytes] | str | bytes | None = None,
) -> None:
"""
Unique workaround the known limitation of CPython inability to initialize the mTLS context without files.
Only supported on Linux, FreeBSD, and OpenBSD.
:raise UnsupportedOperation: If anything goes wrong in the process.
"""
if (
sys.platform != "linux"
and sys.platform.startswith("freebsd") is False
and sys.platform.startswith("openbsd") is False
):
raise UnsupportedOperation(
f"Unable to provide support for in-memory client certificate: Unsupported platform {sys.platform}"
)
unique_name: str = f"{sha256(secrets.token_bytes(32)).hexdigest()}.pem"
if isinstance(certdata, bytes):
certdata = certdata.decode("ascii")
if keydata is not None:
if isinstance(keydata, bytes):
keydata = keydata.decode("ascii")
if hasattr(os, "memfd_create"):
fd = os.memfd_create(unique_name, os.MFD_CLOEXEC)
else:
# this branch patch is for CPython <3.8 and PyPy 3.7+
from ctypes import c_int, c_ushort, cdll, create_string_buffer, get_errno, util
loc = util.find_library("rt") or util.find_library("c")
if not loc:
raise UnsupportedOperation(
"Unable to provide support for in-memory client certificate: libc or librt not found."
)
lib = cdll.LoadLibrary(loc)
_shm_open = lib.shm_open
# _shm_unlink = lib.shm_unlink
buf_name = create_string_buffer(unique_name.encode())
try:
fd = _shm_open(
buf_name,
c_int(os.O_RDWR | os.O_CREAT),
c_ushort(stat.S_IRUSR | stat.S_IWUSR),
)
except SystemError as e:
raise UnsupportedOperation(
f"Unable to provide support for in-memory client certificate: {e}"
)
if fd == -1:
raise UnsupportedOperation(
f"Unable to provide support for in-memory client certificate: {os.strerror(get_errno())}"
)
# Linux 3.17+
path = f"/proc/self/fd/{fd}"
# Alt-path
shm_path = f"/dev/shm/{unique_name}"
if os.path.exists(path) is False:
if os.path.exists(shm_path):
path = shm_path
else:
os.fdopen(fd).close()
raise UnsupportedOperation(
"Unable to provide support for in-memory client certificate: no virtual patch available?"
)
os.chmod(path, stat.S_IRUSR | stat.S_IWUSR)
with open(path, "w") as fp:
fp.write(certdata)
if keydata:
fp.write(keydata)
path = fp.name
ctx.load_cert_chain(path, password=password)
# we shall start cleaning remnants
os.fdopen(fd).close()
if os.path.exists(shm_path):
os.unlink(shm_path)
if os.path.exists(path) or os.path.exists(shm_path):
warnings.warn(
"In-memory client certificate: The kernel leaked a file descriptor outside of its expected lifetime.",
ResourceWarning,
)
__all__ = ("load_cert_chain",)

View File

@@ -0,0 +1,28 @@
from __future__ import annotations
import warnings
warnings.warn(
(
"'urllib3.contrib.pyopenssl' module has been removed in urllib3.future due to incompatibilities "
"with our QUIC integration. While the import proceed without error for your convenience, it is rendered "
"completely ineffective. Were you looking for in-memory client certificate? "
"See https://urllib3future.readthedocs.io/en/latest/advanced-usage.html#in-memory-client-mtls-certificate"
),
category=DeprecationWarning,
stacklevel=2,
)
import OpenSSL.SSL # type: ignore # noqa
__all__ = ["inject_into_urllib3", "extract_from_urllib3"]
def inject_into_urllib3() -> None:
"""Kept for BC-purposes."""
...
def extract_from_urllib3() -> None:
"""Kept for BC-purposes."""
...

View File

@@ -0,0 +1,12 @@
from __future__ import annotations
from .factories import ResolverDescription, ResolverFactory
from .protocols import BaseResolver, ManyResolver, ProtocolResolver
__all__ = (
"ResolverFactory",
"ProtocolResolver",
"BaseResolver",
"ManyResolver",
"ResolverDescription",
)

View File

@@ -0,0 +1,11 @@
from __future__ import annotations
from .factories import AsyncResolverDescription, AsyncResolverFactory
from .protocols import AsyncBaseResolver, AsyncManyResolver
__all__ = (
"AsyncResolverDescription",
"AsyncResolverFactory",
"AsyncBaseResolver",
"AsyncManyResolver",
)

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
from ._urllib3 import (
AdGuardResolver,
CloudflareResolver,
GoogleResolver,
HTTPSResolver,
NextDNSResolver,
OpenDNSResolver,
Quad9Resolver,
)
__all__ = (
"HTTPSResolver",
"GoogleResolver",
"CloudflareResolver",
"AdGuardResolver",
"OpenDNSResolver",
"Quad9Resolver",
"NextDNSResolver",
)

View File

@@ -0,0 +1,656 @@
from __future__ import annotations
import socket
import typing
from asyncio import as_completed
from base64 import b64encode
from ....._async.connectionpool import AsyncHTTPSConnectionPool
from ....._async.response import AsyncHTTPResponse
from ....._collections import HTTPHeaderDict
from .....backend import ConnectionInfo, HttpVersion
from .....util.url import parse_url
from ...protocols import (
DomainNameServerQuery,
DomainNameServerReturn,
ProtocolResolver,
SupportedQueryType,
)
from ...utils import is_ipv4, is_ipv6, validate_length_of, parse_https_rdata
from ..protocols import AsyncBaseResolver
class HTTPSResolver(AsyncBaseResolver):
"""
Advanced DNS over HTTPS resolver.
No common ground emerged from IETF w/ JSON. Following Googles DNS over HTTPS schematics that is
also implemented at Cloudflare.
Support RFC 8484 without JSON. Disabled by default.
"""
implementation = "urllib3"
protocol = ProtocolResolver.DOH
def __init__(
self,
server: str | None,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
) -> None:
super().__init__(server, port or 443, *patterns, **kwargs)
self._path: str = "/resolve"
if "path" in kwargs:
if isinstance(kwargs["path"], str) and kwargs["path"] != "/":
self._path = kwargs["path"]
kwargs.pop("path")
self._rfc8484: bool = False
if "rfc8484" in kwargs:
if kwargs["rfc8484"]:
self._rfc8484 = True
kwargs.pop("rfc8484")
assert self._server is not None
if "source_address" in kwargs:
if isinstance(kwargs["source_address"], str):
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
if bind_ip and bind_port.isdigit():
kwargs["source_address"] = (
bind_ip,
int(bind_port),
)
else:
raise ValueError("invalid source_address given in parameters")
else:
raise ValueError("invalid source_address given in parameters")
if "proxy" in kwargs:
kwargs["_proxy"] = parse_url(kwargs["proxy"])
kwargs.pop("proxy")
if "maxsize" not in kwargs:
kwargs["maxsize"] = 10
if "proxy_headers" in kwargs and "_proxy" in kwargs:
proxy_headers = HTTPHeaderDict()
if not isinstance(kwargs["proxy_headers"], list):
kwargs["proxy_headers"] = [kwargs["proxy_headers"]]
for item in kwargs["proxy_headers"]:
if ":" not in item:
raise ValueError("Passed header is invalid in DNS parameters")
k, v = item.split(":", 1)
proxy_headers.add(k, v)
kwargs["_proxy_headers"] = proxy_headers
if "headers" in kwargs:
headers = HTTPHeaderDict()
if not isinstance(kwargs["headers"], list):
kwargs["headers"] = [kwargs["headers"]]
for item in kwargs["headers"]:
if ":" not in item:
raise ValueError("Passed header is invalid in DNS parameters")
k, v = item.split(":", 1)
headers.add(k, v)
kwargs["headers"] = headers
if "disabled_svn" in kwargs:
if not isinstance(kwargs["disabled_svn"], list):
kwargs["disabled_svn"] = [kwargs["disabled_svn"]]
disabled_svn = set()
for svn in kwargs["disabled_svn"]:
svn = svn.lower()
if svn == "h11":
disabled_svn.add(HttpVersion.h11)
elif svn == "h2":
disabled_svn.add(HttpVersion.h2)
elif svn == "h3":
disabled_svn.add(HttpVersion.h3)
kwargs["disabled_svn"] = disabled_svn
if "on_post_connection" in kwargs and callable(kwargs["on_post_connection"]):
self._connection_callback: (
typing.Callable[[ConnectionInfo], None] | None
) = kwargs["on_post_connection"]
kwargs.pop("on_post_connection")
else:
self._connection_callback = None
self._pool = AsyncHTTPSConnectionPool(self._server, self._port, **kwargs)
async def close(self) -> None: # type: ignore[override]
await self._pool.close()
def is_available(self) -> bool:
return self._pool.pool is not None
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Tried to resolve 'localhost' from a HTTPSResolver"
)
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
validate_length_of(host)
promises = []
remote_preemptive_quic_rr = False
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
quic_upgrade_via_dns_rr = False
if family in [socket.AF_UNSPEC, socket.AF_INET]:
if not self._rfc8484:
promises.append(
await self._pool.request_encode_url(
"GET",
self._path,
{"name": host, "type": "1"},
headers={"Accept": "application/dns-json"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
else:
dns_query = DomainNameServerQuery(
host, SupportedQueryType.A, override_id=0
)
dns_payload = bytes(dns_query)
promises.append(
await self._pool.request_encode_url(
"GET",
self._path,
{
"dns": b64encode(dns_payload).decode().replace("=", ""),
},
headers={"Accept": "application/dns-message"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
if not self._rfc8484:
promises.append(
await self._pool.request_encode_url(
"GET",
self._path,
{"name": host, "type": "28"},
headers={"Accept": "application/dns-json"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
else:
dns_query = DomainNameServerQuery(
host, SupportedQueryType.AAAA, override_id=0
)
dns_payload = bytes(dns_query)
promises.append(
await self._pool.request_encode_url(
"GET",
self._path,
{
"dns": b64encode(dns_payload).decode().replace("=", ""),
},
headers={"Accept": "application/dns-message"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
if quic_upgrade_via_dns_rr:
if not self._rfc8484:
promises.append(
await self._pool.request_encode_url(
"GET",
self._path,
{"name": host, "type": "65"},
headers={"Accept": "application/dns-json"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
else:
dns_query = DomainNameServerQuery(
host, SupportedQueryType.HTTPS, override_id=0
)
dns_payload = bytes(dns_query)
promises.append(
await self._pool.request_encode_url(
"GET",
self._path,
{
"dns": b64encode(dns_payload).decode().replace("=", ""),
},
headers={"Accept": "application/dns-message"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
tasks = []
responses = []
for promise in promises:
# already resolved
if isinstance(promise, AsyncHTTPResponse):
responses.append(promise)
continue
tasks.append(self._pool.get_response(promise=promise))
if tasks:
for waiting_promise_coro in as_completed(tasks):
responses.append(await waiting_promise_coro) # type: ignore[arg-type]
results: list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
] = []
for response in responses:
if response.status >= 300:
raise socket.gaierror(
f"DNS over HTTPS was unsuccessful, server response status {response.status}."
)
if not self._rfc8484:
payload = await response.json()
assert "Status" in payload and isinstance(payload["Status"], int)
if payload["Status"] != 0:
msg = (
payload["Comment"]
if "Comment" in payload
else f"Remote DNS indicated that an error occurred while providing resolution. Status {payload['Status']}."
)
if isinstance(msg, list):
msg = ", ".join(msg)
raise socket.gaierror(msg)
assert "Question" in payload and isinstance(payload["Question"], list)
if "Answer" not in payload:
continue
assert isinstance(payload["Answer"], list)
for answer in payload["Answer"]:
if answer["type"] not in [1, 28, 65]:
continue
assert "data" in answer
assert isinstance(answer["data"], str)
# DNS RR/HTTPS
if answer["type"] == 65:
# "1 . alpn=h3,h2 ipv4hint=104.16.132.229,104.16.133.229 ipv6hint=2606:4700::6810:84e5,2606:4700::6810:85e5"
# or..
# "1 . alpn=h2,h3"
rr: str = answer["data"]
if rr.startswith("\\#"): # it means, raw, bytes.
rr = "".join(rr[2:].split(" ")[2:])
try:
raw_record = bytes.fromhex(rr)
except ValueError:
raw_record = b""
if not raw_record:
continue
https_record = parse_https_rdata(raw_record)
if "h3" not in https_record["alpn"]:
continue
remote_preemptive_quic_rr = True
else:
rr_decode: dict[str, str] = dict(
tuple(_.lower().split("=", 1)) # type: ignore[misc]
for _ in rr.split(" ")
if "=" in _
)
if "alpn" not in rr_decode or "h3" not in rr_decode["alpn"]:
continue
remote_preemptive_quic_rr = True
if "ipv4hint" in rr_decode and family in [
socket.AF_UNSPEC,
socket.AF_INET,
]:
for ipv4 in rr_decode["ipv4hint"].split(","):
results.append(
(
socket.AF_INET,
socket.SOCK_DGRAM,
17,
"",
(
ipv4,
port,
),
)
)
if "ipv6hint" in rr_decode and family in [
socket.AF_UNSPEC,
socket.AF_INET6,
]:
for ipv6 in rr_decode["ipv6hint"].split(","):
results.append(
(
socket.AF_INET6,
socket.SOCK_DGRAM,
17,
"",
(
ipv6,
port,
0,
0,
),
)
)
continue
inet_type = (
socket.AF_INET if answer["type"] == 1 else socket.AF_INET6
)
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
(
answer["data"],
port,
)
if inet_type == socket.AF_INET
else (
answer["data"],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
else:
dns_resp = DomainNameServerReturn(await response.data)
for record in dns_resp.records:
if record[0] == SupportedQueryType.HTTPS:
assert isinstance(record[-1], dict)
if "h3" in record[-1]["alpn"]:
remote_preemptive_quic_rr = True
continue
assert not isinstance(record[-1], dict)
inet_type = (
socket.AF_INET
if record[0] == SupportedQueryType.A
else socket.AF_INET6
)
dst_addr = (
(
record[-1],
port,
)
if inet_type == socket.AF_INET
else (
record[-1],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
quic_results: list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
] = []
if remote_preemptive_quic_rr:
any_specified = False
for result in results:
if result[1] == socket.SOCK_STREAM:
quic_results.append(
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
)
else:
any_specified = True
break
if any_specified:
quic_results = []
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
class GoogleResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "google"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
if "rfc8484" in kwargs:
if kwargs["rfc8484"]:
kwargs["path"] = "/dns-query"
super().__init__("dns.google", port, *patterns, **kwargs)
class CloudflareResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "cloudflare"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query"})
super().__init__("cloudflare-dns.com", port, *patterns, **kwargs)
class AdGuardResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query", "rfc8484": True})
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
class OpenDNSResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "opendns"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query", "rfc8484": True})
super().__init__("dns.opendns.com", port, *patterns, **kwargs)
class Quad9Resolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "quad9"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query", "rfc8484": True})
super().__init__("dns11.quad9.net", port, *patterns, **kwargs)
class NextDNSResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "nextdns"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.nextdns.io", port, *patterns, **kwargs)

View File

@@ -0,0 +1,15 @@
from __future__ import annotations
try:
from ._qh3 import AdGuardResolver, NextDNSResolver, QUICResolver
except ImportError:
QUICResolver = None # type: ignore
AdGuardResolver = None # type: ignore
NextDNSResolver = None # type: ignore
__all__ = (
"QUICResolver",
"AdGuardResolver",
"NextDNSResolver",
)

View File

@@ -0,0 +1,557 @@
from __future__ import annotations
import asyncio
import socket
import ssl
import typing
from collections import deque
from ssl import SSLError
from time import time as monotonic
from qh3.quic.configuration import QuicConfiguration
from qh3.quic.connection import QuicConnection
from qh3.quic.events import (
ConnectionTerminated,
HandshakeCompleted,
QuicEvent,
StopSendingReceived,
StreamDataReceived,
StreamReset,
)
from .....util.ssl_ import IS_FIPS, resolve_cert_reqs
from ...protocols import (
COMMON_RCODE_LABEL,
DomainNameServerQuery,
DomainNameServerReturn,
ProtocolResolver,
SupportedQueryType,
)
from ...utils import (
is_ipv4,
is_ipv6,
rfc1035_pack,
rfc1035_should_read,
rfc1035_unpack,
validate_length_of,
)
from ..dou import PlainResolver
from ..system import SystemResolver
if IS_FIPS:
raise ImportError(
"DNS-over-QUIC disabled when Python is built with FIPS-compliant ssl module"
)
class QUICResolver(PlainResolver):
protocol = ProtocolResolver.DOQ
implementation = "qh3"
def __init__(
self,
server: str | None,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
):
super().__init__(server, port or 853, *patterns, **kwargs)
# qh3 load_default_certs seems off. need to investigate.
if "ca_cert_data" not in kwargs and "ca_certs" not in kwargs:
kwargs["ca_cert_data"] = []
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
try:
ctx.load_default_certs()
for der in ctx.get_ca_certs(binary_form=True):
kwargs["ca_cert_data"].append(ssl.DER_cert_to_PEM_cert(der))
if kwargs["ca_cert_data"]:
kwargs["ca_cert_data"] = "".join(kwargs["ca_cert_data"])
else:
del kwargs["ca_cert_data"]
except (AttributeError, ValueError, OSError):
del kwargs["ca_cert_data"]
if "ca_cert_data" not in kwargs and "ca_certs" not in kwargs:
if (
"cert_reqs" not in kwargs
or resolve_cert_reqs(kwargs["cert_reqs"]) is ssl.CERT_REQUIRED
):
raise ssl.SSLError(
"DoQ requires at least one CA loaded in order to verify the remote peer certificate. "
"Add ?cert_reqs=0 to disable certificate checks."
)
configuration = QuicConfiguration(
is_client=True,
alpn_protocols=["doq"],
server_name=self._server
if "server_hostname" not in kwargs
else kwargs["server_hostname"],
verify_mode=resolve_cert_reqs(kwargs["cert_reqs"])
if "cert_reqs" in kwargs
else ssl.CERT_REQUIRED,
cadata=kwargs["ca_cert_data"].encode()
if "ca_cert_data" in kwargs
else None,
cafile=kwargs["ca_certs"] if "ca_certs" in kwargs else None,
idle_timeout=300.0,
)
if "cert_file" in kwargs:
configuration.load_cert_chain(
kwargs["cert_file"],
kwargs["key_file"] if "key_file" in kwargs else None,
kwargs["key_password"] if "key_password" in kwargs else None,
)
elif "cert_data" in kwargs:
configuration.load_cert_chain(
kwargs["cert_data"],
kwargs["key_data"] if "key_data" in kwargs else None,
kwargs["key_password"] if "key_password" in kwargs else None,
)
self._quic = QuicConnection(configuration=configuration)
self._read_semaphore: asyncio.Semaphore = asyncio.Semaphore()
self._connect_attempt: asyncio.Event = asyncio.Event()
self._handshake_event: asyncio.Event = asyncio.Event()
self._terminated: bool = False
self._should_disconnect: bool = False
# DNS over QUIC mandate the size-prefix (unsigned int, 2b)
self._rfc1035_prefix_mandated = True
self._unconsumed: deque[DomainNameServerReturn] = deque()
self._pending: deque[DomainNameServerQuery] = deque()
async def close(self) -> None: # type: ignore[override]
if (
not self._terminated
and self._socket is not None
and not self._socket.should_connect()
):
self._quic.close()
while True:
datagrams = self._quic.datagrams_to_send(monotonic())
if not datagrams:
break
for datagram in datagrams:
data, addr = datagram
await self._socket.sendall(data)
self._socket.close()
await self._socket.wait_for_close()
self._terminated = True
if self._socket is None or self._socket.should_connect():
self._terminated = True
def is_available(self) -> bool:
if self._terminated:
return False
if self._socket is None or self._socket.should_connect():
return True
self._quic.handle_timer(monotonic())
if hasattr(self._quic, "_close_event") and self._quic._close_event is not None:
self._terminated = True
return not self._terminated
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Tried to resolve 'localhost' using the QUICResolver"
)
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
validate_length_of(host)
if self._socket is None and self._connect_attempt.is_set() is False:
self._connect_attempt.set()
assert self.server is not None
self._quic.connect((self._server, self._port), monotonic())
self._socket = await SystemResolver().create_connection(
(self.server, self.port or 853),
timeout=self._timeout,
source_address=self._source_address,
socket_options=None,
socket_kind=self._socket_type,
)
await self.__exchange_until(HandshakeCompleted, receive_first=False)
self._handshake_event.set()
else:
await self._handshake_event.wait()
assert self._socket is not None
remote_preemptive_quic_rr = False
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
quic_upgrade_via_dns_rr = False
tbq = []
if family in [socket.AF_UNSPEC, socket.AF_INET]:
tbq.append(SupportedQueryType.A)
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
tbq.append(SupportedQueryType.AAAA)
if quic_upgrade_via_dns_rr:
tbq.append(SupportedQueryType.HTTPS)
queries = DomainNameServerQuery.bulk(host, *tbq)
open_streams = []
for q in queries:
payload = bytes(q)
self._pending.append(q)
if self._rfc1035_prefix_mandated is True:
payload = rfc1035_pack(payload)
stream_id = self._quic.get_next_available_stream_id()
self._quic.send_stream_data(stream_id, payload, True)
open_streams.append(stream_id)
for dg in self._quic.datagrams_to_send(monotonic()):
await self._socket.sendall(dg[0])
responses: list[DomainNameServerReturn] = []
while len(responses) < len(tbq):
await self._read_semaphore.acquire()
if self._unconsumed:
dns_resp = None
for query in queries:
for unconsumed in self._unconsumed:
if unconsumed.id == query.id:
dns_resp = unconsumed
responses.append(dns_resp)
break
if dns_resp:
break
if dns_resp:
self._unconsumed.remove(dns_resp)
self._pending.remove(query)
self._read_semaphore.release()
continue
try:
events: list[StreamDataReceived] = await self.__exchange_until( # type: ignore[assignment]
StreamDataReceived,
receive_first=True,
event_type_collectable=(StreamDataReceived,),
respect_end_stream_signal=False,
)
payload = b"".join([e.data for e in events])
while rfc1035_should_read(payload):
events.extend(
await self.__exchange_until( # type: ignore[arg-type]
StreamDataReceived,
receive_first=True,
event_type_collectable=(StreamDataReceived,),
respect_end_stream_signal=False,
)
)
payload = b"".join([e.data for e in events])
except (TimeoutError, OSError, socket.timeout, ConnectionError) as e:
raise socket.gaierror(
"Got unexpectedly disconnected while waiting for name resolution"
) from e
self._read_semaphore.release()
if not payload:
continue
#: We can receive two responses at once (or more, concatenated). Let's unwrap them.
fragments = rfc1035_unpack(payload)
for fragment in fragments:
dns_resp = DomainNameServerReturn(fragment)
if any(dns_resp.id == _.id for _ in queries):
responses.append(dns_resp)
query_tbr: DomainNameServerQuery | None = None
for query_tbr in self._pending:
if query_tbr.id == dns_resp.id:
break
if query_tbr:
self._pending.remove(query_tbr)
else:
self._unconsumed.append(dns_resp)
if self._should_disconnect:
await self.close()
self._should_disconnect = False
self._terminated = True
results = []
for response in responses:
if not response.is_ok:
if response.rcode == 2:
raise socket.gaierror(
f"DNSSEC validation failure. Check http://dnsviz.net/d/{host}/dnssec/ and http://dnssec-debugger.verisignlabs.com/{host} for errors"
)
raise socket.gaierror(
f"DNS returned an error: {COMMON_RCODE_LABEL[response.rcode] if response.rcode in COMMON_RCODE_LABEL else f'code {response.rcode}'}"
)
for record in response.records:
if record[0] == SupportedQueryType.HTTPS:
assert isinstance(record[-1], dict)
if "h3" in record[-1]["alpn"]:
remote_preemptive_quic_rr = True
continue
assert not isinstance(record[-1], dict)
inet_type = (
socket.AF_INET
if record[0] == SupportedQueryType.A
else socket.AF_INET6
)
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
(
record[-1],
port,
)
if inet_type == socket.AF_INET
else (
record[-1],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
quic_results = []
if remote_preemptive_quic_rr:
any_specified = False
for result in results:
if result[1] == socket.SOCK_STREAM:
quic_results.append(
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
)
else:
any_specified = True
break
if any_specified:
quic_results = []
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
async def __exchange_until(
self,
event_type: type[QuicEvent] | tuple[type[QuicEvent], ...],
*,
receive_first: bool = False,
event_type_collectable: type[QuicEvent]
| tuple[type[QuicEvent], ...]
| None = None,
respect_end_stream_signal: bool = True,
) -> list[QuicEvent]:
assert self._socket is not None
while True:
if receive_first is False:
now = monotonic()
while True:
datagrams = self._quic.datagrams_to_send(now)
if not datagrams:
break
for datagram in datagrams:
data, addr = datagram
await self._socket.sendall(data)
events = []
while True:
if not self._quic._events:
data_in = await self._socket.recv(1500)
if not data_in:
break
now = monotonic()
if not isinstance(data_in, list):
self._quic.receive_datagram(
data_in, (self._server, self._port), now
)
else:
for gro_segment in data_in:
self._quic.receive_datagram(
gro_segment, (self._server, self._port), now
)
while True:
datagrams = self._quic.datagrams_to_send(now)
if not datagrams:
break
for datagram in datagrams:
data, addr = datagram
await self._socket.sendall(data)
for ev in iter(self._quic.next_event, None):
if isinstance(ev, ConnectionTerminated):
if ev.error_code == 298:
raise SSLError(
"DNS over QUIC did not succeed (Error 298). Chain certificate verification failed."
)
raise socket.gaierror(
f"DNS over QUIC encountered a unrecoverable failure (error {ev.error_code} {ev.reason_phrase})"
)
elif isinstance(ev, StreamReset):
self._terminated = True
raise socket.gaierror(
"DNS over QUIC server submitted a StreamReset. A request was rejected."
)
elif isinstance(ev, StopSendingReceived):
self._should_disconnect = True
continue
if event_type_collectable:
if isinstance(ev, event_type_collectable):
events.append(ev)
else:
events.append(ev)
if isinstance(ev, event_type):
if not respect_end_stream_signal:
return events
if hasattr(ev, "stream_ended") and ev.stream_ended:
return events
elif hasattr(ev, "stream_ended") is False:
return events
return events
class AdGuardResolver(
QUICResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
class NextDNSResolver(
QUICResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "nextdns"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.nextdns.io", port, *patterns, **kwargs)

View File

@@ -0,0 +1,19 @@
from __future__ import annotations
from ._ssl import (
AdGuardResolver,
CloudflareResolver,
GoogleResolver,
OpenDNSResolver,
Quad9Resolver,
TLSResolver,
)
__all__ = (
"TLSResolver",
"GoogleResolver",
"CloudflareResolver",
"AdGuardResolver",
"Quad9Resolver",
"OpenDNSResolver",
)

View File

@@ -0,0 +1,197 @@
from __future__ import annotations
import socket
import typing
from .....util._async.ssl_ import ssl_wrap_socket
from .....util.ssl_ import resolve_cert_reqs
from ...protocols import ProtocolResolver
from ..dou import PlainResolver
from ..system import SystemResolver
class TLSResolver(PlainResolver):
"""
Basic DNS resolver over TLS.
Comply with RFC 7858: https://datatracker.ietf.org/doc/html/rfc7858
"""
protocol = ProtocolResolver.DOT
implementation = "ssl"
def __init__(
self,
server: str | None,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
) -> None:
self._socket_type = socket.SOCK_STREAM
super().__init__(server, port or 853, *patterns, **kwargs)
# DNS over TLS mandate the size-prefix (unsigned int, 2 bytes)
self._rfc1035_prefix_mandated = True
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if self._socket is None and self._connect_attempt.is_set() is False:
assert self.server is not None
self._connect_attempt.set()
self._socket = await SystemResolver().create_connection(
(self.server, self.port or 853),
timeout=self._timeout,
source_address=self._source_address,
socket_options=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1, "tcp"),),
socket_kind=self._socket_type,
)
self._socket = await ssl_wrap_socket(
self._socket,
server_hostname=self.server
if "server_hostname" not in self._kwargs
else self._kwargs["server_hostname"],
keyfile=self._kwargs["key_file"]
if "key_file" in self._kwargs
else None,
certfile=self._kwargs["cert_file"]
if "cert_file" in self._kwargs
else None,
cert_reqs=resolve_cert_reqs(self._kwargs["cert_reqs"])
if "cert_reqs" in self._kwargs
else None,
ca_certs=self._kwargs["ca_certs"]
if "ca_certs" in self._kwargs
else None,
ssl_version=self._kwargs["ssl_version"]
if "ssl_version" in self._kwargs
else None,
ciphers=self._kwargs["ciphers"] if "ciphers" in self._kwargs else None,
ca_cert_dir=self._kwargs["ca_cert_dir"]
if "ca_cert_dir" in self._kwargs
else None,
key_password=self._kwargs["key_password"]
if "key_password" in self._kwargs
else None,
ca_cert_data=self._kwargs["ca_cert_data"]
if "ca_cert_data" in self._kwargs
else None,
certdata=self._kwargs["cert_data"]
if "cert_data" in self._kwargs
else None,
keydata=self._kwargs["key_data"]
if "key_data" in self._kwargs
else None,
)
self._connect_finalized.set()
return await super().getaddrinfo(
host,
port,
family=family,
type=type,
proto=proto,
flags=flags,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
class GoogleResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "google"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.google", port, *patterns, **kwargs)
class CloudflareResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "cloudflare"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("1.1.1.1", port, *patterns, **kwargs)
class AdGuardResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
class OpenDNSResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "opendns"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.opendns.com", port, *patterns, **kwargs)
class Quad9Resolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "quad9"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns11.quad9.net", port, *patterns, **kwargs)

View File

@@ -0,0 +1,17 @@
from __future__ import annotations
from ._socket import (
AdGuardResolver,
CloudflareResolver,
GoogleResolver,
PlainResolver,
Quad9Resolver,
)
__all__ = (
"PlainResolver",
"CloudflareResolver",
"GoogleResolver",
"Quad9Resolver",
"AdGuardResolver",
)

View File

@@ -0,0 +1,431 @@
from __future__ import annotations
import asyncio
import socket
import typing
from collections import deque
from ....ssa import AsyncSocket
from ...protocols import (
COMMON_RCODE_LABEL,
DomainNameServerQuery,
DomainNameServerReturn,
ProtocolResolver,
SupportedQueryType,
)
from ...utils import (
is_ipv4,
is_ipv6,
packet_fragment,
rfc1035_pack,
rfc1035_should_read,
rfc1035_unpack,
validate_length_of,
)
from ..protocols import AsyncBaseResolver
from ..system import SystemResolver
class PlainResolver(AsyncBaseResolver):
"""
Minimalist DNS resolver over UDP
Comply with RFC 1035: https://datatracker.ietf.org/doc/html/rfc1035
EDNS is not supported, yet. But we plan to. Willing to contribute?
"""
protocol = ProtocolResolver.DOU
implementation = "socket"
def __init__(
self,
server: str | None,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
) -> None:
super().__init__(server, port, *patterns, **kwargs)
self._socket: AsyncSocket | None = None
if not hasattr(self, "_socket_type"):
self._socket_type = socket.SOCK_DGRAM
if "source_address" in kwargs and isinstance(kwargs["source_address"], str):
if ":" in kwargs["source_address"]:
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
self._source_address: tuple[str, int] | None = (bind_ip, int(bind_port))
else:
self._source_address = (kwargs["source_address"], 0)
else:
self._source_address = None
if "timeout" in kwargs and isinstance(
kwargs["timeout"],
(
float,
int,
),
):
self._timeout: float | int | None = kwargs["timeout"]
else:
self._timeout = None
#: Only useful for inheritance, e.g. DNS over TLS support dns-message but require a prefix.
self._rfc1035_prefix_mandated: bool = False
self._unconsumed: deque[DomainNameServerReturn] = deque()
self._pending: deque[DomainNameServerQuery] = deque()
self._read_semaphore: asyncio.Semaphore = asyncio.Semaphore()
self._connect_attempt: asyncio.Event = asyncio.Event()
self._connect_finalized: asyncio.Event = asyncio.Event()
self._terminated: bool = False
async def close(self) -> None: # type: ignore[override]
if not self._terminated:
with self._lock:
if self._socket is not None:
self._socket.close()
await self._socket.wait_for_close()
self._terminated = True
def is_available(self) -> bool:
return not self._terminated
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Tried to resolve 'localhost' from a PlainResolver"
)
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
validate_length_of(host)
if self._socket is None and self._connect_attempt.is_set() is False:
self._connect_attempt.set()
assert self.server is not None
self._socket = await SystemResolver().create_connection(
(self.server, self.port or 53),
timeout=self._timeout,
source_address=self._source_address,
socket_options=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1, "tcp"),),
socket_kind=self._socket_type,
)
self._connect_finalized.set()
else:
await self._connect_finalized.wait()
assert self._socket is not None
await self._socket.wait_for_readiness()
remote_preemptive_quic_rr = False
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
quic_upgrade_via_dns_rr = False
tbq = []
if family in [socket.AF_UNSPEC, socket.AF_INET]:
tbq.append(SupportedQueryType.A)
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
tbq.append(SupportedQueryType.AAAA)
if quic_upgrade_via_dns_rr:
tbq.append(SupportedQueryType.HTTPS)
queries = DomainNameServerQuery.bulk(host, *tbq)
for q in queries:
payload = bytes(q)
self._pending.append(q)
if self._rfc1035_prefix_mandated is True:
payload = rfc1035_pack(payload)
await self._socket.sendall(payload)
responses: list[DomainNameServerReturn] = []
while len(responses) < len(tbq):
await self._read_semaphore.acquire()
#: There we want to verify if another thread got a response that belong to this thread.
if self._unconsumed:
dns_resp = None
for query in queries:
for unconsumed in self._unconsumed:
if unconsumed.id == query.id:
dns_resp = unconsumed
responses.append(dns_resp)
break
if dns_resp:
break
if dns_resp:
self._pending.remove(query)
self._unconsumed.remove(dns_resp)
self._read_semaphore.release()
continue
try:
data_in_or_segments = await self._socket.recv(1500)
if isinstance(data_in_or_segments, list):
payloads = data_in_or_segments
elif data_in_or_segments:
payloads = [data_in_or_segments]
else:
payloads = []
if self._rfc1035_prefix_mandated is True and payloads:
payload = b"".join(payloads)
while rfc1035_should_read(payload):
extra = await self._socket.recv(1500)
if isinstance(extra, list):
payload += b"".join(extra)
else:
payload += extra
payloads = [payload]
except (TimeoutError, OSError, socket.timeout, ConnectionError) as e:
raise socket.gaierror(
"Got unexpectedly disconnected while waiting for name resolution"
) from e
self._read_semaphore.release()
if not payloads:
self._terminated = True
raise socket.gaierror(
"Got unexpectedly disconnected while waiting for name resolution"
)
pending_raw_identifiers = [_.raw_id for _ in self._pending]
for payload in payloads:
#: We can receive two responses at once (or more, concatenated). Let's unwrap them.
if self._rfc1035_prefix_mandated is True:
fragments = rfc1035_unpack(payload)
else:
fragments = packet_fragment(payload, *pending_raw_identifiers)
for fragment in fragments:
dns_resp = DomainNameServerReturn(fragment)
if any(dns_resp.id == _.id for _ in queries):
responses.append(dns_resp)
query_tbr: DomainNameServerQuery | None = None
for query_tbr in self._pending:
if query_tbr.id == dns_resp.id:
break
if query_tbr:
self._pending.remove(query_tbr)
else:
self._unconsumed.append(dns_resp)
results = []
for response in responses:
if not response.is_ok:
if response.rcode == 2:
raise socket.gaierror(
f"DNSSEC validation failure. Check http://dnsviz.net/d/{host}/dnssec/ and http://dnssec-debugger.verisignlabs.com/{host} for errors"
)
raise socket.gaierror(
f"DNS returned an error: {COMMON_RCODE_LABEL[response.rcode] if response.rcode in COMMON_RCODE_LABEL else f'code {response.rcode}'}"
)
for record in response.records:
if record[0] == SupportedQueryType.HTTPS:
assert isinstance(record[-1], dict)
if "h3" in record[-1]["alpn"]:
remote_preemptive_quic_rr = True
continue
assert not isinstance(record[-1], dict)
inet_type = (
socket.AF_INET
if record[0] == SupportedQueryType.A
else socket.AF_INET6
)
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
(
record[-1],
port,
)
if inet_type == socket.AF_INET
else (
record[-1],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
quic_results = []
if remote_preemptive_quic_rr:
any_specified = False
for result in results:
if result[1] == socket.SOCK_STREAM:
quic_results.append(
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
)
else:
any_specified = True
break
if any_specified:
quic_results = []
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
class CloudflareResolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "cloudflare"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("1.1.1.1", port, *patterns, **kwargs)
class GoogleResolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "google"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("8.8.8.8", port, *patterns, **kwargs)
class Quad9Resolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "quad9"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("9.9.9.9", port, *patterns, **kwargs)
class AdGuardResolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("94.140.14.140", port, *patterns, **kwargs)

View File

@@ -0,0 +1,243 @@
from __future__ import annotations
import importlib
import inspect
import typing
from abc import ABCMeta
from base64 import b64encode
from typing import Any
from urllib.parse import parse_qs
from ....util import parse_url
from ..factories import ResolverDescription
from ..protocols import ProtocolResolver
from .protocols import AsyncBaseResolver
class AsyncResolverFactory(metaclass=ABCMeta):
@staticmethod
def has(
protocol: ProtocolResolver,
specifier: str | None = None,
implementation: str | None = None,
) -> bool:
package_name: str = __name__.split(".")[0]
module_expr = f".{protocol.value.replace('-', '_')}"
if implementation:
module_expr += f"._{implementation.replace('-', '_').lower()}"
try:
resolver_module = importlib.import_module(
module_expr, f"{package_name}.contrib.resolver._async"
)
except ImportError:
return False
implementations: list[tuple[str, type[AsyncBaseResolver]]] = inspect.getmembers(
resolver_module,
lambda e: isinstance(e, type)
and issubclass(e, AsyncBaseResolver)
and (
(specifier is None and e.specifier is None) or specifier == e.specifier
),
)
if not implementations:
return False
return True
@staticmethod
def new(
protocol: ProtocolResolver,
specifier: str | None = None,
implementation: str | None = None,
**kwargs: Any,
) -> AsyncBaseResolver:
package_name: str = __name__.split(".")[0]
module_expr = f".{protocol.value.replace('-', '_')}"
if implementation:
module_expr += f"._{implementation.replace('-', '_').lower()}"
spe_msg = " " if specifier is None else f' (w/ specifier "{specifier}") '
try:
resolver_module = importlib.import_module(
module_expr, f"{package_name}.contrib.resolver._async"
)
except ImportError as e:
raise NotImplementedError(
f"{protocol}{spe_msg}cannot be loaded. Tried to import '{module_expr}'. Did you specify a non-existent implementation?"
) from e
implementations: list[tuple[str, type[AsyncBaseResolver]]] = inspect.getmembers(
resolver_module,
lambda e: isinstance(e, type)
and issubclass(e, AsyncBaseResolver)
and (
(specifier is None and e.specifier is None) or specifier == e.specifier
)
and hasattr(e, "protocol")
and e.protocol == protocol,
)
if not implementations:
raise NotImplementedError(
f"{protocol}{spe_msg}cannot be loaded. "
"No compatible implementation available. "
"Make sure your implementation inherit from BaseResolver."
)
implementation_target: type[AsyncBaseResolver] = implementations.pop()[1]
return implementation_target(**kwargs)
class AsyncResolverDescription(ResolverDescription):
"""Describe how a BaseResolver must be instantiated."""
def new(self) -> AsyncBaseResolver:
kwargs = {**self.kwargs}
if self.server:
kwargs["server"] = self.server
if self.port:
kwargs["port"] = self.port
if self.host_patterns:
kwargs["patterns"] = self.host_patterns
return AsyncResolverFactory.new(
self.protocol,
self.specifier,
self.implementation,
**kwargs,
)
@staticmethod
def from_url(url: str) -> AsyncResolverDescription:
parsed_url = parse_url(url)
schema = parsed_url.scheme
if schema is None:
raise ValueError("Given DNS url is missing a protocol")
specifier = None
implementation = None
if "+" in schema:
schema, specifier = tuple(schema.lower().split("+", 1))
protocol = ProtocolResolver(schema)
kwargs: dict[str, typing.Any] = {}
if parsed_url.path:
kwargs["path"] = parsed_url.path
if parsed_url.auth:
kwargs["headers"] = dict()
if ":" in parsed_url.auth:
username, password = parsed_url.auth.split(":")
username = username.strip("'\"")
password = password.strip("'\"")
kwargs["headers"]["Authorization"] = (
f"Basic {b64encode(f'{username}:{password}'.encode()).decode()}"
)
else:
kwargs["headers"]["Authorization"] = f"Bearer {parsed_url.auth}"
if parsed_url.query:
parameters = parse_qs(parsed_url.query)
for parameter in parameters:
if not parameters[parameter]:
continue
parameter_insensible = parameter.lower()
if (
isinstance(parameters[parameter], list)
and len(parameters[parameter]) > 1
):
if parameter == "implementation":
raise ValueError("Only one implementation can be passed to URL")
values = []
for e in parameters[parameter]:
if "," in e:
values.extend(e.split(","))
else:
values.append(e)
if parameter_insensible in kwargs:
if isinstance(kwargs[parameter_insensible], list):
kwargs[parameter_insensible].extend(values)
else:
values.append(kwargs[parameter_insensible])
kwargs[parameter_insensible] = values
continue
kwargs[parameter_insensible] = values
continue
value: str = parameters[parameter][0].lower().strip(" ")
if parameter == "implementation":
implementation = value
continue
if "," in value:
list_of_values = value.split(",")
if parameter_insensible in kwargs:
if isinstance(kwargs[parameter_insensible], list):
kwargs[parameter_insensible].extend(list_of_values)
else:
list_of_values.append(kwargs[parameter_insensible])
continue
kwargs[parameter_insensible] = list_of_values
continue
value_converted: bool | int | float | None = None
if value in ["false", "true"]:
value_converted = True if value == "true" else False
elif value.isdigit():
value_converted = int(value)
elif (
value.count(".") == 1
and value.index(".") > 0
and value.replace(".", "").isdigit()
):
value_converted = float(value)
kwargs[parameter_insensible] = (
value if value_converted is None else value_converted
)
host_patterns: list[str] = []
if "hosts" in kwargs:
host_patterns = (
kwargs["hosts"].split(",")
if isinstance(kwargs["hosts"], str)
else kwargs["hosts"]
)
del kwargs["hosts"]
return AsyncResolverDescription(
protocol,
specifier,
implementation,
parsed_url.host,
parsed_url.port,
*host_patterns,
**kwargs,
)

View File

@@ -0,0 +1,5 @@
from __future__ import annotations
from ._dict import InMemoryResolver
__all__ = ("InMemoryResolver",)

View File

@@ -0,0 +1,186 @@
from __future__ import annotations
import socket
import typing
from .....util.url import _IPV6_ADDRZ_RE
from ...protocols import ProtocolResolver
from ...utils import is_ipv4, is_ipv6
from ..protocols import AsyncBaseResolver
class InMemoryResolver(AsyncBaseResolver):
protocol = ProtocolResolver.MANUAL
implementation = "dict"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
kwargs.pop("port")
super().__init__(None, None, *patterns, **kwargs)
self._maxsize = 65535 if "maxsize" not in kwargs else int(kwargs["maxsize"])
self._hosts: dict[str, list[tuple[socket.AddressFamily, str]]] = {}
if self._host_patterns:
for record in self._host_patterns:
if ":" not in record:
continue
hostname, addr = record.split(":", 1)
self.register(hostname, addr)
self._host_patterns = tuple([])
def recycle(self) -> AsyncBaseResolver:
return self
async def close(self) -> None: # type: ignore[override]
pass # no-op
def is_available(self) -> bool:
return True
def have_constraints(self) -> bool:
return True
def support(self, hostname: str | bytes | None) -> bool | None:
if hostname is None:
hostname = "localhost"
if isinstance(hostname, bytes):
hostname = hostname.decode("ascii")
return hostname in self._hosts
def register(self, hostname: str, ipaddr: str) -> None:
if hostname not in self._hosts:
self._hosts[hostname] = []
else:
for e in self._hosts[hostname]:
t, addr = e
if addr in ipaddr:
return
if _IPV6_ADDRZ_RE.match(ipaddr):
self._hosts[hostname].append((socket.AF_INET6, ipaddr[1:-1]))
elif is_ipv6(ipaddr):
self._hosts[hostname].append((socket.AF_INET6, ipaddr))
else:
self._hosts[hostname].append((socket.AF_INET, ipaddr))
if len(self._hosts) > self._maxsize:
k = None
for k in self._hosts.keys():
break
if k:
self._hosts.pop(k)
def clear(self, hostname: str) -> None:
if hostname in self._hosts:
del self._hosts[hostname]
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
host = "localhost" # Defensive: stdlib cpy behavior
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
results: list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
] = []
if host not in self._hosts:
raise socket.gaierror(f"no records found for hostname {host} in-memory")
for entry in self._hosts[host]:
addr_type, addr_target = entry
if family != socket.AF_UNSPEC:
if family != addr_type:
continue
results.append(
(
addr_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
(addr_target, port)
if addr_type == socket.AF_INET
else (addr_target, port, 0, 0),
)
)
if not results:
raise socket.gaierror(f"no records found for hostname {host} in-memory")
return results

View File

@@ -0,0 +1,105 @@
from __future__ import annotations
import socket
import typing
from ...protocols import ProtocolResolver
from ...utils import is_ipv4, is_ipv6
from ..protocols import AsyncBaseResolver
class NullResolver(AsyncBaseResolver):
protocol = ProtocolResolver.NULL
implementation = "dummy"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
kwargs.pop("port")
super().__init__(None, None, *patterns, **kwargs)
def recycle(self) -> AsyncBaseResolver:
return self
async def close(self) -> None: # type: ignore[override]
pass # no-op
def is_available(self) -> bool:
return True
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
host = "localhost" # Defensive: stdlib cpy behavior
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
raise socket.gaierror(f"Tried to resolve '{host}' using the NullResolver")
__all__ = ("NullResolver",)

View File

@@ -0,0 +1,375 @@
from __future__ import annotations
import asyncio
import ipaddress
import socket
import struct
import sys
import typing
from abc import ABCMeta, abstractmethod
from datetime import datetime, timedelta, timezone
from ...._constant import UDP_LINUX_GRO
from ...._typing import _TYPE_SOCKET_OPTIONS, _TYPE_TIMEOUT_INTERNAL
from ....exceptions import LocationParseError
from ....util.connection import _set_socket_options, allowed_gai_family
from ....util.timeout import _DEFAULT_TIMEOUT
from ...ssa import AsyncSocket
from ...ssa._timeout import timeout as timeout_
from ..protocols import BaseResolver
class AsyncBaseResolver(BaseResolver, metaclass=ABCMeta):
def recycle(self) -> AsyncBaseResolver:
return super().recycle() # type: ignore[return-value]
@abstractmethod
async def close(self) -> None: # type: ignore[override]
"""Terminate the given resolver instance. This should render it unusable. Further inquiries should raise an exception."""
raise NotImplementedError
@abstractmethod
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
"""This method align itself on the standard library socket.getaddrinfo(). It must be implemented as-is on your Resolver."""
raise NotImplementedError
# This function is copied from socket.py in the Python 2.7 standard
# library test suite. Added to its signature is only `socket_options`.
# One additional modification is that we avoid binding to IPv6 servers
# discovered in DNS if the system doesn't have IPv6 functionality.
async def create_connection( # type: ignore[override]
self,
address: tuple[str, int],
timeout: _TYPE_TIMEOUT_INTERNAL = _DEFAULT_TIMEOUT,
source_address: tuple[str, int] | None = None,
socket_options: _TYPE_SOCKET_OPTIONS | None = None,
socket_kind: socket.SocketKind = socket.SOCK_STREAM,
*,
quic_upgrade_via_dns_rr: bool = False,
timing_hook: typing.Callable[[tuple[timedelta, timedelta, datetime]], None]
| None = None,
default_socket_family: socket.AddressFamily = socket.AF_UNSPEC,
) -> AsyncSocket:
"""Connect to *address* and return the socket object.
Convenience function. Connect to *address* (a 2-tuple ``(host,
port)``) and return the socket object. Passing the optional
*timeout* parameter will set the timeout on the socket instance
before attempting to connect. If no *timeout* is supplied, the
global default timeout setting returned by :func:`socket.getdefaulttimeout`
is used. If *source_address* is set it must be a tuple of (host, port)
for the socket to bind as a source address before making the connection.
An host of '' or port 0 tells the OS to use the default.
"""
host, port = address
if host.startswith("["):
host = host.strip("[]")
err = None
# Using the value from allowed_gai_family() in the context of getaddrinfo lets
# us select whether to work with IPv4 DNS records, IPv6 records, or both.
# The original create_connection function always returns all records.
family = allowed_gai_family()
if family != socket.AF_UNSPEC:
default_socket_family = family
if source_address is not None:
if isinstance(
ipaddress.ip_address(source_address[0]), ipaddress.IPv4Address
):
default_socket_family = socket.AF_INET
else:
default_socket_family = socket.AF_INET6
try:
host.encode("idna")
except UnicodeError:
raise LocationParseError(f"'{host}', label empty or too long") from None
dt_pre_resolve = datetime.now(tz=timezone.utc)
if timeout is not _DEFAULT_TIMEOUT and timeout is not None:
# we can hang here in case of bad networking conditions
# the DNS may never answer or the packets can be lost.
# this isn't possible in sync mode. unfortunately.
# found by user at https://github.com/jawah/niquests/issues/183
# todo: find a way to limit getaddrinfo delays in sync mode.
try:
async with timeout_(timeout):
records = await self.getaddrinfo(
host,
port,
default_socket_family,
socket_kind,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
except TimeoutError:
raise socket.gaierror(
f"unable to resolve '{host}' within timeout. the DNS server may be unresponsive."
)
else:
records = await self.getaddrinfo(
host,
port,
default_socket_family,
socket_kind,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
delta_post_resolve = datetime.now(tz=timezone.utc) - dt_pre_resolve
dt_pre_established = datetime.now(tz=timezone.utc)
for res in records:
af, socktype, proto, canonname, sa = res
sock = None
try:
sock = AsyncSocket(af, socktype, proto)
# we need to add this or reusing the same origin port will likely fail within
# short period of time. kernel put port on wait shut.
if source_address:
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
except (OSError, AttributeError): # Defensive: very old OS?
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
except (
OSError,
AttributeError,
): # Defensive: we can't do anything better than this.
pass
try:
sock.setsockopt(
socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0)
)
except (OSError, AttributeError):
pass
# attempt to leverage GRO when under Linux
if socktype == socket.SOCK_DGRAM and sys.platform == "linux":
try:
sock.setsockopt(socket.SOL_UDP, UDP_LINUX_GRO, 1)
except OSError: # Defensive: oh, well(...) anyway!
pass
# If provided, set socket level options before connecting.
_set_socket_options(sock, socket_options)
if timeout is not _DEFAULT_TIMEOUT:
sock.settimeout(timeout)
if source_address:
sock.bind(source_address)
try:
await sock.connect(sa)
except asyncio.CancelledError:
sock.close()
raise
# Break explicitly a reference cycle
err = None
delta_post_established = (
datetime.now(tz=timezone.utc) - dt_pre_established
)
if timing_hook is not None:
timing_hook(
(
delta_post_resolve,
delta_post_established,
datetime.now(tz=timezone.utc),
)
)
return sock
except (OSError, OverflowError) as _:
err = _
if sock is not None:
sock.close()
if isinstance(_, OverflowError):
break
if err is not None:
try:
raise err
finally:
# Break explicitly a reference cycle
err = None
else:
raise OSError("getaddrinfo returns an empty list")
class AsyncManyResolver(AsyncBaseResolver):
"""
Special resolver that use many child resolver. Priorities
are based on given order (list of BaseResolver).
"""
def __init__(self, *resolvers: AsyncBaseResolver) -> None:
super().__init__(None, None)
self._size = len(resolvers)
self._unconstrained: list[AsyncBaseResolver] = [
_ for _ in resolvers if not _.have_constraints()
]
self._constrained: list[AsyncBaseResolver] = [
_ for _ in resolvers if _.have_constraints()
]
self._concurrent: int = 0
self._terminated: bool = False
def recycle(self) -> AsyncBaseResolver:
resolvers = []
for resolver in self._unconstrained + self._constrained:
resolvers.append(resolver.recycle())
return AsyncManyResolver(*resolvers)
async def close(self) -> None: # type: ignore[override]
for resolver in self._unconstrained + self._constrained:
await resolver.close()
self._terminated = True
def is_available(self) -> bool:
return not self._terminated
def __resolvers(
self, constrained: bool = False
) -> typing.Generator[AsyncBaseResolver, None, None]:
resolvers = self._unconstrained if not constrained else self._constrained
if not resolvers:
return
with self._lock:
self._concurrent += 1
try:
resolver_count = len(resolvers)
start_idx = (self._concurrent - 1) % resolver_count
for idx in range(start_idx, resolver_count):
if not resolvers[idx].is_available():
with self._lock:
resolvers[idx] = resolvers[idx].recycle()
yield resolvers[idx]
if start_idx > 0:
for idx in range(0, start_idx):
if not resolvers[idx].is_available():
with self._lock:
resolvers[idx] = resolvers[idx].recycle()
yield resolvers[idx]
finally:
with self._lock:
self._concurrent -= 1
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if isinstance(host, bytes):
host = host.decode("ascii")
if host is None:
host = "localhost"
tested_resolvers = []
any_constrained_tried: bool = False
for resolver in self.__resolvers(True):
can_resolve = resolver.support(host)
if can_resolve is True:
any_constrained_tried = True
try:
results = await resolver.getaddrinfo(
host,
port,
family,
type,
proto,
flags,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
if results:
return results
except socket.gaierror as exc:
if isinstance(exc.args[0], str) and (
"DNSSEC" in exc.args[0] or "DNSKEY" in exc.args[0]
):
raise
continue
elif can_resolve is False:
tested_resolvers.append(resolver)
if any_constrained_tried:
raise socket.gaierror(
f"Name or service not known: {host} using {self._size - len(self._unconstrained)} resolver(s)"
)
for resolver in self.__resolvers():
try:
results = await resolver.getaddrinfo(
host,
port,
family,
type,
proto,
flags,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
if results:
return results
except socket.gaierror as exc:
if isinstance(exc.args[0], str) and (
"DNSSEC" in exc.args[0] or "DNSKEY" in exc.args[0]
):
raise
continue
raise socket.gaierror(
f"Name or service not known: {host} using {self._size - len(self._constrained)} resolver(s)"
)

View File

@@ -0,0 +1,5 @@
from __future__ import annotations
from ._socket import SystemResolver
__all__ = ("SystemResolver",)

View File

@@ -0,0 +1,66 @@
from __future__ import annotations
import asyncio
import socket
import typing
from ...protocols import ProtocolResolver
from ..protocols import AsyncBaseResolver
class SystemResolver(AsyncBaseResolver):
implementation = "socket"
protocol = ProtocolResolver.SYSTEM
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
kwargs.pop("port")
super().__init__(None, None, *patterns, **kwargs)
def support(self, hostname: str | bytes | None) -> bool | None:
if hostname is None:
return True
if isinstance(hostname, bytes):
hostname = hostname.decode("ascii")
if hostname == "localhost":
return True
return super().support(hostname)
def recycle(self) -> AsyncBaseResolver:
return self
async def close(self) -> None: # type: ignore[override]
pass # no-op!
def is_available(self) -> bool:
return True
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
return await asyncio.get_running_loop().getaddrinfo(
host=host,
port=port,
family=family,
type=type,
proto=proto,
flags=flags,
)

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
from ._urllib3 import (
AdGuardResolver,
CloudflareResolver,
GoogleResolver,
HTTPSResolver,
NextDNSResolver,
OpenDNSResolver,
Quad9Resolver,
)
__all__ = (
"HTTPSResolver",
"GoogleResolver",
"CloudflareResolver",
"AdGuardResolver",
"OpenDNSResolver",
"Quad9Resolver",
"NextDNSResolver",
)

View File

@@ -0,0 +1,641 @@
from __future__ import annotations
import socket
import typing
from base64 import b64encode
from ...._collections import HTTPHeaderDict
from ....backend import ConnectionInfo, HttpVersion, ResponsePromise
from ....connectionpool import HTTPSConnectionPool
from ....response import HTTPResponse
from ....util.url import parse_url
from ..protocols import (
BaseResolver,
DomainNameServerQuery,
DomainNameServerReturn,
ProtocolResolver,
SupportedQueryType,
)
from ..utils import is_ipv4, is_ipv6, validate_length_of, parse_https_rdata
class HTTPSResolver(BaseResolver):
"""
Advanced DNS over HTTPS resolver.
No common ground emerged from IETF w/ JSON. Following Googles DNS over HTTPS schematics that is
also implemented at Cloudflare.
Support RFC 8484 without JSON. Disabled by default.
"""
implementation = "urllib3"
protocol = ProtocolResolver.DOH
def __init__(
self,
server: str | None,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
) -> None:
super().__init__(server, port or 443, *patterns, **kwargs)
self._path: str = "/resolve"
if "path" in kwargs:
if isinstance(kwargs["path"], str) and kwargs["path"] != "/":
self._path = kwargs["path"]
kwargs.pop("path")
self._rfc8484: bool = False
if "rfc8484" in kwargs:
if kwargs["rfc8484"]:
self._rfc8484 = True
kwargs.pop("rfc8484")
assert self._server is not None
if "source_address" in kwargs:
if isinstance(kwargs["source_address"], str):
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
if bind_ip and bind_port.isdigit():
kwargs["source_address"] = (
bind_ip,
int(bind_port),
)
else:
raise ValueError("invalid source_address given in parameters")
else:
raise ValueError("invalid source_address given in parameters")
if "proxy" in kwargs:
kwargs["_proxy"] = parse_url(kwargs["proxy"])
kwargs.pop("proxy")
if "maxsize" not in kwargs:
kwargs["maxsize"] = 10
if "proxy_headers" in kwargs and "_proxy" in kwargs:
proxy_headers = HTTPHeaderDict()
if not isinstance(kwargs["proxy_headers"], list):
kwargs["proxy_headers"] = [kwargs["proxy_headers"]]
for item in kwargs["proxy_headers"]:
if ":" not in item:
raise ValueError("Passed header is invalid in DNS parameters")
k, v = item.split(":", 1)
proxy_headers.add(k, v)
kwargs["_proxy_headers"] = proxy_headers
if "headers" in kwargs:
headers = HTTPHeaderDict()
if not isinstance(kwargs["headers"], list):
kwargs["headers"] = [kwargs["headers"]]
for item in kwargs["headers"]:
if ":" not in item:
raise ValueError("Passed header is invalid in DNS parameters")
k, v = item.split(":", 1)
headers.add(k, v)
kwargs["headers"] = headers
if "disabled_svn" in kwargs:
if not isinstance(kwargs["disabled_svn"], list):
kwargs["disabled_svn"] = [kwargs["disabled_svn"]]
disabled_svn = set()
for svn in kwargs["disabled_svn"]:
svn = svn.lower()
if svn == "h11":
disabled_svn.add(HttpVersion.h11)
elif svn == "h2":
disabled_svn.add(HttpVersion.h2)
elif svn == "h3":
disabled_svn.add(HttpVersion.h3)
kwargs["disabled_svn"] = disabled_svn
if "on_post_connection" in kwargs and callable(kwargs["on_post_connection"]):
self._connection_callback: (
typing.Callable[[ConnectionInfo], None] | None
) = kwargs["on_post_connection"]
kwargs.pop("on_post_connection")
else:
self._connection_callback = None
self._pool = HTTPSConnectionPool(self._server, self._port, **kwargs)
def close(self) -> None:
self._pool.close()
def is_available(self) -> bool:
return self._pool.pool is not None
def getaddrinfo(
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Tried to resolve 'localhost' from a HTTPSResolver"
)
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror("Address family for hostname not supported")
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror("Address family for hostname not supported")
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
validate_length_of(host)
promises: list[HTTPResponse | ResponsePromise] = []
remote_preemptive_quic_rr = False
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
quic_upgrade_via_dns_rr = False
if family in [socket.AF_UNSPEC, socket.AF_INET]:
if not self._rfc8484:
promises.append(
self._pool.request_encode_url(
"GET",
self._path,
{"name": host, "type": "1"},
headers={"Accept": "application/dns-json"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
else:
dns_query = DomainNameServerQuery(
host, SupportedQueryType.A, override_id=0
)
dns_payload = bytes(dns_query)
promises.append(
self._pool.request_encode_url(
"GET",
self._path,
{
"dns": b64encode(dns_payload).decode().replace("=", ""),
},
headers={"Accept": "application/dns-message"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
if not self._rfc8484:
promises.append(
self._pool.request_encode_url(
"GET",
self._path,
{"name": host, "type": "28"},
headers={"Accept": "application/dns-json"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
else:
dns_query = DomainNameServerQuery(
host, SupportedQueryType.AAAA, override_id=0
)
dns_payload = bytes(dns_query)
promises.append(
self._pool.request_encode_url(
"GET",
self._path,
{
"dns": b64encode(dns_payload).decode().replace("=", ""),
},
headers={"Accept": "application/dns-message"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
if quic_upgrade_via_dns_rr:
if not self._rfc8484:
promises.append(
self._pool.request_encode_url(
"GET",
self._path,
{"name": host, "type": "65"},
headers={"Accept": "application/dns-json"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
else:
dns_query = DomainNameServerQuery(
host, SupportedQueryType.HTTPS, override_id=0
)
dns_payload = bytes(dns_query)
promises.append(
self._pool.request_encode_url(
"GET",
self._path,
{
"dns": b64encode(dns_payload).decode().replace("=", ""),
},
headers={"Accept": "application/dns-message"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
responses: list[HTTPResponse] = []
for promise in promises:
if isinstance(promise, HTTPResponse):
responses.append(promise)
continue
responses.append(self._pool.get_response(promise=promise)) # type: ignore[arg-type]
results: list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
] = []
for response in responses:
if response.status >= 300:
raise socket.gaierror(
f"DNS over HTTPS was unsuccessful, server response status {response.status}."
)
if not self._rfc8484:
payload = response.json()
assert "Status" in payload and isinstance(payload["Status"], int)
if payload["Status"] != 0:
msg = (
payload["Comment"]
if "Comment" in payload
else f"Remote DNS indicated that an error occurred while providing resolution. Status {payload['Status']}."
)
if isinstance(msg, list):
msg = ", ".join(msg)
raise socket.gaierror(msg)
assert "Question" in payload and isinstance(payload["Question"], list)
if "Answer" not in payload:
continue
assert isinstance(payload["Answer"], list)
for answer in payload["Answer"]:
if answer["type"] not in [1, 28, 65]:
continue
assert "data" in answer
assert isinstance(answer["data"], str)
# DNS RR/HTTPS
if answer["type"] == 65:
# "1 . alpn=h3,h2 ipv4hint=104.16.132.229,104.16.133.229 ipv6hint=2606:4700::6810:84e5,2606:4700::6810:85e5"
# or..
# "1 . alpn=h2,h3"
rr: str = answer["data"]
if rr.startswith("\\#"): # it means, raw, bytes.
rr = "".join(rr[2:].split(" ")[2:])
try:
raw_record = bytes.fromhex(rr)
except ValueError:
raw_record = b""
https_record = parse_https_rdata(raw_record)
if "h3" not in https_record["alpn"]:
continue
remote_preemptive_quic_rr = True
else:
rr_decode: dict[str, str] = dict(
tuple(_.lower().split("=", 1)) # type: ignore[misc]
for _ in rr.split(" ")
if "=" in _
)
if "alpn" not in rr_decode or "h3" not in rr_decode["alpn"]:
continue
remote_preemptive_quic_rr = True
if "ipv4hint" in rr_decode and family in [
socket.AF_UNSPEC,
socket.AF_INET,
]:
for ipv4 in rr_decode["ipv4hint"].split(","):
results.append(
(
socket.AF_INET,
socket.SOCK_DGRAM,
17,
"",
(
ipv4,
port,
),
)
)
if "ipv6hint" in rr_decode and family in [
socket.AF_UNSPEC,
socket.AF_INET6,
]:
for ipv6 in rr_decode["ipv6hint"].split(","):
results.append(
(
socket.AF_INET6,
socket.SOCK_DGRAM,
17,
"",
(
ipv6,
port,
0,
0,
),
)
)
continue
inet_type = (
socket.AF_INET if answer["type"] == 1 else socket.AF_INET6
)
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
(
answer["data"],
port,
)
if inet_type == socket.AF_INET
else (
answer["data"],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
else:
dns_resp = DomainNameServerReturn(response.data)
for record in dns_resp.records:
if record[0] == SupportedQueryType.HTTPS:
assert isinstance(record[-1], dict)
if "h3" in record[-1]["alpn"]:
remote_preemptive_quic_rr = True
continue
assert not isinstance(record[-1], dict)
inet_type = (
socket.AF_INET
if record[0] == SupportedQueryType.A
else socket.AF_INET6
)
dst_addr = (
(
record[-1],
port,
)
if inet_type == socket.AF_INET
else (
record[-1],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
quic_results: list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
] = []
if remote_preemptive_quic_rr:
any_specified = False
for result in results:
if result[1] == socket.SOCK_STREAM:
quic_results.append(
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
)
else:
any_specified = True
break
if any_specified:
quic_results = []
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
class GoogleResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "google"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
if "rfc8484" in kwargs:
if kwargs["rfc8484"]:
kwargs["path"] = "/dns-query"
super().__init__("dns.google", port, *patterns, **kwargs)
class CloudflareResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "cloudflare"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query"})
super().__init__("cloudflare-dns.com", port, *patterns, **kwargs)
class AdGuardResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query", "rfc8484": True})
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
class OpenDNSResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "opendns"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query", "rfc8484": True})
super().__init__("dns.opendns.com", port, *patterns, **kwargs)
class Quad9Resolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "quad9"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query", "rfc8484": True})
super().__init__("dns11.quad9.net", port, *patterns, **kwargs)
class NextDNSResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "nextdns"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.nextdns.io", port, *patterns, **kwargs)

View File

@@ -0,0 +1,15 @@
from __future__ import annotations
try:
from ._qh3 import AdGuardResolver, NextDNSResolver, QUICResolver
except ImportError:
QUICResolver = None # type: ignore
AdGuardResolver = None # type: ignore
NextDNSResolver = None # type: ignore
__all__ = (
"QUICResolver",
"AdGuardResolver",
"NextDNSResolver",
)

View File

@@ -0,0 +1,541 @@
from __future__ import annotations
import socket
import ssl
import typing
from collections import deque
from ssl import SSLError
from time import time as monotonic
from qh3.quic.configuration import QuicConfiguration
from qh3.quic.connection import QuicConnection
from qh3.quic.events import (
ConnectionTerminated,
HandshakeCompleted,
QuicEvent,
StopSendingReceived,
StreamDataReceived,
StreamReset,
)
from ....util.ssl_ import IS_FIPS, resolve_cert_reqs
from ...ssa._gro import _sock_has_gro, _sock_has_gso, sync_recv_gro, sync_sendmsg_gso
from ..dou import PlainResolver
from ..protocols import (
COMMON_RCODE_LABEL,
DomainNameServerQuery,
DomainNameServerReturn,
ProtocolResolver,
SupportedQueryType,
)
from ..utils import (
is_ipv4,
is_ipv6,
rfc1035_pack,
rfc1035_should_read,
rfc1035_unpack,
validate_length_of,
)
if IS_FIPS:
raise ImportError(
"DNS-over-QUIC disabled when Python is built with FIPS-compliant ssl module"
)
class QUICResolver(PlainResolver):
protocol = ProtocolResolver.DOQ
implementation = "qh3"
def __init__(
self,
server: str,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
):
super().__init__(server, port or 853, *patterns, **kwargs)
# qh3 load_default_certs seems off. need to investigate.
if "ca_cert_data" not in kwargs and "ca_certs" not in kwargs:
kwargs["ca_cert_data"] = []
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
try:
ctx.load_default_certs()
for der in ctx.get_ca_certs(binary_form=True):
kwargs["ca_cert_data"].append(ssl.DER_cert_to_PEM_cert(der))
if kwargs["ca_cert_data"]:
kwargs["ca_cert_data"] = "".join(kwargs["ca_cert_data"])
else:
del kwargs["ca_cert_data"]
except (AttributeError, ValueError, OSError):
del kwargs["ca_cert_data"]
if "ca_cert_data" not in kwargs and "ca_certs" not in kwargs:
if (
"cert_reqs" not in kwargs
or resolve_cert_reqs(kwargs["cert_reqs"]) is ssl.CERT_REQUIRED
):
raise ssl.SSLError(
"DoQ requires at least one CA loaded in order to verify the remote peer certificate. "
"Add ?cert_reqs=0 to disable certificate checks."
)
configuration = QuicConfiguration(
is_client=True,
alpn_protocols=["doq"],
server_name=self._server
if "server_hostname" not in kwargs
else kwargs["server_hostname"],
verify_mode=resolve_cert_reqs(kwargs["cert_reqs"])
if "cert_reqs" in kwargs
else ssl.CERT_REQUIRED,
cadata=kwargs["ca_cert_data"].encode()
if "ca_cert_data" in kwargs
else None,
cafile=kwargs["ca_certs"] if "ca_certs" in kwargs else None,
idle_timeout=300.0,
)
if "cert_file" in kwargs:
configuration.load_cert_chain(
kwargs["cert_file"],
kwargs["key_file"] if "key_file" in kwargs else None,
kwargs["key_password"] if "key_password" in kwargs else None,
)
elif "cert_data" in kwargs:
configuration.load_cert_chain(
kwargs["cert_data"],
kwargs["key_data"] if "key_data" in kwargs else None,
kwargs["key_password"] if "key_password" in kwargs else None,
)
self._quic = QuicConnection(configuration=configuration)
self._dgram_gro_enabled: bool = _sock_has_gro(self._socket)
self._dgram_gso_enabled: bool = _sock_has_gso(self._socket)
self._quic.connect((self._server, self._port), monotonic())
self.__exchange_until(HandshakeCompleted, receive_first=False)
self._terminated: bool = False
self._should_disconnect: bool = False
# DNS over QUIC mandate the size-prefix (unsigned int, 2b)
self._rfc1035_prefix_mandated = True
self._unconsumed: deque[DomainNameServerReturn] = deque()
self._pending: deque[DomainNameServerQuery] = deque()
def close(self) -> None:
if not self._terminated:
with self._lock:
self._quic.close()
while True:
datagrams = self._quic.datagrams_to_send(monotonic())
if not datagrams:
break
if self._dgram_gso_enabled and len(datagrams) > 1:
sync_sendmsg_gso(self._socket, [d[0] for d in datagrams])
else:
for datagram in datagrams:
self._socket.sendall(datagram[0])
self._socket.close()
self._terminated = True
def is_available(self) -> bool:
self._quic.handle_timer(monotonic())
if hasattr(self._quic, "_close_event") and self._quic._close_event is not None:
self._terminated = True
return not self._terminated
def getaddrinfo(
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Tried to resolve 'localhost' using the QUICResolver"
)
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
validate_length_of(host)
remote_preemptive_quic_rr = False
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
quic_upgrade_via_dns_rr = False
tbq = []
if family in [socket.AF_UNSPEC, socket.AF_INET]:
tbq.append(SupportedQueryType.A)
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
tbq.append(SupportedQueryType.AAAA)
if quic_upgrade_via_dns_rr:
tbq.append(SupportedQueryType.HTTPS)
queries = DomainNameServerQuery.bulk(host, *tbq)
open_streams = []
with self._lock:
for q in queries:
payload = bytes(q)
self._pending.append(q)
if self._rfc1035_prefix_mandated is True:
payload = rfc1035_pack(payload)
stream_id = self._quic.get_next_available_stream_id()
self._quic.send_stream_data(stream_id, payload, True)
open_streams.append(stream_id)
datagrams = self._quic.datagrams_to_send(monotonic())
if self._dgram_gso_enabled and len(datagrams) > 1:
sync_sendmsg_gso(self._socket, [d[0] for d in datagrams])
else:
for dg in datagrams:
self._socket.sendall(dg[0])
responses: list[DomainNameServerReturn] = []
while len(responses) < len(tbq):
with self._lock:
if self._unconsumed:
dns_resp = None
for query in queries:
for unconsumed in self._unconsumed:
if unconsumed.id == query.id:
dns_resp = unconsumed
responses.append(dns_resp)
break
if dns_resp:
break
if dns_resp:
self._unconsumed.remove(dns_resp)
self._pending.remove(query)
continue
try:
events: list[StreamDataReceived] = self.__exchange_until( # type: ignore[assignment]
StreamDataReceived,
receive_first=True,
event_type_collectable=(StreamDataReceived,),
respect_end_stream_signal=False,
)
payload = b"".join([e.data for e in events])
while rfc1035_should_read(payload):
events.extend(
self.__exchange_until( # type: ignore[arg-type]
StreamDataReceived,
receive_first=True,
event_type_collectable=(StreamDataReceived,),
respect_end_stream_signal=False,
)
)
payload = b"".join([e.data for e in events])
except (TimeoutError, OSError, socket.timeout, ConnectionError) as e:
raise socket.gaierror(
"Got unexpectedly disconnected while waiting for name resolution"
) from e
if not payload:
continue
#: We can receive two responses at once (or more, concatenated). Let's unwrap them.
fragments = rfc1035_unpack(payload)
for fragment in fragments:
dns_resp = DomainNameServerReturn(fragment)
if any(dns_resp.id == _.id for _ in queries):
responses.append(dns_resp)
query_tbr: DomainNameServerQuery | None = None
for query_tbr in self._pending:
if query_tbr.id == dns_resp.id:
break
if query_tbr:
self._pending.remove(query_tbr)
else:
self._unconsumed.append(dns_resp)
if self._should_disconnect:
with self._lock:
self.close()
self._should_disconnect = False
self._terminated = True
results = []
for response in responses:
if not response.is_ok:
if response.rcode == 2:
raise socket.gaierror(
f"DNSSEC validation failure. Check http://dnsviz.net/d/{host}/dnssec/ and http://dnssec-debugger.verisignlabs.com/{host} for errors"
)
raise socket.gaierror(
f"DNS returned an error: {COMMON_RCODE_LABEL[response.rcode] if response.rcode in COMMON_RCODE_LABEL else f'code {response.rcode}'}"
)
for record in response.records:
if record[0] == SupportedQueryType.HTTPS:
assert isinstance(record[-1], dict)
if "h3" in record[-1]["alpn"]:
remote_preemptive_quic_rr = True
continue
assert not isinstance(record[-1], dict)
inet_type = (
socket.AF_INET
if record[0] == SupportedQueryType.A
else socket.AF_INET6
)
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
(
record[-1],
port,
)
if inet_type == socket.AF_INET
else (
record[-1],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
quic_results = []
if remote_preemptive_quic_rr:
any_specified = False
for result in results:
if result[1] == socket.SOCK_STREAM:
quic_results.append(
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
)
else:
any_specified = True
break
if any_specified:
quic_results = []
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
def __exchange_until(
self,
event_type: type[QuicEvent] | tuple[type[QuicEvent], ...],
*,
receive_first: bool = False,
event_type_collectable: type[QuicEvent]
| tuple[type[QuicEvent], ...]
| None = None,
respect_end_stream_signal: bool = True,
) -> list[QuicEvent]:
while True:
if receive_first is False:
now = monotonic()
while True:
datagrams = self._quic.datagrams_to_send(now)
if not datagrams:
break
if self._dgram_gso_enabled and len(datagrams) > 1:
sync_sendmsg_gso(self._socket, [d[0] for d in datagrams])
else:
for datagram in datagrams:
self._socket.sendall(datagram[0])
events = []
while True:
if not self._quic._events:
if self._dgram_gro_enabled:
data_in = sync_recv_gro(self._socket, 65535)
else:
data_in = self._socket.recv(1500)
if not data_in:
break
now = monotonic()
if isinstance(data_in, list):
for gro_segment in data_in:
self._quic.receive_datagram(
gro_segment, (self._server, self._port), now
)
else:
self._quic.receive_datagram(
data_in, (self._server, self._port), now
)
while True:
now = monotonic()
datagrams = self._quic.datagrams_to_send(now)
if not datagrams:
break
if self._dgram_gso_enabled and len(datagrams) > 1:
sync_sendmsg_gso(self._socket, [d[0] for d in datagrams])
else:
for datagram in datagrams:
self._socket.sendall(datagram[0])
for ev in iter(self._quic.next_event, None):
if isinstance(ev, ConnectionTerminated):
if ev.error_code == 298:
raise SSLError(
"DNS over QUIC did not succeed (Error 298). Chain certificate verification failed."
)
raise socket.gaierror(
f"DNS over QUIC encountered a unrecoverable failure (error {ev.error_code} {ev.reason_phrase})"
)
elif isinstance(ev, StreamReset):
self._terminated = True
raise socket.gaierror(
"DNS over QUIC server submitted a StreamReset. A request was rejected."
)
elif isinstance(ev, StopSendingReceived):
self._should_disconnect = True
continue
if event_type_collectable:
if isinstance(ev, event_type_collectable):
events.append(ev)
else:
events.append(ev)
if isinstance(ev, event_type):
if not respect_end_stream_signal:
return events
if hasattr(ev, "stream_ended") and ev.stream_ended:
return events
elif hasattr(ev, "stream_ended") is False:
return events
return events
class AdGuardResolver(
QUICResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
class NextDNSResolver(
QUICResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "nextdns"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.nextdns.io", port, *patterns, **kwargs)

View File

@@ -0,0 +1,19 @@
from __future__ import annotations
from ._ssl import (
AdGuardResolver,
CloudflareResolver,
GoogleResolver,
OpenDNSResolver,
Quad9Resolver,
TLSResolver,
)
__all__ = (
"TLSResolver",
"GoogleResolver",
"CloudflareResolver",
"AdGuardResolver",
"Quad9Resolver",
"OpenDNSResolver",
)

View File

@@ -0,0 +1,156 @@
from __future__ import annotations
import socket
import typing
from ....util.ssl_ import resolve_cert_reqs, ssl_wrap_socket
from ..dou import PlainResolver
from ..protocols import ProtocolResolver
from ..system import SystemResolver
class TLSResolver(PlainResolver):
"""
Basic DNS resolver over TLS.
Comply with RFC 7858: https://datatracker.ietf.org/doc/html/rfc7858
"""
protocol = ProtocolResolver.DOT
implementation = "ssl"
def __init__(
self,
server: str,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
) -> None:
if "timeout" in kwargs and isinstance(kwargs["timeout"], (int, float)):
timeout = kwargs["timeout"]
else:
timeout = None
if "source_address" in kwargs and isinstance(kwargs["source_address"], str):
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
else:
bind_ip, bind_port = "0.0.0.0", "0"
self._socket = SystemResolver().create_connection(
(server, port or 853),
timeout=timeout,
source_address=(bind_ip, int(bind_port))
if bind_ip != "0.0.0.0" or bind_port != "0"
else None,
socket_options=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1, "tcp"),),
socket_kind=socket.SOCK_STREAM,
)
super().__init__(server, port, *patterns, **kwargs)
self._socket = ssl_wrap_socket(
self._socket,
server_hostname=server
if "server_hostname" not in kwargs
else kwargs["server_hostname"],
keyfile=kwargs["key_file"] if "key_file" in kwargs else None,
certfile=kwargs["cert_file"] if "cert_file" in kwargs else None,
cert_reqs=resolve_cert_reqs(kwargs["cert_reqs"])
if "cert_reqs" in kwargs
else None,
ca_certs=kwargs["ca_certs"] if "ca_certs" in kwargs else None,
ssl_version=kwargs["ssl_version"] if "ssl_version" in kwargs else None,
ciphers=kwargs["ciphers"] if "ciphers" in kwargs else None,
ca_cert_dir=kwargs["ca_cert_dir"] if "ca_cert_dir" in kwargs else None,
key_password=kwargs["key_password"] if "key_password" in kwargs else None,
ca_cert_data=kwargs["ca_cert_data"] if "ca_cert_data" in kwargs else None,
certdata=kwargs["cert_data"] if "cert_data" in kwargs else None,
keydata=kwargs["key_data"] if "key_data" in kwargs else None,
)
# DNS over TLS mandate the size-prefix (unsigned int, 2 bytes)
self._rfc1035_prefix_mandated = True
class GoogleResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "google"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.google", port, *patterns, **kwargs)
class CloudflareResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "cloudflare"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("1.1.1.1", port, *patterns, **kwargs)
class AdGuardResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
class OpenDNSResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "opendns"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.opendns.com", port, *patterns, **kwargs)
class Quad9Resolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "quad9"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns11.quad9.net", port, *patterns, **kwargs)

View File

@@ -0,0 +1,17 @@
from __future__ import annotations
from ._socket import (
AdGuardResolver,
CloudflareResolver,
GoogleResolver,
PlainResolver,
Quad9Resolver,
)
__all__ = (
"PlainResolver",
"CloudflareResolver",
"GoogleResolver",
"Quad9Resolver",
"AdGuardResolver",
)

View File

@@ -0,0 +1,415 @@
from __future__ import annotations
import socket
import typing
from collections import deque
from ...ssa._gro import _sock_has_gro, sync_recv_gro
from ..protocols import (
COMMON_RCODE_LABEL,
BaseResolver,
DomainNameServerQuery,
DomainNameServerReturn,
ProtocolResolver,
SupportedQueryType,
)
from ..system import SystemResolver
from ..utils import (
is_ipv4,
is_ipv6,
packet_fragment,
rfc1035_pack,
rfc1035_should_read,
rfc1035_unpack,
validate_length_of,
)
class PlainResolver(BaseResolver):
"""
Minimalist DNS resolver over UDP
Comply with RFC 1035: https://datatracker.ietf.org/doc/html/rfc1035
EDNS is not supported, yet. But we plan to. Willing to contribute?
"""
protocol = ProtocolResolver.DOU
implementation = "socket"
def __init__(
self,
server: str,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
) -> None:
super().__init__(server, port, *patterns, **kwargs)
if not hasattr(self, "_socket"):
if "timeout" in kwargs and isinstance(
kwargs["timeout"],
(
float,
int,
),
):
timeout = kwargs["timeout"]
else:
timeout = None
if "source_address" in kwargs and isinstance(kwargs["source_address"], str):
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
else:
bind_ip, bind_port = "0.0.0.0", "0"
self._socket = SystemResolver().create_connection(
(server, port or 53),
timeout=timeout,
source_address=(bind_ip, int(bind_port))
if bind_ip != "0.0.0.0" or bind_port != "0"
else None,
socket_options=None,
socket_kind=socket.SOCK_DGRAM,
)
#: Only useful for inheritance, e.g. DNS over TLS support dns-message but require a prefix.
self._rfc1035_prefix_mandated: bool = False
self._gro_enabled: bool = _sock_has_gro(self._socket)
self._unconsumed: deque[DomainNameServerReturn] = deque()
self._pending: deque[DomainNameServerQuery] = deque()
self._terminated: bool = False
def close(self) -> None:
if not self._terminated:
with self._lock:
if self._socket is not None:
self._socket.shutdown(0)
self._socket.close()
self._terminated = True
def is_available(self) -> bool:
return not self._terminated
def getaddrinfo(
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Tried to resolve 'localhost' from a PlainResolver"
)
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
validate_length_of(host)
remote_preemptive_quic_rr = False
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
quic_upgrade_via_dns_rr = False
tbq = []
if family in [socket.AF_UNSPEC, socket.AF_INET]:
tbq.append(SupportedQueryType.A)
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
tbq.append(SupportedQueryType.AAAA)
if quic_upgrade_via_dns_rr:
tbq.append(SupportedQueryType.HTTPS)
queries = DomainNameServerQuery.bulk(host, *tbq)
with self._lock:
for q in queries:
payload = bytes(q)
self._pending.append(q)
if self._rfc1035_prefix_mandated is True:
payload = rfc1035_pack(payload)
self._socket.sendall(payload)
responses: list[DomainNameServerReturn] = []
while len(responses) < len(tbq):
with self._lock:
#: There we want to verify if another thread got a response that belong to this thread.
if self._unconsumed:
dns_resp = None
for query in queries:
for unconsumed in self._unconsumed:
if unconsumed.id == query.id:
dns_resp = unconsumed
responses.append(dns_resp)
break
if dns_resp:
break
if dns_resp:
self._pending.remove(query)
self._unconsumed.remove(dns_resp)
continue
try:
if self._gro_enabled:
data_in_or_segments = sync_recv_gro(self._socket, 65535)
else:
data_in_or_segments = self._socket.recv(1500)
if isinstance(data_in_or_segments, list):
payloads = data_in_or_segments
elif data_in_or_segments:
payloads = [data_in_or_segments]
else:
payloads = []
if self._rfc1035_prefix_mandated is True and payloads:
payload = b"".join(payloads)
while rfc1035_should_read(payload):
extra = self._socket.recv(1500)
if isinstance(extra, list):
payload += b"".join(extra)
else:
payload += extra
payloads = [payload]
except (TimeoutError, OSError, socket.timeout, ConnectionError) as e:
raise socket.gaierror(
"Got unexpectedly disconnected while waiting for name resolution"
) from e
if not payloads:
self._terminated = True
raise socket.gaierror(
"Got unexpectedly disconnected while waiting for name resolution"
)
pending_raw_identifiers = [_.raw_id for _ in self._pending]
for payload in payloads:
#: We can receive two responses at once (or more, concatenated). Let's unwrap them.
if self._rfc1035_prefix_mandated is True:
fragments = rfc1035_unpack(payload)
else:
fragments = packet_fragment(payload, *pending_raw_identifiers)
for fragment in fragments:
dns_resp = DomainNameServerReturn(fragment)
if any(dns_resp.id == _.id for _ in queries):
responses.append(dns_resp)
query_tbr: DomainNameServerQuery | None = None
for query_tbr in self._pending:
if query_tbr.id == dns_resp.id:
break
if query_tbr:
self._pending.remove(query_tbr)
else:
self._unconsumed.append(dns_resp)
results = []
for response in responses:
if not response.is_ok:
if response.rcode == 2:
raise socket.gaierror(
f"DNSSEC validation failure. Check http://dnsviz.net/d/{host}/dnssec/ and http://dnssec-debugger.verisignlabs.com/{host} for errors"
)
raise socket.gaierror(
f"DNS returned an error: {COMMON_RCODE_LABEL[response.rcode] if response.rcode in COMMON_RCODE_LABEL else f'code {response.rcode}'}"
)
for record in response.records:
if record[0] == SupportedQueryType.HTTPS:
assert isinstance(record[-1], dict)
if "h3" in record[-1]["alpn"]:
remote_preemptive_quic_rr = True
continue
assert not isinstance(record[-1], dict)
inet_type = (
socket.AF_INET
if record[0] == SupportedQueryType.A
else socket.AF_INET6
)
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
(
record[-1],
port,
)
if inet_type == socket.AF_INET
else (
record[-1],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
quic_results = []
if remote_preemptive_quic_rr:
any_specified = False
for result in results:
if result[1] == socket.SOCK_STREAM:
quic_results.append(
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
)
else:
any_specified = True
break
if any_specified:
quic_results = []
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
class CloudflareResolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "cloudflare"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("1.1.1.1", port, *patterns, **kwargs)
class GoogleResolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "google"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("8.8.8.8", port, *patterns, **kwargs)
class Quad9Resolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "quad9"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("9.9.9.9", port, *patterns, **kwargs)
class AdGuardResolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("94.140.14.140", port, *patterns, **kwargs)

View File

@@ -0,0 +1,230 @@
from __future__ import annotations
import importlib
import inspect
import typing
from abc import ABCMeta
from base64 import b64encode
from typing import Any
from urllib.parse import parse_qs
from ...util import parse_url
from .protocols import BaseResolver, ProtocolResolver
class ResolverFactory(metaclass=ABCMeta):
@staticmethod
def new(
protocol: ProtocolResolver,
specifier: str | None = None,
implementation: str | None = None,
**kwargs: Any,
) -> BaseResolver:
package_name: str = __name__.split(".")[0]
module_expr = f".{protocol.value.replace('-', '_')}"
if implementation:
module_expr += f"._{implementation.replace('-', '_').lower()}"
spe_msg = " " if specifier is None else f' (w/ specifier "{specifier}") '
try:
resolver_module = importlib.import_module(
module_expr, f"{package_name}.contrib.resolver"
)
except ImportError as e:
raise NotImplementedError(
f"{protocol}{spe_msg}cannot be loaded. Tried to import '{module_expr}'. Did you specify a non-existent implementation?"
) from e
implementations: list[tuple[str, type[BaseResolver]]] = inspect.getmembers(
resolver_module,
lambda e: isinstance(e, type)
and issubclass(e, BaseResolver)
and (
(specifier is None and e.specifier is None) or specifier == e.specifier
),
)
if not implementations:
raise NotImplementedError(
f"{protocol}{spe_msg}cannot be loaded. "
"No compatible implementation available. "
"Make sure your implementation inherit from BaseResolver."
)
implementation_target: type[BaseResolver] = implementations.pop()[1]
return implementation_target(**kwargs)
class ResolverDescription:
"""Describe how a BaseResolver must be instantiated."""
def __init__(
self,
protocol: ProtocolResolver,
specifier: str | None = None,
implementation: str | None = None,
server: str | None = None,
port: int | None = None,
*host_patterns: str,
**kwargs: typing.Any,
) -> None:
self.protocol = protocol
self.specifier = specifier
self.implementation = implementation
self.server = server
self.port = port
self.host_patterns = host_patterns
self.kwargs = kwargs
def __setitem__(self, key: str, value: typing.Any) -> None:
self.kwargs[key] = value
def __contains__(self, item: str) -> bool:
return item in self.kwargs
def new(self) -> BaseResolver:
kwargs = {**self.kwargs}
if self.server:
kwargs["server"] = self.server
if self.port:
kwargs["port"] = self.port
if self.host_patterns:
kwargs["patterns"] = self.host_patterns
return ResolverFactory.new(
self.protocol,
self.specifier,
self.implementation,
**kwargs,
)
@staticmethod
def from_url(url: str) -> ResolverDescription:
parsed_url = parse_url(url)
schema = parsed_url.scheme
if schema is None:
raise ValueError("Given DNS url is missing a protocol")
specifier = None
implementation = None
if "+" in schema:
schema, specifier = tuple(schema.lower().split("+", 1))
protocol = ProtocolResolver(schema)
kwargs: dict[str, typing.Any] = {}
if parsed_url.path:
kwargs["path"] = parsed_url.path
if parsed_url.auth:
kwargs["headers"] = dict()
if ":" in parsed_url.auth:
username, password = parsed_url.auth.split(":")
username = username.strip("'\"")
password = password.strip("'\"")
kwargs["headers"]["Authorization"] = (
f"Basic {b64encode(f'{username}:{password}'.encode()).decode()}"
)
else:
kwargs["headers"]["Authorization"] = f"Bearer {parsed_url.auth}"
if parsed_url.query:
parameters = parse_qs(parsed_url.query)
for parameter in parameters:
if not parameters[parameter]:
continue
parameter_insensible = parameter.lower()
if (
isinstance(parameters[parameter], list)
and len(parameters[parameter]) > 1
):
if parameter == "implementation":
raise ValueError("Only one implementation can be passed to URL")
values = []
for e in parameters[parameter]:
if "," in e:
values.extend(e.split(","))
else:
values.append(e)
if parameter_insensible in kwargs:
if isinstance(kwargs[parameter_insensible], list):
kwargs[parameter_insensible].extend(values)
else:
values.append(kwargs[parameter_insensible])
kwargs[parameter_insensible] = values
continue
kwargs[parameter_insensible] = values
continue
value: str = parameters[parameter][0].lower().strip(" ")
if parameter == "implementation":
implementation = value
continue
if "," in value:
list_of_values = value.split(",")
if parameter_insensible in kwargs:
if isinstance(kwargs[parameter_insensible], list):
kwargs[parameter_insensible].extend(list_of_values)
else:
list_of_values.append(kwargs[parameter_insensible])
continue
kwargs[parameter_insensible] = list_of_values
continue
value_converted: bool | int | float | None = None
if value in ["false", "true"]:
value_converted = True if value == "true" else False
elif value.isdigit():
value_converted = int(value)
elif (
value.count(".") == 1
and value.index(".") > 0
and value.replace(".", "").isdigit()
):
value_converted = float(value)
kwargs[parameter_insensible] = (
value if value_converted is None else value_converted
)
host_patterns: list[str] = []
if "hosts" in kwargs:
host_patterns = (
kwargs["hosts"].split(",")
if isinstance(kwargs["hosts"], str)
else kwargs["hosts"]
)
del kwargs["hosts"]
return ResolverDescription(
protocol,
specifier,
implementation,
parsed_url.host,
parsed_url.port,
*host_patterns,
**kwargs,
)

View File

@@ -0,0 +1,5 @@
from __future__ import annotations
from ._dict import InMemoryResolver
__all__ = ("InMemoryResolver",)

View File

@@ -0,0 +1,192 @@
from __future__ import annotations
import socket
import typing
from ....util.url import _IPV6_ADDRZ_RE
from ..protocols import BaseResolver, ProtocolResolver
from ..utils import is_ipv4, is_ipv6
class InMemoryResolver(BaseResolver):
protocol = ProtocolResolver.MANUAL
implementation = "dict"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
kwargs.pop("port")
super().__init__(None, None, *patterns, **kwargs)
self._maxsize = 65535 if "maxsize" not in kwargs else int(kwargs["maxsize"])
self._hosts: dict[str, list[tuple[socket.AddressFamily, str]]] = {}
if self._host_patterns:
for record in self._host_patterns:
if ":" not in record:
continue
hostname, addr = record.split(":", 1)
self.register(hostname, addr)
self._host_patterns = tuple([])
# probably about our happy eyeballs impl (sync only)
if len(self._hosts) == 1 and len(self._hosts[list(self._hosts.keys())[0]]) == 1:
self._unsafe_expose = True
def recycle(self) -> BaseResolver:
return self
def close(self) -> None:
pass # no-op
def is_available(self) -> bool:
return True
def have_constraints(self) -> bool:
return True
def support(self, hostname: str | bytes | None) -> bool | None:
if hostname is None:
hostname = "localhost"
if isinstance(hostname, bytes):
hostname = hostname.decode("ascii")
return hostname in self._hosts
def register(self, hostname: str, ipaddr: str) -> None:
with self._lock:
if hostname not in self._hosts:
self._hosts[hostname] = []
else:
for e in self._hosts[hostname]:
t, addr = e
if addr in ipaddr:
return
if _IPV6_ADDRZ_RE.match(ipaddr):
self._hosts[hostname].append((socket.AF_INET6, ipaddr[1:-1]))
elif is_ipv6(ipaddr):
self._hosts[hostname].append((socket.AF_INET6, ipaddr))
else:
self._hosts[hostname].append((socket.AF_INET, ipaddr))
if len(self._hosts) > self._maxsize:
k = None
for k in self._hosts.keys():
break
if k:
self._hosts.pop(k)
def clear(self, hostname: str) -> None:
with self._lock:
if hostname in self._hosts:
del self._hosts[hostname]
def getaddrinfo(
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
host = "localhost" # Defensive: stdlib cpy behavior
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
results: list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
] = []
with self._lock:
if host not in self._hosts:
raise socket.gaierror(f"no records found for hostname {host} in-memory")
for entry in self._hosts[host]:
addr_type, addr_target = entry
if family != socket.AF_UNSPEC:
if family != addr_type:
continue
results.append(
(
addr_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
(addr_target, port)
if addr_type == socket.AF_INET
else (addr_target, port, 0, 0),
)
)
if not results:
raise socket.gaierror(f"no records found for hostname {host} in-memory")
return results

View File

@@ -0,0 +1,104 @@
from __future__ import annotations
import socket
import typing
from ..protocols import BaseResolver, ProtocolResolver
from ..utils import is_ipv4, is_ipv6
class NullResolver(BaseResolver):
protocol = ProtocolResolver.NULL
implementation = "dummy"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
kwargs.pop("port")
super().__init__(None, None, *patterns, **kwargs)
def recycle(self) -> BaseResolver:
return self
def close(self) -> None:
pass # no-op
def is_available(self) -> bool:
return True
def getaddrinfo(
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
host = "localhost" # Defensive: stdlib cpy behavior
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror(
"Servname not supported for ai_socktype"
) # Defensive: stdlib cpy behavior
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror(
"Address family for hostname not supported"
) # Defensive: stdlib cpy behavior
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror(
"Address family for hostname not supported"
) # Defensive: stdlib cpy behavior
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
raise socket.gaierror(f"Tried to resolve '{host}' using the NullResolver")
__all__ = ("NullResolver",)

View File

@@ -0,0 +1,655 @@
from __future__ import annotations
import ipaddress
import socket
import struct
import sys
import threading
import typing
from abc import ABCMeta, abstractmethod
from copy import deepcopy
from datetime import datetime, timedelta, timezone
from enum import Enum
from random import randint
from ..._constant import UDP_LINUX_GRO
from ..._typing import _TYPE_SOCKET_OPTIONS, _TYPE_TIMEOUT_INTERNAL
from ...exceptions import LocationParseError
from ...util.connection import _set_socket_options, allowed_gai_family
from ...util.ssl_match_hostname import CertificateError, match_hostname
from ...util.timeout import _DEFAULT_TIMEOUT
from .utils import inet4_ntoa, inet6_ntoa, parse_https_rdata
if typing.TYPE_CHECKING:
from .utils import HttpsRecord
class ProtocolResolver(str, Enum):
"""
At urllib3.future we aim to propose a wide range of DNS-protocols.
The most used techniques are available.
"""
#: Ask the OS native DNS layer
SYSTEM = "system"
#: DNS over HTTPS
DOH = "doh"
#: DNS over QUIC
DOQ = "doq"
#: DNS over TLS
DOT = "dot"
#: DNS over UDP (insecure)
DOU = "dou"
#: Manual (e.g. hosts)
MANUAL = "in-memory"
#: Void (e.g. purposely disable resolution)
NULL = "null"
#: Custom (e.g. your own implementation, use this when it does not suit any of the protocols specified)
CUSTOM = "custom"
class BaseResolver(metaclass=ABCMeta):
protocol: typing.ClassVar[ProtocolResolver]
specifier: typing.ClassVar[str | None] = None
implementation: typing.ClassVar[str]
def __init__(
self,
server: str | None,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
) -> None:
self._server = server
self._port = port
self._host_patterns: tuple[str, ...] = patterns
self._lock = threading.Lock()
self._kwargs = kwargs
if not self._host_patterns and "patterns" in kwargs:
self._host_patterns = kwargs["patterns"]
# allow to temporarily expose a sock that is "being" created
# this helps with our Happy Eyeballs implementation in sync.
self._unsafe_expose: bool = False
self._sock_cursor: socket.socket | None = None
def recycle(self) -> BaseResolver:
if self.is_available():
raise RuntimeError("Attempting to recycle a Resolver that was not closed")
args = list(self.__class__.__init__.__code__.co_varnames)
args.remove("self")
kwargs_cpy = deepcopy(self._kwargs)
if self._server:
kwargs_cpy["server"] = self._server
if self._port:
kwargs_cpy["port"] = self._port
if "patterns" in args and "kwargs" in args:
return self.__class__(*self._host_patterns, **kwargs_cpy) # type: ignore[arg-type]
elif "kwargs" in args:
return self.__class__(**kwargs_cpy)
return self.__class__() # type: ignore[call-arg]
@property
def server(self) -> str | None:
return self._server
@property
def port(self) -> int | None:
return self._port
def have_constraints(self) -> bool:
return bool(self._host_patterns)
def support(self, hostname: str | bytes | None) -> bool | None:
"""
Determine if given hostname is especially resolvable by given resolver.
If this resolver does not have any constrained list of host, it returns None. Meaning
it support any hostname for resolution.
"""
if not self._host_patterns:
return None
if hostname is None:
hostname = "localhost"
if isinstance(hostname, bytes):
hostname = hostname.decode("ascii")
try:
match_hostname(
{"subjectAltName": (tuple(("DNS", e) for e in self._host_patterns))},
hostname,
)
except CertificateError:
return False
return True
@abstractmethod
def close(self) -> None:
"""Terminate the given resolver instance. This should render it unusable. Further inquiries should raise an exception."""
raise NotImplementedError
@abstractmethod
def is_available(self) -> bool:
"""Determine if Resolver can receive inquiries."""
raise NotImplementedError
@abstractmethod
def getaddrinfo(
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
"""This method align itself on the standard library socket.getaddrinfo(). It must be implemented as-is on your Resolver."""
raise NotImplementedError
# This function is copied from socket.py in the Python 2.7 standard
# library test suite. Added to its signature is only `socket_options`.
# One additional modification is that we avoid binding to IPv6 servers
# discovered in DNS if the system doesn't have IPv6 functionality.
def create_connection(
self,
address: tuple[str, int],
timeout: _TYPE_TIMEOUT_INTERNAL = _DEFAULT_TIMEOUT,
source_address: tuple[str, int] | None = None,
socket_options: _TYPE_SOCKET_OPTIONS | None = None,
socket_kind: socket.SocketKind = socket.SOCK_STREAM,
*,
quic_upgrade_via_dns_rr: bool = False,
timing_hook: typing.Callable[[tuple[timedelta, timedelta, datetime]], None]
| None = None,
default_socket_family: socket.AddressFamily = socket.AF_UNSPEC,
) -> socket.socket:
"""Connect to *address* and return the socket object.
Convenience function. Connect to *address* (a 2-tuple ``(host,
port)``) and return the socket object. Passing the optional
*timeout* parameter will set the timeout on the socket instance
before attempting to connect. If no *timeout* is supplied, the
global default timeout setting returned by :func:`socket.getdefaulttimeout`
is used. If *source_address* is set it must be a tuple of (host, port)
for the socket to bind as a source address before making the connection.
An host of '' or port 0 tells the OS to use the default.
"""
host, port = address
if host.startswith("["):
host = host.strip("[]")
err = None
# Using the value from allowed_gai_family() in the context of getaddrinfo lets
# us select whether to work with IPv4 DNS records, IPv6 records, or both.
# The original create_connection function always returns all records.
family = allowed_gai_family()
if family != socket.AF_UNSPEC:
default_socket_family = family
if source_address is not None:
if isinstance(
ipaddress.ip_address(source_address[0]), ipaddress.IPv4Address
):
default_socket_family = socket.AF_INET
else:
default_socket_family = socket.AF_INET6
try:
host.encode("idna")
except UnicodeError:
raise LocationParseError(f"'{host}', label empty or too long") from None
dt_pre_resolve = datetime.now(tz=timezone.utc)
records = self.getaddrinfo(
host,
port,
default_socket_family,
socket_kind,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
delta_post_resolve = datetime.now(tz=timezone.utc) - dt_pre_resolve
dt_pre_established = datetime.now(tz=timezone.utc)
for res in records:
af, socktype, proto, canonname, sa = res
sock = None
try:
sock = socket.socket(af, socktype, proto)
# we need to add this or reusing the same origin port will likely fail within
# short period of time. kernel put port on wait shut.
if source_address is not None:
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
except (
OSError,
AttributeError,
): # Defensive: Windows or very old OS?
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
except (
OSError,
AttributeError,
): # Defensive: we can't do anything better than this.
pass
try:
sock.setsockopt(
socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0)
)
except (OSError, AttributeError):
pass
sock.bind(source_address)
# attempt to leverage GRO when under Linux
if socktype == socket.SOCK_DGRAM and sys.platform == "linux":
try:
sock.setsockopt(socket.SOL_UDP, UDP_LINUX_GRO, 1)
except OSError: # Defensive: oh, well(...) anyway!
pass
# If provided, set socket level options before connecting.
_set_socket_options(sock, socket_options)
if timeout is not _DEFAULT_TIMEOUT:
sock.settimeout(timeout)
if self._unsafe_expose:
self._sock_cursor = sock
sock.connect(sa)
if self._unsafe_expose:
self._sock_cursor = None
# Break explicitly a reference cycle
err = None
delta_post_established = (
datetime.now(tz=timezone.utc) - dt_pre_established
)
if timing_hook is not None:
timing_hook(
(
delta_post_resolve,
delta_post_established,
datetime.now(tz=timezone.utc),
)
)
return sock
except (OSError, OverflowError) as _:
err = _
if sock is not None:
sock.close()
if isinstance(_, OverflowError):
break
if err is not None:
try:
raise err
finally:
# Break explicitly a reference cycle
err = None
else:
raise OSError("getaddrinfo returns an empty list")
class ManyResolver(BaseResolver):
"""
Special resolver that use many child resolver. Priorities
are based on given order (list of BaseResolver).
"""
def __init__(self, *resolvers: BaseResolver) -> None:
super().__init__(None, None)
self._size = len(resolvers)
self._unconstrained: list[BaseResolver] = [
_ for _ in resolvers if not _.have_constraints()
]
self._constrained: list[BaseResolver] = [
_ for _ in resolvers if _.have_constraints()
]
self._concurrent: int = 0
self._terminated: bool = False
def recycle(self) -> BaseResolver:
resolvers = []
for resolver in self._unconstrained + self._constrained:
resolvers.append(resolver.recycle())
return ManyResolver(*resolvers)
def close(self) -> None:
for resolver in self._unconstrained + self._constrained:
resolver.close()
self._terminated = True
def is_available(self) -> bool:
return not self._terminated
def __resolvers(
self, constrained: bool = False
) -> typing.Generator[BaseResolver, None, None]:
resolvers = self._unconstrained if not constrained else self._constrained
if not resolvers:
return
with self._lock:
self._concurrent += 1
try:
resolver_count = len(resolvers)
start_idx = (self._concurrent - 1) % resolver_count
for idx in range(start_idx, resolver_count):
if not resolvers[idx].is_available():
with self._lock:
resolvers[idx] = resolvers[idx].recycle()
yield resolvers[idx]
if start_idx > 0:
for idx in range(0, start_idx):
if not resolvers[idx].is_available():
with self._lock:
resolvers[idx] = resolvers[idx].recycle()
yield resolvers[idx]
finally:
with self._lock:
self._concurrent -= 1
def getaddrinfo(
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if isinstance(host, bytes):
host = host.decode("ascii")
if host is None:
host = "localhost"
tested_resolvers = []
any_constrained_tried: bool = False
for resolver in self.__resolvers(True):
can_resolve = resolver.support(host)
if can_resolve is True:
any_constrained_tried = True
try:
results = resolver.getaddrinfo(
host,
port,
family,
type,
proto,
flags,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
if results:
return results
except socket.gaierror as exc:
if isinstance(exc.args[0], str) and (
"DNSSEC" in exc.args[0] or "DNSKEY" in exc.args[0]
):
raise
continue
elif can_resolve is False:
tested_resolvers.append(resolver)
if any_constrained_tried:
raise socket.gaierror(
f"Name or service not known: {host} using {self._size - len(self._unconstrained)} resolver(s)"
)
for resolver in self.__resolvers():
try:
results = resolver.getaddrinfo(
host,
port,
family,
type,
proto,
flags,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
if results:
return results
except socket.gaierror as exc:
if isinstance(exc.args[0], str) and (
"DNSSEC" in exc.args[0] or "DNSKEY" in exc.args[0]
):
raise
continue
raise socket.gaierror(
f"Name or service not known: {host} using {self._size - len(self._constrained)} resolver(s)"
)
class SupportedQueryType(int, Enum):
"""
urllib3.future does not need anything else so far. let's be pragmatic.
Each type is associated with its hex value as per the RFC.
"""
A = 0x0001
AAAA = 0x001C
HTTPS = 0x0041
class DomainNameServerQuery:
"""
Minimalist DNS query/message to ask for A, AAAA and HTTPS records.
Only meant for urllib3.future use. Does not cover all of possible extent of use.
"""
def __init__(
self, host: str, query_type: SupportedQueryType, override_id: int | None = None
) -> None:
self._id = struct.pack(
"!H", randint(0x0000, 0xFFFF) if override_id is None else override_id
)
self._host = host
self._query = query_type
self._flags = struct.pack("!H", 0x0100)
self._qd_count = struct.pack("!H", 1)
self._cached: bytes | None = None
@property
def id(self) -> int:
return struct.unpack("!H", self._id)[0] # type: ignore[no-any-return]
@property
def raw_id(self) -> bytes:
return self._id
def __repr__(self) -> str:
return f"<Query '{self._host}' IN {self._query.name}>"
def __bytes__(self) -> bytes:
if self._cached:
return self._cached
payload = b""
payload += self._id
payload += self._flags
payload += self._qd_count
payload += b"\x00\x00"
payload += b"\x00\x00"
payload += b"\x00\x00"
for ext in self._host.split("."):
payload += struct.pack("!B", len(ext))
payload += ext.encode("ascii")
payload += b"\x00"
payload += struct.pack("!H", self._query.value)
payload += struct.pack("!H", 0x0001)
self._cached = payload
return payload
@staticmethod
def bulk(host: str, *types: SupportedQueryType) -> list[DomainNameServerQuery]:
queries = []
for query_type in types:
queries.append(DomainNameServerQuery(host, query_type=query_type))
return queries
#: Most common status code, not exhaustive at all.
COMMON_RCODE_LABEL: dict[int, str] = {
0: "No Error",
1: "Format Error",
2: "Server Failure",
3: "Non-Existent Domain",
5: "Query Refused",
9: "Not Authorized",
}
class DomainNameServerParseException(Exception): ...
class DomainNameServerReturn:
"""
Minimalist DNS response parser. Allow to quickly extract key-data out of it.
Meant for A, AAAA and HTTPS records. Basically only what we need.
"""
def __init__(self, payload: bytes) -> None:
try:
up = struct.unpack("!HHHHHH", payload[:12])
self._id = up[0]
self._flags = up[1]
self._qd_count = up[2]
self._an_count = up[3]
self._rcode = int(f"0x{hex(payload[3])[-1]}", 16)
self._hostname: str = ""
idx = 12
while True:
c = payload[idx]
if c == 0:
idx += 1
break
self._hostname += payload[idx + 1 : idx + 1 + c].decode("ascii") + "."
idx += c + 1
self._records: list[tuple[SupportedQueryType, int, str | HttpsRecord]] = []
if self._an_count:
idx += 4
while idx < len(payload):
up = struct.unpack("!HHHI", payload[idx : idx + 10])
entry_size = struct.unpack("!H", payload[idx + 10 : idx + 12])[0]
data = payload[idx + 12 : idx + 12 + entry_size]
if len(data) == 4:
decoded_data: str | HttpsRecord = inet4_ntoa(data)
elif len(data) == 16:
decoded_data = inet6_ntoa(data)
elif data:
decoded_data = parse_https_rdata(data)
else:
continue
try:
self._records.append(
(SupportedQueryType(up[1]), up[-1], decoded_data)
)
except ValueError:
pass
idx += 12 + entry_size
except (struct.error, IndexError, ValueError, UnicodeDecodeError) as e:
raise DomainNameServerParseException(
"A protocol error occurred while parsing the DNS response payload: "
f"{str(e)}"
) from e
@property
def id(self) -> int:
return self._id # type: ignore[no-any-return]
@property
def hostname(self) -> str:
return self._hostname
@property
def records(self) -> list[tuple[SupportedQueryType, int, str | HttpsRecord]]:
return self._records
@property
def is_found(self) -> bool:
return bool(self._records)
@property
def rcode(self) -> int:
return self._rcode
@property
def is_ok(self) -> bool:
return self._rcode == 0
def __repr__(self) -> str:
if self.is_ok:
return f"<Records '{self.hostname}' {self._records}>"
return f"<DNS Error '{self.hostname}' with Status {self.rcode} ({COMMON_RCODE_LABEL[self.rcode] if self.rcode in COMMON_RCODE_LABEL else 'Unknown'})>"

View File

@@ -0,0 +1,5 @@
from __future__ import annotations
from ._socket import SystemResolver
__all__ = ("SystemResolver",)

View File

@@ -0,0 +1,65 @@
from __future__ import annotations
import socket
import typing
from ..protocols import BaseResolver, ProtocolResolver
class SystemResolver(BaseResolver):
implementation = "socket"
protocol = ProtocolResolver.SYSTEM
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
kwargs.pop("port")
super().__init__(None, None, *patterns, **kwargs)
def support(self, hostname: str | bytes | None) -> bool | None:
if hostname is None:
return True
if isinstance(hostname, bytes):
hostname = hostname.decode("ascii")
if hostname == "localhost":
return True
return super().support(hostname)
def recycle(self) -> BaseResolver:
return self
def close(self) -> None:
pass # no-op!
def is_available(self) -> bool:
return True
def getaddrinfo(
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
# the | tuple[int, bytes] is silently ignored, can't happen with our cases.
return socket.getaddrinfo( # type: ignore[return-value]
host=host,
port=port,
family=family,
type=type,
proto=proto,
flags=flags,
)

View File

@@ -0,0 +1,322 @@
from __future__ import annotations
import base64
import binascii
import socket
import struct
import typing
if typing.TYPE_CHECKING:
class HttpsRecord(typing.TypedDict):
priority: int
target: str
alpn: list[str]
ipv4hint: list[str]
ipv6hint: list[str]
echconfig: list[str]
def inet4_ntoa(address: bytes) -> str:
"""
Convert an IPv4 address from bytes to str.
"""
if len(address) != 4:
raise ValueError(
f"IPv4 addresses are 4 bytes long, got {len(address)} byte(s) instead"
)
return "%u.%u.%u.%u" % (address[0], address[1], address[2], address[3])
def inet6_ntoa(address: bytes) -> str:
"""
Convert an IPv6 address from bytes to str.
"""
if len(address) != 16:
raise ValueError(
f"IPv6 addresses are 16 bytes long, got {len(address)} byte(s) instead"
)
hex = binascii.hexlify(address)
chunks = []
i = 0
length = len(hex)
while i < length:
chunk = hex[i : i + 4].decode().lstrip("0") or "0"
chunks.append(chunk)
i += 4
# Compress the longest subsequence of 0-value chunks to ::
best_start = 0
best_len = 0
start = -1
last_was_zero = False
for i in range(8):
if chunks[i] != "0":
if last_was_zero:
end = i
current_len = end - start
if current_len > best_len:
best_start = start
best_len = current_len
last_was_zero = False
elif not last_was_zero:
start = i
last_was_zero = True
if last_was_zero:
end = 8
current_len = end - start
if current_len > best_len:
best_start = start
best_len = current_len
if best_len > 1:
if best_start == 0 and (best_len == 6 or best_len == 5 and chunks[5] == "ffff"):
# We have an embedded IPv4 address
if best_len == 6:
prefix = "::"
else:
prefix = "::ffff:"
thex = prefix + inet4_ntoa(address[12:])
else:
thex = (
":".join(chunks[:best_start])
+ "::"
+ ":".join(chunks[best_start + best_len :])
)
else:
thex = ":".join(chunks)
return thex
def packet_fragment(payload: bytes, *identifiers: bytes) -> tuple[bytes, ...]:
results = []
offset = 0
start_packet_idx = []
lead_identifier = None
for identifier in identifiers:
idx = payload[:12].find(identifier)
if idx == -1:
continue
if idx != 0:
offset = idx
start_packet_idx.append(idx - offset)
lead_identifier = identifier
break
for identifier in identifiers:
if identifier == lead_identifier:
continue
if offset == 0:
idx = payload.find(b"\x02" + identifier)
else:
idx = payload.find(identifier)
if idx == -1:
continue
start_packet_idx.append(idx - offset)
if not start_packet_idx:
raise ValueError(
"no identifiable dns message emerged from given payload. "
"this should not happen at all. networking issue?"
)
if len(start_packet_idx) == 1:
return (payload,)
start_packet_idx = sorted(start_packet_idx)
previous_idx = None
for idx in start_packet_idx:
if previous_idx is None:
previous_idx = idx
continue
results.append(payload[previous_idx:idx])
previous_idx = idx
results.append(payload[previous_idx:])
return tuple(results)
def is_ipv4(addr: str) -> bool:
try:
socket.inet_aton(addr)
return True
except OSError:
return False
def is_ipv6(addr: str) -> bool:
try:
socket.inet_pton(socket.AF_INET6, addr)
return True
except OSError:
return False
def validate_length_of(hostname: str) -> None:
"""RFC 1035 impose a limit on a domain name length. We verify it there."""
if len(hostname.strip(".")) > 253:
raise UnicodeError("hostname to resolve exceed 253 characters")
elif any([len(_) > 63 for _ in hostname.split(".")]):
raise UnicodeError("at least one label to resolve exceed 63 characters")
def rfc1035_should_read(payload: bytes) -> bool:
if not payload:
return False
if len(payload) <= 2:
return True
cursor = payload
while True:
expected_size: int = struct.unpack("!H", cursor[:2])[0]
if len(cursor[2:]) == expected_size:
return False
elif len(cursor[2:]) < expected_size:
return True
cursor = cursor[2 + expected_size :]
def rfc1035_unpack(payload: bytes) -> tuple[bytes, ...]:
cursor = payload
packets = []
while cursor:
expected_size: int = struct.unpack("!H", cursor[:2])[0]
packets.append(cursor[2 : 2 + expected_size])
cursor = cursor[2 + expected_size :]
return tuple(packets)
def rfc1035_pack(message: bytes) -> bytes:
return struct.pack("!H", len(message)) + message
def read_name(data: bytes, offset: int) -> tuple[str, int]:
"""
Read a DNSencoded name (with compression pointers) from data[offset:].
Returns (name, new_offset).
"""
labels = []
while True:
length = data[offset]
# compression pointer?
if length & 0xC0 == 0xC0:
pointer = struct.unpack_from("!H", data, offset)[0] & 0x3FFF
subname, _ = read_name(data, pointer)
labels.append(subname)
offset += 2
break
if length == 0:
offset += 1
break
offset += 1
labels.append(data[offset : offset + length].decode())
offset += length
return ".".join(labels), offset
def parse_echconfigs(buf: bytes) -> list[str]:
"""
buf is the raw bytes of the ECHConfig vector:
- 2-byte total length, then for each:
- 2-byte cfg length + that many bytes of cfg
We return a list of Base64 strings (one per config).
"""
if len(buf) < 2:
return []
off = 2
total = struct.unpack_from("!H", buf, 0)[0]
end = 2 + total
out = []
while off + 2 <= end:
cfg_len = struct.unpack_from("!H", buf, off)[0]
off += 2
cfg = buf[off : off + cfg_len]
off += cfg_len
out.append(base64.b64encode(cfg).decode())
return out
def parse_https_rdata(rdata: bytes) -> HttpsRecord:
"""
Parse the RDATA of an SVCB/HTTPS record.
Returns a dict with keys: priority, target, alpn, ipv4hint, ipv6hint, echconfig.
"""
off = 0
priority = struct.unpack_from("!H", rdata, off)[0]
off += 2
target, off = read_name(rdata, off)
# pull out all the key/value params
params = {}
while off + 4 <= len(rdata):
key, length = struct.unpack_from("!HH", rdata, off)
off += 4
params[key] = rdata[off : off + length]
off += length
# decode ALPN (key=1), IPv4 (4), IPv6 (6), ECHConfig (5)
def parse_alpn(buf: bytes) -> list[str]:
out = []
i: int = 0
while i < len(buf):
ln = buf[i]
out.append(buf[i + 1 : i + 1 + ln].decode())
i += 1 + ln
return out
alpn: list[str] = parse_alpn(params.get(1, b""))
ipv4 = [
inet4_ntoa(params[4][i : i + 4]) for i in range(0, len(params.get(4, b"")), 4)
]
ipv6 = [
inet6_ntoa(params[6][i : i + 16]) for i in range(0, len(params.get(6, b"")), 16)
]
echconfs = parse_echconfigs(params.get(5, b""))
return {
"priority": priority,
"target": target or ".", # empty name → root
"alpn": alpn,
"ipv4hint": ipv4,
"ipv6hint": ipv6,
"echconfig": echconfs,
}
__all__ = (
"inet4_ntoa",
"inet6_ntoa",
"packet_fragment",
"is_ipv4",
"is_ipv6",
"validate_length_of",
"rfc1035_pack",
"rfc1035_unpack",
"rfc1035_should_read",
"parse_https_rdata",
)

View File

@@ -0,0 +1,453 @@
"""
This module contains provisional support for SOCKS proxies from within
urllib3. This module supports SOCKS4, SOCKS4A (an extension of SOCKS4), and
SOCKS5. To enable its functionality, either install python-socks or install this
module with the ``socks`` extra.
The SOCKS implementation supports the full range of urllib3 features. It also
supports the following SOCKS features:
- SOCKS4A (``proxy_url='socks4a://...``)
- SOCKS4 (``proxy_url='socks4://...``)
- SOCKS5 with remote DNS (``proxy_url='socks5h://...``)
- SOCKS5 with local DNS (``proxy_url='socks5://...``)
- Usernames and passwords for the SOCKS proxy
.. note::
It is recommended to use ``socks5h://`` or ``socks4a://`` schemes in
your ``proxy_url`` to ensure that DNS resolution is done from the remote
server instead of client-side when connecting to a domain name.
SOCKS4 supports IPv4 and domain names with the SOCKS4A extension. SOCKS5
supports IPv4, IPv6, and domain names.
When connecting to a SOCKS4 proxy the ``username`` portion of the ``proxy_url``
will be sent as the ``userid`` section of the SOCKS request:
.. code-block:: python
proxy_url="socks4a://<userid>@proxy-host"
When connecting to a SOCKS5 proxy the ``username`` and ``password`` portion
of the ``proxy_url`` will be sent as the username/password to authenticate
with the proxy:
.. code-block:: python
proxy_url="socks5h://<username>:<password>@proxy-host"
"""
from __future__ import annotations
import warnings
#: We purposely want to support PySocks[...] due to our shadowing of the legacy "urllib3". "Dot not disturb" policy.
BYPASS_SOCKS_LEGACY: bool = False
try:
from python_socks import (
ProxyConnectionError,
ProxyError,
ProxyTimeoutError,
ProxyType,
)
from python_socks.sync import Proxy
from ._socks_override import AsyncioProxy
except ImportError:
from ..exceptions import DependencyWarning
try:
import socks # noqa
except ImportError:
warnings.warn(
(
"SOCKS support in urllib3.future requires the installation of an optional "
"dependency: python-socks. For more information, see "
"https://urllib3future.readthedocs.io/en/latest/contrib.html#socks-proxies"
),
DependencyWarning,
)
else:
from ._socks_legacy import (
SOCKSConnection,
SOCKSHTTPConnectionPool,
SOCKSHTTPSConnection,
SOCKSHTTPSConnectionPool,
SOCKSProxyManager,
)
BYPASS_SOCKS_LEGACY = True
if not BYPASS_SOCKS_LEGACY:
raise
if not BYPASS_SOCKS_LEGACY:
import typing
from socket import socket
from socket import timeout as SocketTimeout
# asynchronous part
from .._async.connection import AsyncHTTPConnection, AsyncHTTPSConnection
from .._async.connectionpool import (
AsyncHTTPConnectionPool,
AsyncHTTPSConnectionPool,
)
from .._async.poolmanager import AsyncPoolManager
from .._typing import _TYPE_SOCKS_OPTIONS
from ..backend import HttpVersion
# synchronous part
from ..connection import HTTPConnection, HTTPSConnection
from ..connectionpool import HTTPConnectionPool, HTTPSConnectionPool
from ..contrib.ssa import AsyncSocket
from ..exceptions import ConnectTimeoutError, NewConnectionError
from ..poolmanager import PoolManager
from ..util.url import parse_url
try:
import ssl
except ImportError:
ssl = None # type: ignore[assignment]
class SOCKSConnection(HTTPConnection): # type: ignore[no-redef]
"""
A plain-text HTTP connection that connects via a SOCKS proxy.
"""
def __init__(
self,
_socks_options: _TYPE_SOCKS_OPTIONS,
*args: typing.Any,
**kwargs: typing.Any,
) -> None:
self._socks_options = _socks_options
super().__init__(*args, **kwargs)
def _new_conn(self) -> socket:
"""
Establish a new connection via the SOCKS proxy.
"""
extra_kw: dict[str, typing.Any] = {}
if self.source_address:
extra_kw["source_address"] = self.source_address
if self.socket_options:
only_tcp_options = []
for opt in self.socket_options:
if len(opt) == 3:
only_tcp_options.append(opt)
elif len(opt) == 4:
protocol: str = opt[3].lower()
if protocol == "udp":
continue
only_tcp_options.append(opt[:3])
extra_kw["socket_options"] = only_tcp_options
try:
assert self._socks_options["proxy_host"] is not None
assert self._socks_options["proxy_port"] is not None
p = Proxy(
proxy_type=self._socks_options["socks_version"], # type: ignore[arg-type]
host=self._socks_options["proxy_host"],
port=int(self._socks_options["proxy_port"]),
username=self._socks_options["username"],
password=self._socks_options["password"],
rdns=self._socks_options["rdns"],
)
_socket = self._resolver.create_connection(
(
self._socks_options["proxy_host"],
int(self._socks_options["proxy_port"]),
),
timeout=self.timeout,
source_address=self.source_address,
socket_options=extra_kw["socket_options"],
quic_upgrade_via_dns_rr=False,
timing_hook=lambda _: setattr(self, "_connect_timings", _),
)
# our dependency started to deprecate passing "_socket"
# which is ... vital for our integration. We'll start by silencing the warning.
# then we'll think on how to proceed.
# A) the maintainer agrees to revert https://github.com/romis2012/python-socks/commit/173a7390469c06aa033f8dca67c827854b462bc3#diff-e4086fa970d1c98b1eb341e58cb70e9ceffe7391b2feecc4b66c7e92ea2de76fR64
# B) the maintainer pursue the removal -> do we vendor our copy of python-socks? is there an alternative?
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
return p.connect(
self.host,
self.port,
self.timeout,
_socket=_socket,
)
except (SocketTimeout, ProxyTimeoutError) as e:
raise ConnectTimeoutError(
self,
f"Connection to {self.host} timed out. (connect timeout={self.timeout})",
) from e
except (ProxyConnectionError, ProxyError) as e:
raise NewConnectionError(
self, f"Failed to establish a new connection: {e}"
) from e
except OSError as e: # Defensive: PySocks should catch all these.
raise NewConnectionError(
self, f"Failed to establish a new connection: {e}"
) from e
# We don't need to duplicate the Verified/Unverified distinction from
# urllib3/connection.py here because the HTTPSConnection will already have been
# correctly set to either the Verified or Unverified form by that module. This
# means the SOCKSHTTPSConnection will automatically be the correct type.
class SOCKSHTTPSConnection(SOCKSConnection, HTTPSConnection): # type: ignore[no-redef]
pass
class SOCKSHTTPConnectionPool(HTTPConnectionPool): # type: ignore[no-redef]
ConnectionCls = SOCKSConnection
class SOCKSHTTPSConnectionPool(HTTPSConnectionPool): # type: ignore[no-redef]
ConnectionCls = SOCKSHTTPSConnection
class SOCKSProxyManager(PoolManager): # type: ignore[no-redef]
"""
A version of the urllib3 ProxyManager that routes connections via the
defined SOCKS proxy.
"""
pool_classes_by_scheme = {
"http": SOCKSHTTPConnectionPool,
"https": SOCKSHTTPSConnectionPool,
}
def __init__(
self,
proxy_url: str,
username: str | None = None,
password: str | None = None,
num_pools: int = 10,
headers: typing.Mapping[str, str] | None = None,
**connection_pool_kw: typing.Any,
):
parsed = parse_url(proxy_url)
if username is None and password is None and parsed.auth is not None:
split = parsed.auth.split(":")
if len(split) == 2:
username, password = split
if parsed.scheme == "socks5":
socks_version = ProxyType.SOCKS5
rdns = False
elif parsed.scheme == "socks5h":
socks_version = ProxyType.SOCKS5
rdns = True
elif parsed.scheme == "socks4":
socks_version = ProxyType.SOCKS4
rdns = False
elif parsed.scheme == "socks4a":
socks_version = ProxyType.SOCKS4
rdns = True
else:
raise ValueError(f"Unable to determine SOCKS version from {proxy_url}")
self.proxy_url = proxy_url
socks_options = {
"socks_version": socks_version,
"proxy_host": parsed.host,
"proxy_port": parsed.port,
"username": username,
"password": password,
"rdns": rdns,
}
connection_pool_kw["_socks_options"] = socks_options
if "disabled_svn" not in connection_pool_kw:
connection_pool_kw["disabled_svn"] = set()
connection_pool_kw["disabled_svn"].add(HttpVersion.h3)
super().__init__(num_pools, headers, **connection_pool_kw)
self.pool_classes_by_scheme = SOCKSProxyManager.pool_classes_by_scheme
class AsyncSOCKSConnection(AsyncHTTPConnection):
"""
A plain-text HTTP connection that connects via a SOCKS proxy.
"""
def __init__(
self,
_socks_options: _TYPE_SOCKS_OPTIONS,
*args: typing.Any,
**kwargs: typing.Any,
) -> None:
self._socks_options = _socks_options
super().__init__(*args, **kwargs)
async def _new_conn(self) -> AsyncSocket: # type: ignore[override]
"""
Establish a new connection via the SOCKS proxy.
"""
extra_kw: dict[str, typing.Any] = {}
if self.source_address:
extra_kw["source_address"] = self.source_address
if self.socket_options:
only_tcp_options = []
for opt in self.socket_options:
if len(opt) == 3:
only_tcp_options.append(opt)
elif len(opt) == 4:
protocol: str = opt[3].lower()
if protocol == "udp":
continue
only_tcp_options.append(opt[:3])
extra_kw["socket_options"] = only_tcp_options
try:
assert self._socks_options["proxy_host"] is not None
assert self._socks_options["proxy_port"] is not None
p = AsyncioProxy(
proxy_type=self._socks_options["socks_version"], # type: ignore[arg-type]
host=self._socks_options["proxy_host"],
port=int(self._socks_options["proxy_port"]),
username=self._socks_options["username"],
password=self._socks_options["password"],
rdns=self._socks_options["rdns"],
)
_socket = await self._resolver.create_connection(
(
self._socks_options["proxy_host"],
int(self._socks_options["proxy_port"]),
),
timeout=self.timeout,
source_address=self.source_address,
socket_options=extra_kw["socket_options"],
quic_upgrade_via_dns_rr=False,
timing_hook=lambda _: setattr(self, "_connect_timings", _),
)
return await p.connect(
self.host,
self.port,
self.timeout,
_socket,
)
except (SocketTimeout, ProxyTimeoutError) as e:
raise ConnectTimeoutError(
self,
f"Connection to {self.host} timed out. (connect timeout={self.timeout})",
) from e
except (ProxyConnectionError, ProxyError) as e:
raise NewConnectionError(
self, f"Failed to establish a new connection: {e}"
) from e
except OSError as e: # Defensive: PySocks should catch all these.
raise NewConnectionError(
self, f"Failed to establish a new connection: {e}"
) from e
# We don't need to duplicate the Verified/Unverified distinction from
# urllib3/connection.py here because the HTTPSConnection will already have been
# correctly set to either the Verified or Unverified form by that module. This
# means the SOCKSHTTPSConnection will automatically be the correct type.
class AsyncSOCKSHTTPSConnection(AsyncSOCKSConnection, AsyncHTTPSConnection):
pass
class AsyncSOCKSHTTPConnectionPool(AsyncHTTPConnectionPool):
ConnectionCls = AsyncSOCKSConnection
class AsyncSOCKSHTTPSConnectionPool(AsyncHTTPSConnectionPool):
ConnectionCls = AsyncSOCKSHTTPSConnection
class AsyncSOCKSProxyManager(AsyncPoolManager):
"""
A version of the urllib3 ProxyManager that routes connections via the
defined SOCKS proxy.
"""
pool_classes_by_scheme = {
"http": AsyncSOCKSHTTPConnectionPool,
"https": AsyncSOCKSHTTPSConnectionPool,
}
def __init__(
self,
proxy_url: str,
username: str | None = None,
password: str | None = None,
num_pools: int = 10,
headers: typing.Mapping[str, str] | None = None,
**connection_pool_kw: typing.Any,
):
parsed = parse_url(proxy_url)
if username is None and password is None and parsed.auth is not None:
split = parsed.auth.split(":")
if len(split) == 2:
username, password = split
if parsed.scheme == "socks5":
socks_version = ProxyType.SOCKS5
rdns = False
elif parsed.scheme == "socks5h":
socks_version = ProxyType.SOCKS5
rdns = True
elif parsed.scheme == "socks4":
socks_version = ProxyType.SOCKS4
rdns = False
elif parsed.scheme == "socks4a":
socks_version = ProxyType.SOCKS4
rdns = True
else:
raise ValueError(f"Unable to determine SOCKS version from {proxy_url}")
self.proxy_url = proxy_url
socks_options = {
"socks_version": socks_version,
"proxy_host": parsed.host,
"proxy_port": parsed.port,
"username": username,
"password": password,
"rdns": rdns,
}
connection_pool_kw["_socks_options"] = socks_options
if "disabled_svn" not in connection_pool_kw:
connection_pool_kw["disabled_svn"] = set()
connection_pool_kw["disabled_svn"].add(HttpVersion.h3)
super().__init__(num_pools, headers, **connection_pool_kw)
self.pool_classes_by_scheme = AsyncSOCKSProxyManager.pool_classes_by_scheme
__all__ = [
"SOCKSConnection",
"SOCKSProxyManager",
"SOCKSHTTPSConnection",
"SOCKSHTTPSConnectionPool",
"SOCKSHTTPConnectionPool",
]
if not BYPASS_SOCKS_LEGACY:
__all__ += [
"AsyncSOCKSConnection",
"AsyncSOCKSHTTPSConnection",
"AsyncSOCKSHTTPConnectionPool",
"AsyncSOCKSHTTPSConnectionPool",
"AsyncSOCKSProxyManager",
]

View File

@@ -0,0 +1,520 @@
from __future__ import annotations
import asyncio
import platform
import socket
import typing
import warnings
from ._timeout import timeout
from ._gro import open_dgram_connection, DatagramReader, DatagramWriter
StandardTimeoutError = socket.timeout
try:
from concurrent.futures import TimeoutError as FutureTimeoutError
except ImportError:
FutureTimeoutError = TimeoutError # type: ignore[misc]
try:
AsyncioTimeoutError = asyncio.exceptions.TimeoutError
except AttributeError:
AsyncioTimeoutError = TimeoutError # type: ignore[misc]
if typing.TYPE_CHECKING:
import ssl
from typing_extensions import Literal
from ..._typing import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT
def _can_shutdown_and_close_selector_loop_bug() -> bool:
import platform
if platform.system() == "Windows" and platform.python_version_tuple()[:2] == (
"3",
"7",
):
return int(platform.python_version_tuple()[-1]) >= 17
return True
# Windows + asyncio bug where doing our shutdown procedure induce a crash
# in SelectorLoop
# File "C:\hostedtoolcache\windows\Python\3.7.9\x64\lib\selectors.py", line 314, in _select
# r, w, x = select.select(r, w, w, timeout)
# [WinError 10038] An operation was attempted on something that is not a socket
_CPYTHON_SELECTOR_CLOSE_BUG_EXIST = _can_shutdown_and_close_selector_loop_bug() is False
class AsyncSocket:
"""
This class is brought to add a level of abstraction to an asyncio transport (reader, or writer)
We don't want to have two distinct code (async/sync) but rather a unified and easily verifiable
code base.
'ssa' stands for Simplified - Socket - Asynchronous.
"""
def __init__(
self,
family: socket.AddressFamily = socket.AF_INET,
type: socket.SocketKind = socket.SOCK_STREAM,
proto: int = -1,
fileno: int | None = None,
) -> None:
self.family: socket.AddressFamily = family
self.type: socket.SocketKind = type
self.proto: int = proto
self._fileno: int | None = fileno
self._connect_called: bool = False
self._established: asyncio.Event = asyncio.Event()
# we do that everytime to forward properly options / advanced settings
self._sock: socket.socket = socket.socket(
family=self.family, type=self.type, proto=self.proto, fileno=fileno
)
# set nonblocking / or cause the loop to block with dgram socket...
self._sock.settimeout(0)
self._writer: asyncio.StreamWriter | DatagramWriter | None = None
self._reader: asyncio.StreamReader | DatagramReader | None = None
self._writer_semaphore: asyncio.Semaphore = asyncio.Semaphore()
self._reader_semaphore: asyncio.Semaphore = asyncio.Semaphore()
self._addr: tuple[str, int] | tuple[str, int, int, int] | None = None
self._external_timeout: float | int | None = None
self._tls_in_tls = False
def fileno(self) -> int:
return self._fileno if self._fileno is not None else self._sock.fileno()
async def wait_for_close(self) -> None:
if self._connect_called:
return
if self._writer is None:
return
try:
# report made in https://github.com/jawah/niquests/issues/184
# made us believe that sometime ssl_transport is freed before
# getting there. So we could end up there with a half broken
# writer state. The original user was using Windows at the time.
is_ssl = self._writer.get_extra_info("ssl_object") is not None
except AttributeError:
is_ssl = False
if is_ssl:
# Give the connection a chance to write any data in the buffer,
# and then forcibly tear down the SSL connection.
await asyncio.sleep(0)
self._writer.transport.abort()
try:
# wait_closed can hang indefinitely!
# on Python 3.8 and 3.9
# there's some case where Python want an explicit EOT
# (spoiler: it was a CPython bug) fixed in recent interpreters.
# to circumvent this and still have a proper close
# we enforce a maximum delay (1000ms).
async with timeout(1):
await self._writer.wait_closed()
except TimeoutError:
pass
def close(self) -> None:
if self._writer is not None:
self._writer.close()
edge_case_close_bug_exist = _CPYTHON_SELECTOR_CLOSE_BUG_EXIST
# Windows + asyncio + asyncio.SelectorEventLoop limits us on how far
# we can safely shutdown the socket.
if not edge_case_close_bug_exist and platform.system() == "Windows":
if hasattr(asyncio, "SelectorEventLoop") and isinstance(
asyncio.get_running_loop(), asyncio.SelectorEventLoop
):
edge_case_close_bug_exist = True
try:
# see https://github.com/MagicStack/uvloop/issues/241
# and https://github.com/jawah/niquests/issues/166
# probably not just uvloop.
uvloop_edge_case_bug = False
# keep track of our clean exit procedure
shutdown_called = False
close_called = False
if hasattr(self._sock, "shutdown"):
try:
self._sock.shutdown(socket.SHUT_RD)
shutdown_called = True
except TypeError:
uvloop_edge_case_bug = True
# uvloop don't support shutdown! and sometime does not support close()...
# see https://github.com/jawah/niquests/issues/166 for ctx.
try:
self._sock.close()
close_called = True
except TypeError:
# last chance of releasing properly the underlying fd!
try:
direct_sock = socket.socket(fileno=self._sock.fileno())
except (OSError, ValueError):
pass
else:
try:
direct_sock.shutdown(socket.SHUT_RD)
shutdown_called = True
except OSError:
warnings.warn(
(
"urllib3-future is unable to properly close your async socket. "
"This mean that you are probably using an asyncio implementation like uvloop "
"that does not support shutdown() or/and close() on the socket transport. "
"This will lead to unclosed socket (fd)."
),
ResourceWarning,
)
finally:
direct_sock.detach()
# we have to force call close() on our sock object (even after shutdown).
# or we'll get a resource warning for sure!
if isinstance(self._sock, socket.socket) and hasattr(self._sock, "close"):
if not uvloop_edge_case_bug and not edge_case_close_bug_exist:
try:
self._sock.close()
close_called = True
except (OSError, TypeError):
pass
if not close_called or not shutdown_called:
# this branch detect whether we have an asyncio.TransportSocket instead of socket.socket.
if hasattr(self._sock, "_sock") and not edge_case_close_bug_exist:
try:
self._sock._sock.close()
except (AttributeError, OSError, TypeError):
pass
except (
OSError
): # branch where we failed to connect and still try to release resource
if isinstance(self._sock, socket.socket):
try:
self._sock.close() # don't call close on asyncio.TransportSocket
except (OSError, TypeError, AttributeError):
pass
elif hasattr(self._sock, "_sock") and not edge_case_close_bug_exist:
try:
self._sock._sock.detach()
except (AttributeError, OSError, TypeError):
pass
self._connect_called = False
self._established.clear()
async def wait_for_readiness(self) -> None:
await self._established.wait()
def setsockopt(self, __level: int, __optname: int, __value: int | bytes) -> None:
self._sock.setsockopt(__level, __optname, __value)
@typing.overload
def getsockopt(self, __level: int, __optname: int) -> int: ...
@typing.overload
def getsockopt(self, __level: int, __optname: int, buflen: int) -> bytes: ...
def getsockopt(
self, __level: int, __optname: int, buflen: int | None = None
) -> int | bytes:
if buflen is None:
return self._sock.getsockopt(__level, __optname)
return self._sock.getsockopt(__level, __optname, buflen)
def should_connect(self) -> bool:
return self._connect_called is False
async def connect(self, addr: tuple[str, int] | tuple[str, int, int, int]) -> None:
if self._connect_called:
raise OSError(
"attempted to connect twice on a already established connection"
)
self._connect_called = True
# there's a particularity on Windows
# we must not forward non-IP in addr due to
# a limitation in the network bridge used in asyncio
if platform.system() == "Windows":
from ..resolver.utils import is_ipv4, is_ipv6
host, port = addr[:2]
if not is_ipv4(host) and not is_ipv6(host):
res = await asyncio.get_running_loop().getaddrinfo(
host,
port,
family=self.family,
type=self.type,
)
if not res:
raise socket.gaierror(f"unable to resolve hostname {host}")
addr = res[0][-1]
if self._external_timeout is not None:
try:
async with timeout(self._external_timeout):
await asyncio.get_running_loop().sock_connect(self._sock, addr)
except (FutureTimeoutError, AsyncioTimeoutError, TimeoutError) as e:
self._connect_called = False
raise StandardTimeoutError from e
except RuntimeError:
raise ConnectionError(
"Likely FD Kernel/Loop Racing Allocation Error. You should retry."
)
else:
try:
await asyncio.get_running_loop().sock_connect(self._sock, addr)
except RuntimeError: # Defensive: CPython might raise RuntimeError if there is a FD allocation error.
raise ConnectionError(
"Likely FD Kernel/Loop Racing Allocation Error. You should retry."
)
if self.type == socket.SOCK_STREAM or self.type == -1: # type: ignore[comparison-overlap]
self._reader, self._writer = await asyncio.open_connection(sock=self._sock)
elif self.type == socket.SOCK_DGRAM:
self._reader, self._writer = await open_dgram_connection(sock=self._sock)
# can become an asyncio.TransportSocket
assert self._writer is not None
self._sock = self._writer.get_extra_info("socket", self._sock)
self._addr = addr
self._established.set()
async def wrap_socket(
self,
ctx: ssl.SSLContext,
*,
server_hostname: str | None = None,
ssl_handshake_timeout: float | None = None,
) -> SSLAsyncSocket:
await self._established.wait()
self._established.clear()
# only if Python <= 3.10
try:
setattr(
asyncio.sslproto._SSLProtocolTransport, # type: ignore[attr-defined]
"_start_tls_compatible",
True,
)
except AttributeError:
pass
if self.type == socket.SOCK_STREAM:
assert self._writer is not None
assert isinstance(self._writer, asyncio.StreamWriter)
# bellow is hard to maintain. Starting with 3.11+, it is useless.
protocol = self._writer._protocol # type: ignore[attr-defined]
await self._writer.drain()
new_transport = await self._writer._loop.start_tls( # type: ignore[attr-defined]
self._writer._transport, # type: ignore[attr-defined]
protocol,
ctx,
server_side=False,
server_hostname=server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
)
self._writer._transport = new_transport # type: ignore[attr-defined]
transport = self._writer.transport
protocol._stream_writer = self._writer
protocol._transport = transport
protocol._over_ssl = transport.get_extra_info("sslcontext") is not None
self._tls_ctx = ctx
else:
raise RuntimeError("Unsupported socket type")
self._established.set()
self.__class__ = SSLAsyncSocket
return self # type: ignore[return-value]
async def recv(self, size: int = -1) -> bytes | list[bytes]:
"""Receive data from the socket.
Returns ``bytes`` for a single datagram (or stream chunk), or
``list[bytes]`` when GRO / batch-receive delivered multiple
coalesced datagrams in one syscall. The caller can then feed
all segments to the QUIC state-machine in a tight loop before
probing, avoiding per-datagram overhead."""
if size == -1:
size = 65536
assert self._reader is not None
await self._established.wait()
await self._reader_semaphore.acquire()
try:
if self._external_timeout is not None:
try:
async with timeout(self._external_timeout):
return await self._reader.read(n=size)
except (FutureTimeoutError, AsyncioTimeoutError, TimeoutError) as e:
self._reader_semaphore.release()
raise StandardTimeoutError from e
except OSError as e: # Defensive: treat any OSError as ConnReset!
raise ConnectionResetError() from e
return await self._reader.read(n=size)
finally:
self._reader_semaphore.release()
async def read_exact(self, size: int = -1) -> bytes | list[bytes]:
"""Just an alias for recv(), it is needed due to our custom AsyncSocks override."""
return await self.recv(size=size)
async def read(self) -> bytes | list[bytes]:
"""Just an alias for recv(), it is needed due to our custom AsyncSocks override."""
return await self.recv()
async def sendall(self, data: bytes | bytearray | memoryview | list[bytes]) -> None:
assert self._writer is not None
await self._established.wait()
await self._writer_semaphore.acquire()
try:
self._writer.write(data) # type: ignore[arg-type]
await self._writer.drain()
except Exception:
raise
finally:
self._writer_semaphore.release()
async def write_all(
self, data: bytes | bytearray | memoryview | list[bytes]
) -> None:
"""Just an alias for sendall(), it is needed due to our custom AsyncSocks override."""
await self.sendall(data)
async def send(self, data: bytes | bytearray | memoryview | list[bytes]) -> None:
await self.sendall(data)
def settimeout(self, __value: float | None = None) -> None:
self._external_timeout = __value
def gettimeout(self) -> float | None:
return self._external_timeout
def getpeername(self) -> tuple[str, int]:
return self._sock.getpeername() # type: ignore[no-any-return]
def bind(self, addr: tuple[str, int]) -> None:
self._sock.bind(addr)
class SSLAsyncSocket(AsyncSocket):
_tls_ctx: ssl.SSLContext
_tls_in_tls: bool
@typing.overload
def getpeercert(
self, binary_form: Literal[False] = ...
) -> _TYPE_PEER_CERT_RET_DICT | None: ...
@typing.overload
def getpeercert(self, binary_form: Literal[True]) -> bytes | None: ...
def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET:
return self.sslobj.getpeercert(binary_form=binary_form) # type: ignore[return-value]
def selected_alpn_protocol(self) -> str | None:
return self.sslobj.selected_alpn_protocol()
@property
def sslobj(self) -> ssl.SSLSocket | ssl.SSLObject:
if self._writer is not None:
sslobj: ssl.SSLSocket | ssl.SSLObject | None = self._writer.get_extra_info(
"ssl_object"
)
if sslobj is not None:
return sslobj
raise RuntimeError(
'"ssl_object" could not be extracted from this SslAsyncSock instance'
)
def version(self) -> str | None:
return self.sslobj.version()
@property
def context(self) -> ssl.SSLContext:
return self.sslobj.context
@property
def _sslobj(self) -> ssl.SSLSocket | ssl.SSLObject:
return self.sslobj
def cipher(self) -> tuple[str, str, int] | None:
return self.sslobj.cipher()
async def wrap_socket(
self,
ctx: ssl.SSLContext,
*,
server_hostname: str | None = None,
ssl_handshake_timeout: float | None = None,
) -> SSLAsyncSocket:
self._tls_in_tls = True
return await super().wrap_socket(
ctx,
server_hostname=server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
)
def _has_complete_support_dgram() -> bool:
"""A bug exist in PyPy asyncio implementation that prevent us to use a DGRAM socket.
This piece of code inform us, potentially, if PyPy has fixed the winapi implementation.
See https://github.com/pypy/pypy/issues/4008 and https://github.com/jawah/niquests/pull/87
The stacktrace look as follows:
File "C:\\hostedtoolcache\\windows\\PyPy\3.10.13\x86\\Lib\asyncio\\windows_events.py", line 594, in connect
_overlapped.WSAConnect(conn.fileno(), address)
AttributeError: module '_overlapped' has no attribute 'WSAConnect'
"""
import platform
if platform.system() == "Windows" and platform.python_implementation() == "PyPy":
try:
import _overlapped # type: ignore[import-not-found]
except ImportError: # Defensive:
return False
if hasattr(_overlapped, "WSAConnect"):
return True
return False
return True
__all__ = (
"AsyncSocket",
"SSLAsyncSocket",
"_has_complete_support_dgram",
)

View File

@@ -0,0 +1,640 @@
"""
High-performance asyncio DatagramTransport with Linux-specific UDP
receive/send coalescing:
- GRO (receive): ``setsockopt(SOL_UDP, UDP_GRO)`` + ``recvmsg`` cmsg
- GSO (send): ``sendmsg`` with ``UDP_SEGMENT`` cmsg
All other platforms fall back to the standard asyncio DatagramTransport.
"""
from __future__ import annotations
import asyncio
import collections
import socket
import struct
from collections import deque
from typing import Any, Callable
from ..._constant import UDP_LINUX_GRO, UDP_LINUX_SEGMENT
_UINT16 = struct.Struct("=H")
_DEFAULT_GRO_BUF = 65535
# Flow control watermarks for the custom write queue
_HIGH_WATERMARK = 64 * 1024
_LOW_WATERMARK = 16 * 1024
# GSO kernel limit: max segments per sendmsg call
_GSO_MAX_SEGMENTS = 64
def _sock_has_gro(sock: socket.socket) -> bool:
"""Check if GRO is enabled on *sock* (caller must have set it)."""
try:
return sock.getsockopt(socket.SOL_UDP, UDP_LINUX_GRO) == 1
except OSError:
return False
def _sock_has_gso(sock: socket.socket) -> bool:
"""Check if the kernel supports GSO on *sock*."""
try:
sock.getsockopt(socket.SOL_UDP, UDP_LINUX_SEGMENT)
return True
except OSError:
return False
def _split_gro_buffer(buf: bytes, segment_size: int) -> list[bytes]:
if segment_size <= 0 or len(buf) <= segment_size:
return [buf]
segments = []
mv = memoryview(buf)
for offset in range(0, len(buf), segment_size):
segments.append(bytes(mv[offset : offset + segment_size]))
return segments
def _group_by_segment_size(datagrams: list[bytes]) -> list[tuple[int, list[bytes]]]:
"""Group consecutive same-size datagrams for Linux UDP GSO.
GSO requires all segments to be the same size (except the last,
which may be shorter). Max 64 segments per ``sendmsg`` call."""
if not datagrams:
return []
groups: list[tuple[int, list[bytes]]] = []
current_size = len(datagrams[0])
current_group: list[bytes] = [datagrams[0]]
for dgram in datagrams[1:]:
if len(dgram) == current_size and len(current_group) < _GSO_MAX_SEGMENTS:
current_group.append(dgram)
else:
groups.append((current_size, current_group))
current_size = len(dgram)
current_group = [dgram]
groups.append((current_size, current_group))
return groups
def sync_recv_gro(
sock: socket.socket, bufsize: int, gro_segment_size: int = 1280
) -> bytes | list[bytes]:
"""Blocking recvmsg with GRO cmsg parsing. Returns bytes or list[bytes]."""
ancbufsize = socket.CMSG_SPACE(_UINT16.size)
data, ancdata, _flags, addr = sock.recvmsg(bufsize, ancbufsize)
if not data:
return b""
segment_size = gro_segment_size
for cmsg_level, cmsg_type, cmsg_data in ancdata:
if cmsg_level == socket.SOL_UDP and cmsg_type == UDP_LINUX_GRO:
(segment_size,) = _UINT16.unpack(cmsg_data[:2])
break
if len(data) <= segment_size:
return data
return _split_gro_buffer(data, segment_size)
def sync_sendmsg_gso(sock: socket.socket, datagrams: list[bytes]) -> None:
"""Batch-send datagrams using GSO. Falls back to individual sends."""
for segment_size, group in _group_by_segment_size(datagrams):
if len(group) == 1:
sock.sendall(group[0])
continue
buf = b"".join(group)
sock.sendmsg(
[buf],
[(socket.SOL_UDP, UDP_LINUX_SEGMENT, _UINT16.pack(segment_size))],
)
class _OptimizedDatagramTransport(asyncio.DatagramTransport):
__slots__ = (
"_loop",
"_sock",
"_protocol",
"_address",
"_gro_enabled",
"_gso_enabled",
"_gro_segment_size",
"_recv_buf_size",
"_closing",
"_closed_fut",
"_extra",
"_paused",
"_write_ready",
"_send_queue",
"_buffer_size",
"_protocol_paused",
)
def __init__(
self,
loop: asyncio.AbstractEventLoop,
sock: socket.socket,
protocol: asyncio.DatagramProtocol,
address: tuple[str, int] | None,
gro_enabled: bool,
gso_enabled: bool,
gro_segment_size: int,
) -> None:
super().__init__()
self._loop = loop
self._sock = sock
self._protocol = protocol
self._address = address
self._gro_enabled = gro_enabled
self._gso_enabled = gso_enabled
self._gro_segment_size = gro_segment_size
self._closing = False
self._closed_fut: asyncio.Future[None] = loop.create_future()
self._paused = False
self._write_ready = True
# Write buffer state
self._send_queue: deque[tuple[bytes, tuple[str, int] | None]] = (
collections.deque()
)
self._buffer_size = 0
self._protocol_paused = False
self._recv_buf_size = _DEFAULT_GRO_BUF if gro_enabled else gro_segment_size
self._extra = {
"peername": address,
"socket": sock,
"sockname": sock.getsockname(),
}
def get_extra_info(self, name: str, default: Any = None) -> Any:
return self._extra.get(name, default)
def is_closing(self) -> bool:
return self._closing
def close(self) -> None:
if self._closing:
return
self._closing = True
self._loop.remove_reader(self._sock.fileno())
# Drain the write queue gracefully in the background
if not self._send_queue:
self._loop.call_soon(self._call_connection_lost, None)
def abort(self) -> None:
self._closing = True
self._call_connection_lost(None)
def _call_connection_lost(self, exc: Exception | None) -> None:
try:
self._loop.remove_reader(self._sock.fileno())
self._loop.remove_writer(self._sock.fileno())
except Exception:
pass
try:
self._protocol.connection_lost(exc)
finally:
self._sock.close()
if not self._closed_fut.done():
self._closed_fut.set_result(None)
def sendto(self, data: bytes, addr: tuple[str, int] | None = None) -> None: # type: ignore[override]
if self._closing:
raise OSError("Transport is closing")
target = addr or self._address
if not self._write_ready:
self._queue_write(data, target)
return
try:
if target is not None:
self._sock.sendto(data, target)
else:
self._sock.send(data)
except BlockingIOError:
self._write_ready = False
self._loop.add_writer(self._sock.fileno(), self._on_write_ready)
self._queue_write(data, target)
except OSError as exc:
self._protocol.error_received(exc)
def sendto_many(self, datagrams: list[bytes]) -> None:
"""Send multiple datagrams, using GSO when available.
Falls back to individual ``sendto`` calls when GSO is not
supported or the socket write buffer is full."""
if self._closing:
raise OSError("Transport is closing")
if not self._write_ready:
target = self._address
for dgram in datagrams:
self._queue_write(dgram, target)
return
if self._gso_enabled:
self._send_linux_gso(datagrams)
else:
for dgram in datagrams:
self.sendto(dgram)
def _send_linux_gso(self, datagrams: list[bytes]) -> None:
for segment_size, group in _group_by_segment_size(datagrams):
if len(group) == 1:
# Single datagram — plain send (GSO needs >1 segment)
try:
self._sock.send(group[0])
except BlockingIOError:
self._write_ready = False
self._loop.add_writer(self._sock.fileno(), self._on_write_ready)
self._queue_write(group[0], self._address)
return
except OSError as exc:
self._protocol.error_received(exc)
continue
buf = b"".join(group)
try:
self._sock.sendmsg(
[buf],
[(socket.SOL_UDP, UDP_LINUX_SEGMENT, _UINT16.pack(segment_size))],
)
except BlockingIOError:
self._write_ready = False
self._loop.add_writer(self._sock.fileno(), self._on_write_ready)
# Queue individual datagrams as fallback
for dgram in group:
self._queue_write(dgram, self._address)
return
except OSError as exc:
self._protocol.error_received(exc)
def _queue_write(self, data: bytes, addr: tuple[str, int] | None) -> None:
self._send_queue.append((data, addr))
self._buffer_size += len(data)
self._maybe_pause_protocol()
def _maybe_pause_protocol(self) -> None:
if self._buffer_size >= _HIGH_WATERMARK and not self._protocol_paused:
self._protocol_paused = True
try:
self._protocol.pause_writing()
except AttributeError:
pass
def _maybe_resume_protocol(self) -> None:
if self._protocol_paused and self._buffer_size <= _LOW_WATERMARK:
self._protocol_paused = False
try:
self._protocol.resume_writing()
except AttributeError:
pass
def _on_write_ready(self) -> None:
while self._send_queue:
data, addr = self._send_queue[0]
try:
if addr is not None:
self._sock.sendto(data, addr)
else:
self._sock.send(data)
except BlockingIOError:
return
except OSError as exc:
self._protocol.error_received(exc)
self._send_queue.popleft()
self._buffer_size -= len(data)
self._maybe_resume_protocol()
self._write_ready = True
self._loop.remove_writer(self._sock.fileno())
if self._closing:
self._call_connection_lost(None)
def pause_reading(self) -> None:
if not self._paused:
self._paused = True
self._loop.remove_reader(self._sock.fileno())
def resume_reading(self) -> None:
if self._paused:
self._paused = False
self._loop.add_reader(self._sock.fileno(), self._on_readable)
def _start(self) -> None:
self._loop.call_soon(self._protocol.connection_made, self)
self._loop.add_reader(self._sock.fileno(), self._on_readable)
def _on_readable(self) -> None:
if self._closing:
return
self._recv_linux_gro()
def _recv_linux_gro(self) -> None:
ancbufsize = socket.CMSG_SPACE(_UINT16.size)
while True:
try:
data, ancdata, _flags, addr = self._sock.recvmsg(
self._recv_buf_size, ancbufsize
)
except BlockingIOError:
return
except OSError as exc:
self._protocol.error_received(exc)
return
if not data:
return
segment_size = self._gro_segment_size
for cmsg_level, cmsg_type, cmsg_data in ancdata:
if cmsg_level == socket.SOL_UDP and cmsg_type == UDP_LINUX_GRO:
(segment_size,) = _UINT16.unpack(cmsg_data[:2])
break
if len(data) <= segment_size:
self._protocol.datagram_received(data, addr)
else:
segments = _split_gro_buffer(data, segment_size)
self._protocol.datagrams_received(segments, addr) # type: ignore[attr-defined]
async def create_udp_endpoint(
loop: asyncio.AbstractEventLoop,
protocol_factory: Callable[[], asyncio.DatagramProtocol],
*,
local_addr: tuple[str, int] | None = None,
remote_addr: tuple[str, int] | None = None,
family: int = socket.AF_UNSPEC,
reuse_port: bool = False,
gro_segment_size: int = 1280,
sock: socket.socket | None = None,
) -> tuple[asyncio.DatagramTransport, asyncio.DatagramProtocol]:
if sock is not None:
# Caller provided a pre-connected socket — skip creation/bind/connect.
try:
connected_addr = sock.getpeername()
except OSError:
connected_addr = None
else:
# 1. Resolve Addresses
if family == socket.AF_UNSPEC:
target_addr = local_addr or remote_addr
if target_addr:
infos = await loop.getaddrinfo(
target_addr[0], target_addr[1], type=socket.SOCK_DGRAM
)
family = infos[0][0]
else:
family = socket.AF_INET
# 2. Create Socket
sock = socket.socket(family, socket.SOCK_DGRAM)
sock.setblocking(False)
if reuse_port:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
if local_addr:
sock.bind(local_addr)
connected_addr = None
if remote_addr:
await loop.sock_connect(sock, remote_addr)
connected_addr = remote_addr
# 3. Determine capabilities — the caller is responsible for
# enabling GRO via setsockopt before handing us the socket.
gro_enabled = _sock_has_gro(sock)
gso_enabled = _sock_has_gso(sock)
if not gro_enabled and not gso_enabled:
return await loop.create_datagram_endpoint(
lambda: protocol_factory(), sock=sock
)
# 4. Wire up optimized transport
protocol = protocol_factory()
transport = _OptimizedDatagramTransport(
loop=loop,
sock=sock,
protocol=protocol,
address=connected_addr,
gro_enabled=gro_enabled,
gso_enabled=gso_enabled,
gro_segment_size=gro_segment_size,
)
transport._start()
return transport, protocol
class DatagramReader:
"""API-compatible with ``asyncio.StreamReader`` (duck-typed) so that
``AsyncSocket`` can assign an instance to ``self._reader`` and the
existing ``recv()`` code works unchanged.
When GRO delivers multiple coalesced segments in a single syscall,
``feed_datagrams()`` stores them as a single ``list[bytes]`` entry.
``read()`` then returns that list directly so the caller can feed
all segments to the QUIC state-machine in one pass before probing —
avoiding the per-datagram recv→feed→probe round-trip overhead."""
def __init__(self) -> None:
self._buffer: deque[bytes | list[bytes]] = collections.deque()
self._waiter: asyncio.Future[None] | None = None
self._exception: BaseException | None = None
self._eof = False
def feed_datagram(self, data: bytes, addr: Any) -> None:
"""Feed a single (non-coalesced) datagram."""
self._buffer.append(data)
self._wake_waiter()
def feed_datagrams(self, data: list[bytes], addr: Any) -> None:
"""Feed a batch of coalesced datagrams as a single entry."""
self._buffer.append(data)
self._wake_waiter()
def set_exception(self, exc: BaseException) -> None:
self._exception = exc
self._wake_waiter()
def connection_lost(self, exc: BaseException | None) -> None:
self._eof = True
if exc is not None:
self._exception = exc
self._wake_waiter()
def _wake_waiter(self) -> None:
waiter = self._waiter
if waiter is not None and not waiter.done():
waiter.set_result(None)
async def read(self, n: int = -1) -> bytes | list[bytes]:
"""Return the next entry from the buffer.
* ``bytes`` — a single datagram (non-coalesced).
* ``list[bytes]`` — a batch of coalesced datagrams from one
GRO syscall.
* ``b""`` — EOF.
"""
if self._buffer:
return self._buffer.popleft()
if self._exception is not None:
raise self._exception
if self._eof:
return b""
self._waiter = asyncio.get_running_loop().create_future()
try:
await self._waiter
finally:
self._waiter = None
if self._buffer:
return self._buffer.popleft()
if self._exception is not None:
raise self._exception
return b""
class DatagramWriter:
"""API-compatible with ``asyncio.StreamWriter`` (duck-typed) so that
``AsyncSocket`` can assign an instance to ``self._writer`` and the
existing ``sendall()``, ``close()``, ``wait_for_close()`` code works
unchanged."""
def __init__(
self,
transport: asyncio.DatagramTransport,
) -> None:
self._transport = transport
self._address: tuple[str, int] | None = transport.get_extra_info("peername")
self._closed_event = asyncio.Event()
self._paused = False
self._drain_waiter: asyncio.Future[None] | None = None
@property
def transport(self) -> asyncio.DatagramTransport:
return self._transport
def write(self, data: bytes | bytearray | memoryview | list[bytes]) -> None:
if self._transport.is_closing():
return
if isinstance(data, list):
if hasattr(self._transport, "sendto_many"):
self._transport.sendto_many(data)
else:
# Plain asyncio transport — send individually
for dgram in data:
self._transport.sendto(dgram, self._address)
else:
self._transport.sendto(bytes(data), self._address)
async def drain(self) -> None:
if not self._paused:
return
self._drain_waiter = asyncio.get_running_loop().create_future()
try:
await self._drain_waiter
finally:
self._drain_waiter = None
def close(self) -> None:
self._transport.close()
async def wait_closed(self) -> None:
await self._closed_event.wait()
def get_extra_info(self, name: str, default: Any = None) -> Any:
return self._transport.get_extra_info(name, default)
def _pause_writing(self) -> None:
self._paused = True
def _resume_writing(self) -> None:
self._paused = False
waiter = self._drain_waiter
if waiter is not None and not waiter.done():
waiter.set_result(None)
class _DatagramBridgeProtocol(asyncio.DatagramProtocol):
"""Bridges ``asyncio.DatagramProtocol`` callbacks to
``DatagramReader`` / ``DatagramWriter``."""
def __init__(self, reader: DatagramReader) -> None:
self._reader = reader
self._writer: DatagramWriter | None = None
def connection_made(self, transport: asyncio.BaseTransport) -> None:
pass # transport is already wired via DatagramWriter
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
self._reader.feed_datagram(data, addr)
def datagrams_received(self, data: list[bytes], addr: tuple[str, int]) -> None:
self._reader.feed_datagrams(data, addr)
def error_received(self, exc: Exception) -> None:
self._reader.set_exception(exc)
def connection_lost(self, exc: BaseException | None) -> None:
self._reader.connection_lost(exc)
if self._writer is not None:
self._writer._closed_event.set()
def pause_writing(self) -> None:
if self._writer is not None:
self._writer._pause_writing()
def resume_writing(self) -> None:
if self._writer is not None:
self._writer._resume_writing()
async def open_dgram_connection(
remote_addr: tuple[str, int] | None = None,
*,
local_addr: tuple[str, int] | None = None,
family: int = socket.AF_UNSPEC,
sock: socket.socket | None = None,
gro_segment_size: int = 1280,
) -> tuple[DatagramReader, DatagramWriter]:
loop = asyncio.get_running_loop()
reader = DatagramReader()
protocol = _DatagramBridgeProtocol(reader)
transport, _ = await create_udp_endpoint(
loop,
lambda: protocol,
local_addr=local_addr,
remote_addr=remote_addr,
family=family,
gro_segment_size=gro_segment_size,
sock=sock,
)
writer = DatagramWriter(transport)
protocol._writer = writer
return reader, writer

View File

@@ -0,0 +1,142 @@
from __future__ import annotations
import enum
from asyncio import CancelledError, events, tasks
from types import TracebackType
__all__ = (
"Timeout",
"timeout",
)
class _State(enum.Enum):
CREATED = "created"
ENTERED = "active"
EXPIRING = "expiring"
EXPIRED = "expired"
EXITED = "finished"
class Timeout:
"""Asynchronous context manager for cancelling overdue coroutines.
Use `timeout()` or `timeout_at()` rather than instantiating this class directly.
"""
def __init__(self, when: float | None) -> None:
"""Schedule a timeout that will trigger at a given loop time.
- If `when` is `None`, the timeout will never trigger.
- If `when < loop.time()`, the timeout will trigger on the next
iteration of the event loop.
"""
self._state = _State.CREATED
self._timeout_handler: events.TimerHandle | events.Handle | None = None
self._task: tasks.Task | None = None # type: ignore[type-arg]
self._when = when
def when(self) -> float | None:
"""Return the current deadline."""
return self._when
def reschedule(self, when: float | None) -> None:
"""Reschedule the timeout."""
if self._state is not _State.ENTERED:
if self._state is _State.CREATED:
raise RuntimeError("Timeout has not been entered")
raise RuntimeError(
f"Cannot change state of {self._state.value} Timeout",
)
self._when = when
if self._timeout_handler is not None:
self._timeout_handler.cancel()
if when is None:
self._timeout_handler = None
else:
loop = events.get_running_loop()
if when <= loop.time():
self._timeout_handler = loop.call_soon(self._on_timeout)
else:
self._timeout_handler = loop.call_at(when, self._on_timeout)
def expired(self) -> bool:
"""Is timeout expired during execution?"""
return self._state in (_State.EXPIRING, _State.EXPIRED)
def __repr__(self) -> str:
info = [""]
if self._state is _State.ENTERED:
when = round(self._when, 3) if self._when is not None else None
info.append(f"when={when}")
info_str = " ".join(info)
return f"<Timeout [{self._state.value}]{info_str}>"
async def __aenter__(self) -> "Timeout":
if self._state is not _State.CREATED:
raise RuntimeError("Timeout has already been entered")
task = tasks.current_task()
if task is None:
raise RuntimeError("Timeout should be used inside a task")
self._state = _State.ENTERED
self._task = task
self.reschedule(self._when)
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
assert self._state in (_State.ENTERED, _State.EXPIRING)
assert self._task is not None
if self._timeout_handler is not None:
self._timeout_handler.cancel()
self._timeout_handler = None
if self._state is _State.EXPIRING:
self._state = _State.EXPIRED
if exc_type is CancelledError:
# Since there are no new cancel requests, we're
# handling this.
raise TimeoutError from exc_val
elif self._state is _State.ENTERED:
self._state = _State.EXITED
return None
def _on_timeout(self) -> None:
assert self._state is _State.ENTERED
assert self._task is not None
self._task.cancel()
self._state = _State.EXPIRING
# drop the reference early
self._timeout_handler = None
def timeout(delay: float | None) -> Timeout:
"""Timeout async context manager.
Useful in cases when you want to apply timeout logic around block
of code or in cases when asyncio.wait_for is not suitable. For example:
>>> async with asyncio.timeout(10): # 10 seconds timeout
... await long_running_task()
delay - value in seconds or None to disable timeout logic
long_running_task() is interrupted by raising asyncio.CancelledError,
the top-most affected timeout() context manager converts CancelledError
into TimeoutError.
"""
loop = events.get_running_loop()
return Timeout(loop.time() + delay if delay is not None else None)

View File

@@ -0,0 +1,60 @@
from __future__ import annotations
from .protocol import ExtensionFromHTTP
from .raw import RawExtensionFromHTTP
from .sse import ServerSideEventExtensionFromHTTP
try:
from .ws import WebSocketExtensionFromHTTP, WebSocketExtensionFromMultiplexedHTTP
except ImportError:
WebSocketExtensionFromHTTP = None # type: ignore[misc, assignment]
WebSocketExtensionFromMultiplexedHTTP = None # type: ignore[misc, assignment]
from typing import TypeVar
T = TypeVar("T")
def recursive_subclasses(cls: type[T]) -> list[type[T]]:
all_subclasses = []
for subclass in cls.__subclasses__():
all_subclasses.append(subclass)
all_subclasses.extend(recursive_subclasses(subclass))
return all_subclasses
def load_extension(
scheme: str | None, implementation: str | None = None
) -> type[ExtensionFromHTTP]:
if scheme is None:
return RawExtensionFromHTTP
scheme = scheme.lower()
if implementation:
implementation = implementation.lower()
for extension in recursive_subclasses(ExtensionFromHTTP):
if scheme in extension.supported_schemes():
if (
implementation is not None
and extension.implementation() != implementation
):
continue
return extension
raise ImportError(
f"Tried to load HTTP extension '{scheme}' but no available plugin support it."
)
__all__ = (
"ExtensionFromHTTP",
"RawExtensionFromHTTP",
"WebSocketExtensionFromHTTP",
"WebSocketExtensionFromMultiplexedHTTP",
"ServerSideEventExtensionFromHTTP",
"load_extension",
)

View File

@@ -0,0 +1,51 @@
from __future__ import annotations
from .protocol import AsyncExtensionFromHTTP
from .raw import AsyncRawExtensionFromHTTP
from .sse import AsyncServerSideEventExtensionFromHTTP
try:
from .ws import (
AsyncWebSocketExtensionFromHTTP,
AsyncWebSocketExtensionFromMultiplexedHTTP,
)
except ImportError:
AsyncWebSocketExtensionFromHTTP = None # type: ignore[misc, assignment]
AsyncWebSocketExtensionFromMultiplexedHTTP = None # type: ignore[misc, assignment]
from .. import recursive_subclasses
def load_extension(
scheme: str | None, implementation: str | None = None
) -> type[AsyncExtensionFromHTTP]:
if scheme is None:
return AsyncRawExtensionFromHTTP
scheme = scheme.lower()
if implementation:
implementation = implementation.lower()
for extension in recursive_subclasses(AsyncExtensionFromHTTP):
if scheme in extension.supported_schemes():
if (
implementation is not None
and extension.implementation() != implementation
):
continue
return extension
raise ImportError(
f"Tried to load HTTP extension '{scheme}' but no available plugin support it."
)
__all__ = (
"AsyncExtensionFromHTTP",
"AsyncRawExtensionFromHTTP",
"AsyncWebSocketExtensionFromHTTP",
"AsyncWebSocketExtensionFromMultiplexedHTTP",
"AsyncServerSideEventExtensionFromHTTP",
"load_extension",
)

View File

@@ -0,0 +1,188 @@
from __future__ import annotations
import typing
from abc import ABCMeta
from contextlib import asynccontextmanager
from socket import timeout as SocketTimeout
if typing.TYPE_CHECKING:
from ...._async.response import AsyncHTTPResponse
from ....backend import HttpVersion
from ....backend._async._base import AsyncDirectStreamAccess
from ....util._async.traffic_police import AsyncTrafficPolice
from ....exceptions import (
BaseSSLError,
ProtocolError,
ReadTimeoutError,
SSLError,
MustRedialError,
)
class AsyncExtensionFromHTTP(metaclass=ABCMeta):
"""Represent an extension that can be negotiated just after a "101 Switching Protocol" HTTP response.
This will considerably ease downstream integration."""
def __init__(self) -> None:
self._dsa: AsyncDirectStreamAccess | None = None
self._response: AsyncHTTPResponse | None = None
self._police_officer: AsyncTrafficPolice | None = None # type: ignore[type-arg]
@asynccontextmanager
async def _read_error_catcher(self) -> typing.AsyncGenerator[None, None]:
"""
Catch low-level python exceptions, instead re-raising urllib3
variants, so that low-level exceptions are not leaked in the
high-level api.
On unrecoverable issues, release the connection back to the pool.
"""
clean_exit = False
try:
try:
yield
except SocketTimeout as e:
clean_exit = True
pool = (
self._response._pool
if self._response and hasattr(self._response, "_pool")
else None
)
raise ReadTimeoutError(pool, None, "Read timed out.") from e # type: ignore[arg-type]
except BaseSSLError as e:
# FIXME: Is there a better way to differentiate between SSLErrors?
if "read operation timed out" not in str(e):
# SSL errors related to framing/MAC get wrapped and reraised here
raise SSLError(e) from e
clean_exit = True # ws algorithms based on timeouts can expect this without being harmful!
pool = (
self._response._pool
if self._response and hasattr(self._response, "_pool")
else None
)
raise ReadTimeoutError(pool, None, "Read timed out.") from e # type: ignore[arg-type]
except (OSError, MustRedialError) as e:
# This includes IncompleteRead.
raise ProtocolError(f"Connection broken: {e!r}", e) from e
# If no exception is thrown, we should avoid cleaning up
# unnecessarily.
clean_exit = True
finally:
# If we didn't terminate cleanly, we need to throw away our
# connection.
if not clean_exit:
# The response may not be closed but we're not going to use it
# anymore so close it now to ensure that the connection is
# released back to the pool.
if self._response:
await self.close()
@asynccontextmanager
async def _write_error_catcher(self) -> typing.AsyncGenerator[None, None]:
"""
Catch low-level python exceptions, instead re-raising urllib3
variants, so that low-level exceptions are not leaked in the
high-level api.
On unrecoverable issues, release the connection back to the pool.
"""
clean_exit = False
try:
try:
yield
except SocketTimeout as e:
pool = (
self._response._pool
if self._response and hasattr(self._response, "_pool")
else None
)
raise ReadTimeoutError(pool, None, "Read timed out.") from e # type: ignore[arg-type]
except BaseSSLError as e:
raise SSLError(e) from e
except OSError as e:
# This includes IncompleteRead.
raise ProtocolError(f"Connection broken: {e!r}", e) from e
# If no exception is thrown, we should avoid cleaning up
# unnecessarily.
clean_exit = True
finally:
# If we didn't terminate cleanly, we need to throw away our
# connection.
if not clean_exit:
# The response may not be closed but we're not going to use it
# anymore so close it now to ensure that the connection is
# released back to the pool.
if self._response:
await self.close()
@property
def urlopen_kwargs(self) -> dict[str, typing.Any]:
return {}
async def start(self, response: AsyncHTTPResponse) -> None:
"""The HTTP server gave us the go-to start negotiating another protocol."""
if response._fp is None or not hasattr(response._fp, "_dsa"):
raise OSError("The HTTP extension is closed or uninitialized")
self._dsa = response._fp._dsa
self._police_officer = response._police_officer
self._response = response
@property
def closed(self) -> bool:
return self._dsa is None
@staticmethod
def supported_svn() -> set[HttpVersion]:
"""Hint about supported parent SVN for this extension."""
raise NotImplementedError
@staticmethod
def implementation() -> str:
raise NotImplementedError
@staticmethod
def supported_schemes() -> set[str]:
"""Recognized schemes for the extension."""
raise NotImplementedError
@staticmethod
def scheme_to_http_scheme(scheme: str) -> str:
"""Convert the extension scheme to a known http scheme (either http or https)"""
raise NotImplementedError
def headers(self, http_version: HttpVersion) -> dict[str, str]:
"""Specific HTTP headers required (request) before the 101 status response."""
raise NotImplementedError
async def close(self) -> None:
"""End/Notify close for sub protocol."""
raise NotImplementedError
async def next_payload(self) -> str | bytes | None:
"""Unpack the next received message/payload from remote. This call does read from the socket.
If the method return None, it means that the remote closed the (extension) pipeline.
"""
raise NotImplementedError
async def send_payload(self, buf: str | bytes) -> None:
"""Dispatch a buffer to remote."""
raise NotImplementedError
async def on_payload(
self, callback: typing.Callable[[str | bytes | None], typing.Awaitable[None]]
) -> None:
"""Set up a callback that will be invoked automatically once a payload is received.
Meaning that you stop calling manually next_payload()."""
raise NotImplementedError

View File

@@ -0,0 +1,57 @@
from __future__ import annotations
from ....backend import HttpVersion
from .protocol import AsyncExtensionFromHTTP
class AsyncRawExtensionFromHTTP(AsyncExtensionFromHTTP):
"""Raw I/O from given HTTP stream after a 101 Switching Protocol Status."""
@staticmethod
def supported_svn() -> set[HttpVersion]:
return {HttpVersion.h11, HttpVersion.h2, HttpVersion.h3}
def headers(self, http_version: HttpVersion) -> dict[str, str]:
"""Specific HTTP headers required (request) before the 101 status response."""
return {}
async def close(self) -> None:
"""End/Notify close for sub protocol."""
if self._dsa is not None:
await self._dsa.close()
self._dsa = None
if self._response is not None:
await self._response.close()
self._response = None
self._police_officer = None
@staticmethod
def implementation() -> str:
return "raw"
@staticmethod
def supported_schemes() -> set[str]:
return set()
@staticmethod
def scheme_to_http_scheme(scheme: str) -> str:
return scheme
async def next_payload(self) -> bytes | None:
if self._police_officer is None or self._dsa is None:
raise OSError("The HTTP extension is closed or uninitialized")
async with self._police_officer.borrow(self._response):
async with self._read_error_catcher():
data, eot, _ = await self._dsa.recv_extended(None)
return data
async def send_payload(self, buf: str | bytes) -> None:
if self._police_officer is None or self._dsa is None:
raise OSError("The HTTP extension is closed or uninitialized")
if isinstance(buf, str):
buf = buf.encode()
async with self._police_officer.borrow(self._response):
async with self._write_error_catcher():
await self._dsa.sendall(buf)

View File

@@ -0,0 +1,132 @@
from __future__ import annotations
import typing
if typing.TYPE_CHECKING:
from ...._async.response import AsyncHTTPResponse
from ....backend import HttpVersion
from ..sse import ServerSentEvent
from .protocol import AsyncExtensionFromHTTP
class AsyncServerSideEventExtensionFromHTTP(AsyncExtensionFromHTTP):
def __init__(self) -> None:
super().__init__()
self._last_event_id: str | None = None
self._buffer: str = ""
self._stream: typing.AsyncGenerator[bytes, None] | None = None
@staticmethod
def supported_svn() -> set[HttpVersion]:
return {HttpVersion.h11, HttpVersion.h2, HttpVersion.h3}
@staticmethod
def implementation() -> str:
return "native"
@property
def urlopen_kwargs(self) -> dict[str, typing.Any]:
return {"preload_content": False}
async def close(self) -> None:
if self._stream is not None and self._response is not None:
await self._stream.aclose()
if (
self._response._fp is not None
and self._police_officer is not None
and hasattr(self._response._fp, "abort")
):
async with self._police_officer.borrow(self._response):
await self._response._fp.abort()
self._stream = None
self._response = None
self._police_officer = None
@property
def closed(self) -> bool:
return self._stream is None
async def start(self, response: AsyncHTTPResponse) -> None:
await super().start(response)
self._stream = response.stream(-1, decode_content=True)
def headers(self, http_version: HttpVersion) -> dict[str, str]:
return {"accept": "text/event-stream", "cache-control": "no-store"}
@typing.overload
async def next_payload(self, *, raw: typing.Literal[True] = True) -> str | None: ...
@typing.overload
async def next_payload(
self, *, raw: typing.Literal[False] = False
) -> ServerSentEvent | None: ...
async def next_payload(self, *, raw: bool = False) -> ServerSentEvent | str | None:
"""Unpack the next received message/payload from remote."""
if self._response is None or self._stream is None:
raise OSError("The HTTP extension is closed or uninitialized")
try:
raw_payload: str = (await self._stream.__anext__()).decode("utf-8")
except StopAsyncIteration:
await self._stream.aclose()
self._stream = None
return None
if self._buffer:
raw_payload = self._buffer + raw_payload
self._buffer = ""
kwargs: dict[str, typing.Any] = {}
eot = False
for line in raw_payload.splitlines():
if not line:
eot = True
break
key, _, value = line.partition(":")
if key not in {"event", "data", "retry", "id"}:
continue
if value.startswith(" "):
value = value[1:]
if key == "id":
if "\u0000" in value:
continue
if key == "retry":
try:
value = int(value) # type: ignore[assignment]
except (ValueError, TypeError):
continue
kwargs[key] = value
if eot is False:
self._buffer = raw_payload
return await self.next_payload(raw=raw) # type: ignore[call-overload,no-any-return]
if "id" not in kwargs and self._last_event_id is not None:
kwargs["id"] = self._last_event_id
event = ServerSentEvent(**kwargs)
if event.id:
self._last_event_id = event.id
if raw is True:
return raw_payload
return event
async def send_payload(self, buf: str | bytes) -> None:
"""Dispatch a buffer to remote."""
raise NotImplementedError("SSE is only one-way. Sending is forbidden.")
@staticmethod
def supported_schemes() -> set[str]:
return {"sse", "psse"}
@staticmethod
def scheme_to_http_scheme(scheme: str) -> str:
return {"sse": "https", "psse": "http"}[scheme]

View File

@@ -0,0 +1,238 @@
from __future__ import annotations
import typing
if typing.TYPE_CHECKING:
from ...._async.response import AsyncHTTPResponse
from wsproto import ConnectionType, WSConnection
from wsproto.events import (
AcceptConnection,
BytesMessage,
CloseConnection,
Ping,
Pong,
Request,
TextMessage,
)
from wsproto.extensions import PerMessageDeflate
from wsproto.utilities import ProtocolError as WebSocketProtocolError
from ....backend import HttpVersion
from ....exceptions import ProtocolError
from .protocol import AsyncExtensionFromHTTP
class AsyncWebSocketExtensionFromHTTP(AsyncExtensionFromHTTP):
def __init__(self) -> None:
super().__init__()
self._protocol = WSConnection(ConnectionType.CLIENT)
self._request_headers: dict[str, str] | None = None
self._remote_shutdown: bool = False
@staticmethod
def supported_svn() -> set[HttpVersion]:
return {HttpVersion.h11}
@staticmethod
def implementation() -> str:
return "wsproto"
async def start(self, response: AsyncHTTPResponse) -> None:
await super().start(response)
fake_http_response = b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n"
fake_http_response += b"Sec-Websocket-Accept: "
accept_token: str | None = response.headers.get("Sec-Websocket-Accept")
if accept_token is None:
raise ProtocolError(
"The WebSocket HTTP extension requires 'Sec-Websocket-Accept' header in the server response but was not present."
)
fake_http_response += accept_token.encode() + b"\r\n"
if "sec-websocket-extensions" in response.headers:
fake_http_response += (
b"Sec-Websocket-Extensions: "
+ response.headers.get("sec-websocket-extensions").encode() # type: ignore[union-attr]
+ b"\r\n"
)
fake_http_response += b"\r\n"
try:
self._protocol.receive_data(fake_http_response)
except WebSocketProtocolError as e:
raise ProtocolError from e # Defensive: should never occur!
event = next(self._protocol.events())
if not isinstance(event, AcceptConnection):
raise RuntimeError(
"The WebSocket state-machine did not pass the handshake phase when expected."
)
def headers(self, http_version: HttpVersion) -> dict[str, str]:
"""Specific HTTP headers required (request) before the 101 status response."""
if self._request_headers is not None:
return self._request_headers
try:
raw_data_to_socket = self._protocol.send(
Request(
host="example.com", target="/", extensions=(PerMessageDeflate(),)
)
)
except WebSocketProtocolError as e:
raise ProtocolError from e # Defensive: should never occur!
raw_headers = raw_data_to_socket.split(b"\r\n")[2:-2]
request_headers: dict[str, str] = {}
for raw_header in raw_headers:
k, v = raw_header.decode().split(": ")
request_headers[k.lower()] = v
if http_version != HttpVersion.h11:
del request_headers["upgrade"]
del request_headers["connection"]
request_headers[":protocol"] = "websocket"
request_headers[":method"] = "CONNECT"
self._request_headers = request_headers
return request_headers
async def close(self) -> None:
"""End/Notify close for sub protocol."""
if self._dsa is not None:
if self._police_officer is not None:
async with self._police_officer.borrow(self._response):
if self._remote_shutdown is False:
try:
data_to_send: bytes = self._protocol.send(
CloseConnection(0)
)
except WebSocketProtocolError:
pass
else:
async with self._write_error_catcher():
await self._dsa.sendall(data_to_send)
await self._dsa.close()
self._dsa = None
else:
self._dsa = None
if self._response is not None:
if self._police_officer is not None:
self._police_officer.forget(self._response)
else:
await self._response.close()
self._response = None
self._police_officer = None
async def next_payload(self) -> str | bytes | None:
"""Unpack the next received message/payload from remote."""
if self._dsa is None or self._response is None or self._police_officer is None:
raise OSError("The HTTP extension is closed or uninitialized")
async with self._police_officer.borrow(self._response):
for event in self._protocol.events():
if isinstance(event, TextMessage):
return event.data
elif isinstance(event, BytesMessage):
return event.data
elif isinstance(event, CloseConnection):
self._remote_shutdown = True
await self.close()
return None
elif isinstance(event, Ping):
try:
data_to_send: bytes = self._protocol.send(event.response())
except WebSocketProtocolError as e:
await self.close()
raise ProtocolError from e
async with self._write_error_catcher():
await self._dsa.sendall(data_to_send)
while True:
async with self._read_error_catcher():
data, eot, _ = await self._dsa.recv_extended(None)
try:
self._protocol.receive_data(data)
except WebSocketProtocolError as e:
raise ProtocolError from e
for event in self._protocol.events():
if isinstance(event, TextMessage):
return event.data
elif isinstance(event, BytesMessage):
return event.data
elif isinstance(event, CloseConnection):
self._remote_shutdown = True
await self.close()
return None
elif isinstance(event, Ping):
data_to_send = self._protocol.send(event.response())
async with self._write_error_catcher():
await self._dsa.sendall(data_to_send)
elif isinstance(event, Pong):
continue
async def send_payload(self, buf: str | bytes) -> None:
"""Dispatch a buffer to remote."""
if self._dsa is None or self._response is None or self._police_officer is None:
raise OSError("The HTTP extension is closed or uninitialized")
async with self._police_officer.borrow(self._response):
try:
if isinstance(buf, str):
data_to_send: bytes = self._protocol.send(TextMessage(buf))
else:
data_to_send = self._protocol.send(BytesMessage(buf))
except WebSocketProtocolError as e:
raise ProtocolError from e
async with self._write_error_catcher():
await self._dsa.sendall(data_to_send)
async def ping(self) -> None:
if self._dsa is None or self._response is None or self._police_officer is None:
raise OSError("The HTTP extension is closed or uninitialized")
async with self._police_officer.borrow(self._response):
try:
data_to_send: bytes = self._protocol.send(Ping())
except WebSocketProtocolError as e:
raise ProtocolError from e
async with self._write_error_catcher():
await self._dsa.sendall(data_to_send)
@staticmethod
def supported_schemes() -> set[str]:
return {"ws", "wss"}
@staticmethod
def scheme_to_http_scheme(scheme: str) -> str:
return {"ws": "http", "wss": "https"}[scheme]
class AsyncWebSocketExtensionFromMultiplexedHTTP(AsyncWebSocketExtensionFromHTTP):
"""
Plugin that support doing WebSocket over HTTP 2 and 3.
This implement RFC8441. Beware that this isn't actually supported by much server around internet.
"""
@staticmethod
def implementation() -> str:
return "rfc8441"
@staticmethod
def supported_svn() -> set[HttpVersion]:
return {HttpVersion.h11, HttpVersion.h2, HttpVersion.h3}

View File

@@ -0,0 +1,189 @@
from __future__ import annotations
import typing
from abc import ABCMeta
from contextlib import contextmanager
from socket import timeout as SocketTimeout
if typing.TYPE_CHECKING:
from ...backend import HttpVersion
from ...backend._base import DirectStreamAccess
from ...response import HTTPResponse
from ...util.traffic_police import TrafficPolice
from ...exceptions import (
BaseSSLError,
ProtocolError,
ReadTimeoutError,
SSLError,
MustRedialError,
)
class ExtensionFromHTTP(metaclass=ABCMeta):
"""Represent an extension that can be negotiated just after a "101 Switching Protocol" HTTP response.
This will considerably ease downstream integration."""
def __init__(self) -> None:
self._dsa: DirectStreamAccess | None = None
self._response: HTTPResponse | None = None
self._police_officer: TrafficPolice | None = None # type: ignore[type-arg]
@contextmanager
def _read_error_catcher(self) -> typing.Generator[None, None, None]:
"""
Catch low-level python exceptions, instead re-raising urllib3
variants, so that low-level exceptions are not leaked in the
high-level api.
On unrecoverable issues, release the connection back to the pool.
"""
clean_exit = False
try:
try:
yield
except SocketTimeout as e:
clean_exit = True
pool = (
self._response._pool
if self._response and hasattr(self._response, "_pool")
else None
)
raise ReadTimeoutError(pool, None, "Read timed out.") from e # type: ignore[arg-type]
except BaseSSLError as e:
# FIXME: Is there a better way to differentiate between SSLErrors?
if "read operation timed out" not in str(e):
# SSL errors related to framing/MAC get wrapped and reraised here
raise SSLError(e) from e
clean_exit = True # ws algorithms based on timeouts can expect this without being harmful!
pool = (
self._response._pool
if self._response and hasattr(self._response, "_pool")
else None
)
raise ReadTimeoutError(pool, None, "Read timed out.") from e # type: ignore[arg-type]
except (OSError, MustRedialError) as e:
# This includes IncompleteRead.
raise ProtocolError(f"Connection broken: {e!r}", e) from e
# If no exception is thrown, we should avoid cleaning up
# unnecessarily.
clean_exit = True
finally:
# If we didn't terminate cleanly, we need to throw away our
# connection.
if not clean_exit:
# The response may not be closed but we're not going to use it
# anymore so close it now to ensure that the connection is
# released back to the pool.
if self._response:
self.close()
@contextmanager
def _write_error_catcher(self) -> typing.Generator[None, None, None]:
"""
Catch low-level python exceptions, instead re-raising urllib3
variants, so that low-level exceptions are not leaked in the
high-level api.
On unrecoverable issues, release the connection back to the pool.
"""
clean_exit = False
try:
try:
yield
except SocketTimeout as e:
pool = (
self._response._pool
if self._response and hasattr(self._response, "_pool")
else None
)
raise ReadTimeoutError(pool, None, "Read timed out.") from e # type: ignore[arg-type]
except BaseSSLError as e:
raise SSLError(e) from e
except OSError as e:
# This includes IncompleteRead.
raise ProtocolError(f"Connection broken: {e!r}", e) from e
# If no exception is thrown, we should avoid cleaning up
# unnecessarily.
clean_exit = True
finally:
# If we didn't terminate cleanly, we need to throw away our
# connection.
if not clean_exit:
# The response may not be closed but we're not going to use it
# anymore so close it now to ensure that the connection is
# released back to the pool.
if self._response:
self.close()
@property
def urlopen_kwargs(self) -> dict[str, typing.Any]:
"""Return prerequisites. Must be passed as additional parameters to urlopen."""
return {}
def start(self, response: HTTPResponse) -> None:
"""The HTTP server gave us the go-to start negotiating another protocol."""
if response._fp is None or not hasattr(response._fp, "_dsa"):
raise RuntimeError(
"Attempt to start an HTTP extension without direct I/O access to the stream"
)
self._dsa = response._fp._dsa
self._police_officer = response._police_officer
self._response = response
@property
def closed(self) -> bool:
return self._dsa is None
@staticmethod
def supported_svn() -> set[HttpVersion]:
"""Hint about supported parent SVN for this extension."""
raise NotImplementedError
@staticmethod
def implementation() -> str:
raise NotImplementedError
@staticmethod
def supported_schemes() -> set[str]:
"""Recognized schemes for the extension."""
raise NotImplementedError
@staticmethod
def scheme_to_http_scheme(scheme: str) -> str:
"""Convert the extension scheme to a known http scheme (either http or https)"""
raise NotImplementedError
def headers(self, http_version: HttpVersion) -> dict[str, str]:
"""Specific HTTP headers required (request) before the 101 status response."""
raise NotImplementedError
def close(self) -> None:
"""End/Notify close for sub protocol."""
raise NotImplementedError
def next_payload(self) -> str | bytes | None:
"""Unpack the next received message/payload from remote. This call does read from the socket.
If the method return None, it means that the remote closed the (extension) pipeline.
"""
raise NotImplementedError
def send_payload(self, buf: str | bytes) -> None:
"""Dispatch a buffer to remote."""
raise NotImplementedError
def on_payload(self, callback: typing.Callable[[str | bytes | None], None]) -> None:
"""Set up a callback that will be invoked automatically once a payload is received.
Meaning that you stop calling manually next_payload()."""
raise NotImplementedError

View File

@@ -0,0 +1,58 @@
from __future__ import annotations
from ...backend import HttpVersion
from .protocol import ExtensionFromHTTP
class RawExtensionFromHTTP(ExtensionFromHTTP):
"""Raw I/O from given HTTP stream after a 101 Switching Protocol Status."""
@staticmethod
def supported_svn() -> set[HttpVersion]:
return {HttpVersion.h11, HttpVersion.h2, HttpVersion.h3}
def headers(self, http_version: HttpVersion) -> dict[str, str]:
"""Specific HTTP headers required (request) before the 101 status response."""
return {}
def close(self) -> None:
"""End/Notify close for sub protocol."""
if self._dsa is not None:
with self._write_error_catcher():
self._dsa.close()
self._dsa = None
if self._response is not None:
self._response.close()
self._response = None
self._police_officer = None
@staticmethod
def implementation() -> str:
return "raw"
@staticmethod
def supported_schemes() -> set[str]:
return set()
@staticmethod
def scheme_to_http_scheme(scheme: str) -> str:
return scheme
def next_payload(self) -> bytes | None:
if self._police_officer is None or self._dsa is None:
raise OSError("The HTTP extension is closed or uninitialized")
with self._police_officer.borrow(self._response):
with self._read_error_catcher():
data, eot, _ = self._dsa.recv_extended(None)
return data
def send_payload(self, buf: str | bytes) -> None:
if self._police_officer is None or self._dsa is None:
raise OSError("The HTTP extension is closed or uninitialized")
if isinstance(buf, str):
buf = buf.encode()
with self._police_officer.borrow(self._response):
with self._write_error_catcher():
self._dsa.sendall(buf)

View File

@@ -0,0 +1,185 @@
from __future__ import annotations
import json
import typing
from threading import RLock
if typing.TYPE_CHECKING:
from ...response import HTTPResponse
from ...backend import HttpVersion
from .protocol import ExtensionFromHTTP
class ServerSentEvent:
def __init__(
self,
event: str | None = None,
data: str | None = None,
id: str | None = None,
retry: int | None = None,
) -> None:
if not event:
event = "message"
if data is None:
data = ""
if id is None:
id = ""
self._event = event
self._data = data
self._id = id
self._retry = retry
@property
def event(self) -> str:
return self._event
@property
def data(self) -> str:
return self._data
@property
def id(self) -> str:
return self._id
@property
def retry(self) -> int | None:
return self._retry
def json(self) -> typing.Any:
return json.loads(self.data)
def __repr__(self) -> str:
pieces = [f"event={self.event!r}"]
if self.data != "":
pieces.append(f"data={self.data!r}")
if self.id != "":
pieces.append(f"id={self.id!r}")
if self.retry is not None:
pieces.append(f"retry={self.retry!r}")
return f"ServerSentEvent({', '.join(pieces)})"
class ServerSideEventExtensionFromHTTP(ExtensionFromHTTP):
def __init__(self) -> None:
super().__init__()
self._last_event_id: str | None = None
self._buffer: str = ""
self._lock = RLock()
self._stream: typing.Generator[bytes, None, None] | None = None
@staticmethod
def supported_svn() -> set[HttpVersion]:
return {HttpVersion.h11, HttpVersion.h2, HttpVersion.h3}
@staticmethod
def implementation() -> str:
return "native"
@property
def urlopen_kwargs(self) -> dict[str, typing.Any]:
return {"preload_content": False}
@property
def closed(self) -> bool:
return self._stream is None
def close(self) -> None:
if self._stream is not None and self._response is not None:
self._stream.close()
if (
self._response._fp is not None
and self._police_officer is not None
and hasattr(self._response._fp, "abort")
):
with self._police_officer.borrow(self._response):
self._response._fp.abort()
self._stream = None
self._response = None
self._police_officer = None
def start(self, response: HTTPResponse) -> None:
super().start(response)
self._stream = response.stream(-1, decode_content=True)
def headers(self, http_version: HttpVersion) -> dict[str, str]:
return {"accept": "text/event-stream", "cache-control": "no-store"}
@typing.overload
def next_payload(self, *, raw: typing.Literal[True] = True) -> str | None: ...
@typing.overload
def next_payload(
self, *, raw: typing.Literal[False] = False
) -> ServerSentEvent | None: ...
def next_payload(self, *, raw: bool = False) -> ServerSentEvent | str | None:
"""Unpack the next received message/payload from remote."""
if self._response is None or self._stream is None:
raise OSError("The HTTP extension is closed or uninitialized")
with self._lock:
try:
raw_payload: str = next(self._stream).decode("utf-8")
except StopIteration:
self._stream = None
return None
if self._buffer:
raw_payload = self._buffer + raw_payload
self._buffer = ""
kwargs: dict[str, typing.Any] = {}
eot = False
for line in raw_payload.splitlines():
if not line:
eot = True
break
key, _, value = line.partition(":")
if key not in {"event", "data", "retry", "id"}:
continue
if value.startswith(" "):
value = value[1:]
if key == "id":
if "\u0000" in value:
continue
if key == "retry":
try:
value = int(value) # type: ignore[assignment]
except (ValueError, TypeError):
continue
kwargs[key] = value
if eot is False:
self._buffer = raw_payload
return self.next_payload(raw=raw) # type: ignore[call-overload,no-any-return]
if "id" not in kwargs and self._last_event_id is not None:
kwargs["id"] = self._last_event_id
event = ServerSentEvent(**kwargs)
if event.id:
self._last_event_id = event.id
if raw is True:
return raw_payload
return event
def send_payload(self, buf: str | bytes) -> None:
"""Dispatch a buffer to remote."""
raise NotImplementedError("SSE is only one-way. Sending is forbidden.")
@staticmethod
def supported_schemes() -> set[str]:
return {"sse", "psse"}
@staticmethod
def scheme_to_http_scheme(scheme: str) -> str:
return {"sse": "https", "psse": "http"}[scheme]

View File

@@ -0,0 +1,247 @@
from __future__ import annotations
import typing
if typing.TYPE_CHECKING:
from ...response import HTTPResponse
from wsproto import ConnectionType, WSConnection
from wsproto.events import (
AcceptConnection,
BytesMessage,
CloseConnection,
Ping,
Pong,
Request,
TextMessage,
)
from wsproto.extensions import PerMessageDeflate
from wsproto.utilities import ProtocolError as WebSocketProtocolError
from ...backend import HttpVersion
from ...exceptions import ProtocolError
from .protocol import ExtensionFromHTTP
class WebSocketExtensionFromHTTP(ExtensionFromHTTP):
def __init__(self) -> None:
super().__init__()
self._protocol = WSConnection(ConnectionType.CLIENT)
self._request_headers: dict[str, str] | None = None
self._remote_shutdown: bool = False
@staticmethod
def supported_svn() -> set[HttpVersion]:
return {HttpVersion.h11}
@staticmethod
def implementation() -> str:
return "wsproto"
def start(self, response: HTTPResponse) -> None:
super().start(response)
fake_http_response = b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n"
fake_http_response += b"Sec-Websocket-Accept: "
accept_token: str | None = response.headers.get("Sec-Websocket-Accept")
if accept_token is None:
raise ProtocolError(
"The WebSocket HTTP extension requires 'Sec-Websocket-Accept' header in the server response but was not present."
)
fake_http_response += accept_token.encode() + b"\r\n"
if "sec-websocket-extensions" in response.headers:
fake_http_response += (
b"Sec-Websocket-Extensions: "
+ response.headers.get("sec-websocket-extensions").encode() # type: ignore[union-attr]
+ b"\r\n"
)
fake_http_response += b"\r\n"
try:
self._protocol.receive_data(fake_http_response)
except WebSocketProtocolError as e:
raise ProtocolError from e # Defensive: should never happen
event = next(self._protocol.events())
if not isinstance(event, AcceptConnection):
raise RuntimeError(
"The WebSocket state-machine did not pass the handshake phase when expected."
)
def headers(self, http_version: HttpVersion) -> dict[str, str]:
"""Specific HTTP headers required (request) before the 101 status response."""
if self._request_headers is not None:
return self._request_headers
try:
raw_data_to_socket = self._protocol.send(
Request(
host="example.com", target="/", extensions=(PerMessageDeflate(),)
)
)
except WebSocketProtocolError as e:
raise ProtocolError from e # Defensive: should never happen
raw_headers = raw_data_to_socket.split(b"\r\n")[2:-2]
request_headers: dict[str, str] = {}
for raw_header in raw_headers:
k, v = raw_header.decode().split(": ")
request_headers[k.lower()] = v
if http_version != HttpVersion.h11:
del request_headers["upgrade"]
del request_headers["connection"]
request_headers[":protocol"] = "websocket"
request_headers[":method"] = "CONNECT"
self._request_headers = request_headers
return request_headers
def close(self) -> None:
"""End/Notify close for sub protocol."""
if self._dsa is not None:
if self._police_officer is not None:
with self._police_officer.borrow(self._response):
if self._remote_shutdown is False:
try:
data_to_send: bytes = self._protocol.send(
CloseConnection(0)
)
except WebSocketProtocolError:
pass
else:
with self._write_error_catcher():
self._dsa.sendall(data_to_send)
self._dsa.close()
self._dsa = None
else:
self._dsa = None
if self._response is not None:
if self._police_officer is not None:
self._police_officer.forget(self._response)
else:
self._response.close()
self._response = None
self._police_officer = None
def next_payload(self) -> str | bytes | None:
"""Unpack the next received message/payload from remote."""
if self._dsa is None or self._response is None or self._police_officer is None:
raise OSError("The HTTP extension is closed or uninitialized")
with self._police_officer.borrow(self._response):
# we may have pending event to unpack!
for event in self._protocol.events():
if isinstance(event, TextMessage):
return event.data
elif isinstance(event, BytesMessage):
return event.data
elif isinstance(event, CloseConnection):
self._remote_shutdown = True
self.close()
return None
elif isinstance(event, Ping):
try:
data_to_send: bytes = self._protocol.send(event.response())
except WebSocketProtocolError as e:
self.close()
raise ProtocolError from e
with self._write_error_catcher():
self._dsa.sendall(data_to_send)
while True:
with self._read_error_catcher():
data, eot, _ = self._dsa.recv_extended(None)
try:
self._protocol.receive_data(data)
except WebSocketProtocolError as e:
self.close()
raise ProtocolError from e
for event in self._protocol.events():
if isinstance(event, TextMessage):
return event.data
elif isinstance(event, BytesMessage):
return event.data
elif isinstance(event, CloseConnection):
self._remote_shutdown = True
self.close()
return None
elif isinstance(event, Ping):
try:
data_to_send = self._protocol.send(event.response())
except WebSocketProtocolError as e:
self.close()
raise ProtocolError from e
with self._write_error_catcher():
self._dsa.sendall(data_to_send)
elif isinstance(event, Pong):
continue
def send_payload(self, buf: str | bytes) -> None:
"""Dispatch a buffer to remote."""
if self._dsa is None or self._response is None or self._police_officer is None:
raise OSError("The HTTP extension is closed or uninitialized")
with self._police_officer.borrow(self._response):
try:
if isinstance(buf, str):
data_to_send: bytes = self._protocol.send(TextMessage(buf))
else:
data_to_send = self._protocol.send(BytesMessage(buf))
except WebSocketProtocolError as e:
self.close()
raise ProtocolError from e
with self._write_error_catcher():
self._dsa.sendall(data_to_send)
def ping(self) -> None:
if self._dsa is None or self._response is None or self._police_officer is None:
raise OSError("The HTTP extension is closed or uninitialized")
with self._police_officer.borrow(self._response):
if self._remote_shutdown is False:
try:
data_to_send: bytes = self._protocol.send(Ping())
except WebSocketProtocolError as e:
self.close()
raise ProtocolError from e
with self._write_error_catcher():
self._dsa.sendall(data_to_send)
@staticmethod
def supported_schemes() -> set[str]:
return {"ws", "wss"}
@staticmethod
def scheme_to_http_scheme(scheme: str) -> str:
return {"ws": "http", "wss": "https"}[scheme]
class WebSocketExtensionFromMultiplexedHTTP(WebSocketExtensionFromHTTP):
"""
Plugin that support doing WebSocket over HTTP 2 and 3.
This implement RFC8441. Beware that this isn't actually supported by much server around internet.
"""
@staticmethod
def implementation() -> str:
return "rfc8441" # also known as rfc9220 (http3)
@staticmethod
def supported_svn() -> set[HttpVersion]:
return {HttpVersion.h11, HttpVersion.h2, HttpVersion.h3}

View File

@@ -0,0 +1,373 @@
from __future__ import annotations
import socket
import typing
from email.errors import MessageDefect
if typing.TYPE_CHECKING:
from ._async.connection import AsyncHTTPConnection
from ._async.connectionpool import AsyncConnectionPool
from ._async.response import AsyncHTTPResponse
from ._typing import _TYPE_REDUCE_RESULT
from .backend import ResponsePromise
from .connection import HTTPConnection
from .connectionpool import ConnectionPool
from .response import HTTPResponse
from .util.retry import Retry
# Base Exceptions
try: # Compiled with SSL?
import ssl
BaseSSLError = ssl.SSLError
except (ImportError, AttributeError):
ssl = None # type: ignore[assignment]
class BaseSSLError(BaseException): # type: ignore[no-redef]
pass
class HTTPError(Exception):
"""Base exception used by this module."""
class HTTPWarning(Warning):
"""Base warning used by this module."""
class PoolError(HTTPError):
"""Base exception for errors caused within a pool."""
def __init__(
self, pool: ConnectionPool | AsyncConnectionPool, message: str
) -> None:
self.pool = pool
super().__init__(f"{pool}: {message}")
def __reduce__(self) -> _TYPE_REDUCE_RESULT:
# For pickling purposes.
return self.__class__, (None, None)
class RequestError(PoolError):
"""Base exception for PoolErrors that have associated URLs."""
def __init__(
self, pool: ConnectionPool | AsyncConnectionPool, url: str, message: str
) -> None:
self.url = url
super().__init__(pool, message)
def __reduce__(self) -> _TYPE_REDUCE_RESULT:
# For pickling purposes.
return self.__class__, (None, self.url, None)
class SSLError(HTTPError):
"""Raised when SSL certificate fails in an HTTPS connection."""
class ProxyError(HTTPError):
"""Raised when the connection to a proxy fails."""
# The original error is also available as __cause__.
original_error: Exception
def __init__(self, message: str, error: Exception) -> None:
super().__init__(message, error)
self.original_error = error
class DecodeError(HTTPError):
"""Raised when automatic decoding based on Content-Type fails."""
class ProtocolError(HTTPError):
"""Raised when something unexpected happens mid-request/response."""
#: Renamed to ProtocolError but aliased for backwards compatibility.
ConnectionError = ProtocolError
# Leaf Exceptions
class MaxRetryError(RequestError):
"""Raised when the maximum number of retries is exceeded.
:param pool: The connection pool
:type pool: :class:`~urllib3.connectionpool.HTTPConnectionPool`
:param str url: The requested Url
:param reason: The underlying error
:type reason: :class:`Exception`
"""
def __init__(
self,
pool: ConnectionPool | AsyncConnectionPool,
url: str,
reason: Exception | None = None,
) -> None:
self.reason = reason
message = f"Max retries exceeded with url: {url} (Caused by {reason!r})"
super().__init__(pool, url, message)
class HostChangedError(RequestError):
"""Raised when an existing pool gets a request for a foreign host."""
def __init__(
self,
pool: ConnectionPool | AsyncConnectionPool,
url: str,
retries: Retry | int = 3,
) -> None:
message = f"Tried to open a foreign host with url: {url}"
super().__init__(pool, url, message)
self.retries = retries
class TimeoutStateError(HTTPError):
"""Raised when passing an invalid state to a timeout"""
class TimeoutError(HTTPError):
"""Raised when a socket timeout error occurs.
Catching this error will catch both :exc:`ReadTimeoutErrors
<ReadTimeoutError>` and :exc:`ConnectTimeoutErrors <ConnectTimeoutError>`.
"""
class ReadTimeoutError(TimeoutError, RequestError):
"""Raised when a socket timeout occurs while receiving data from a server"""
# This timeout error does not have a URL attached and needs to inherit from the
# base HTTPError
class ConnectTimeoutError(TimeoutError):
"""Raised when a socket timeout occurs while connecting to a server"""
class NewConnectionError(ConnectTimeoutError, HTTPError):
"""Raised when we fail to establish a new connection. Usually ECONNREFUSED."""
def __init__(
self,
conn: HTTPConnection
| AsyncHTTPConnection
| ConnectionPool
| AsyncConnectionPool,
message: str,
) -> None:
self.conn = conn
super().__init__(f"{conn}: {message}")
def __reduce__(self) -> _TYPE_REDUCE_RESULT:
# For pickling purposes.
return self.__class__, (None, None)
class NameResolutionError(NewConnectionError):
"""Raised when host name resolution fails."""
def __init__(
self,
host: str,
conn: HTTPConnection
| AsyncHTTPConnection
| ConnectionPool
| AsyncConnectionPool,
reason: socket.gaierror,
):
message = f"Failed to resolve '{host}' ({reason})"
super().__init__(conn, message)
def __reduce__(self) -> _TYPE_REDUCE_RESULT:
# For pickling purposes.
return self.__class__, (None, None, None)
class EmptyPoolError(PoolError):
"""Raised when a pool runs out of connections and no more are allowed."""
class FullPoolError(PoolError):
"""Raised when we try to add a connection to a full pool in blocking mode."""
class ClosedPoolError(PoolError):
"""Raised when a request enters a pool after the pool has been closed."""
class LocationValueError(ValueError, HTTPError):
"""Raised when there is something wrong with a given URL input."""
class LocationParseError(LocationValueError):
"""Raised when get_host or similar fails to parse the URL input."""
def __init__(self, location: str) -> None:
message = f"Failed to parse: {location}"
super().__init__(message)
self.location = location
class URLSchemeUnknown(LocationValueError):
"""Raised when a URL input has an unsupported scheme."""
def __init__(self, scheme: str):
message = f"Not supported URL scheme {scheme}"
super().__init__(message)
self.scheme = scheme
class ResponseError(HTTPError):
"""Used as a container for an error reason supplied in a MaxRetryError."""
GENERIC_ERROR = "too many error responses"
SPECIFIC_ERROR = "too many {status_code} error responses"
class SecurityWarning(HTTPWarning):
"""Warned when performing security reducing actions"""
class InsecureRequestWarning(SecurityWarning):
"""Warned when making an unverified HTTPS request."""
class NotOpenSSLWarning(SecurityWarning):
"""Warned when using unsupported SSL library"""
class SystemTimeWarning(SecurityWarning):
"""Warned when system time is suspected to be wrong"""
class InsecurePlatformWarning(SecurityWarning):
"""Warned when certain TLS/SSL configuration is not available on a platform."""
class DependencyWarning(HTTPWarning):
"""
Warned when an attempt is made to import a module with missing optional
dependencies.
"""
class ResponseNotChunked(ProtocolError, ValueError):
"""Response needs to be chunked in order to read it as chunks."""
class BodyNotHttplibCompatible(HTTPError):
"""
Body should be :class:`http.client.HTTPResponse` like
(have an fp attribute which returns raw chunks) for read_chunked().
"""
class IncompleteRead(ProtocolError):
"""
Response length doesn't match expected Content-Length
Subclass of :class:`http.client.IncompleteRead` to allow int value
for ``partial`` to avoid creating large objects on streamed reads.
"""
def __init__(self, partial: int, expected: int | None = None) -> None:
self.partial = partial
self.expected = expected
def __repr__(self) -> str:
if self.expected is not None:
return f"IncompleteRead({self.partial} bytes read, {self.expected} more expected)"
return f"IncompleteRead({self.partial} bytes read)"
__str__ = object.__str__
class InvalidChunkLength(ProtocolError):
"""Invalid chunk length in a chunked response."""
def __init__(
self, response: HTTPResponse | AsyncHTTPResponse, length: bytes
) -> None:
self.partial: int = response.tell()
self.expected: int | None = response.length_remaining
self.response = response
self.length = length
def __repr__(self) -> str:
return "InvalidChunkLength(got length %r, %i bytes read)" % (
self.length,
self.partial,
)
class InvalidHeader(HTTPError):
"""The header provided was somehow invalid."""
class ProxySchemeUnknown(AssertionError, URLSchemeUnknown):
"""ProxyManager does not support the supplied scheme"""
# TODO(t-8ch): Stop inheriting from AssertionError in v2.0.
def __init__(self, scheme: str | None) -> None:
# 'localhost' is here because our URL parser parses
# localhost:8080 -> scheme=localhost, remove if we fix this.
if scheme == "localhost":
scheme = None
if scheme is None:
message = "Proxy URL had no scheme, should start with http:// or https://"
else:
message = f"Proxy URL had unsupported scheme {scheme}, should use http:// or https://"
super().__init__(message)
class ProxySchemeUnsupported(ValueError):
"""Fetching HTTPS resources through HTTPS proxies is unsupported"""
class HeaderParsingError(HTTPError):
"""Raised by assert_header_parsing, but we convert it to a log.warning statement."""
def __init__(
self, defects: list[MessageDefect], unparsed_data: bytes | str | None
) -> None:
message = f"{defects or 'Unknown'}, unparsed data: {unparsed_data!r}"
super().__init__(message)
class UnrewindableBodyError(HTTPError):
"""urllib3 encountered an error when trying to rewind a body"""
class EarlyResponse(HTTPError):
"""urllib3 received a response prior to sending the whole body"""
def __init__(self, promise: ResponsePromise) -> None:
self.promise = promise
class ResponseNotReady(HTTPError):
"""Kept for BC"""
class RecoverableError(HTTPError):
"""This error is never leaked in the upper stack, it serves only an internal purpose."""
class MustDowngradeError(RecoverableError):
"""An error occurred with a protocol and can be circumvented using an older protocol."""
class MustRedialError(RecoverableError):
"""Unused legacy exception. Remove it in a next major."""

View File

@@ -0,0 +1,283 @@
from __future__ import annotations
import email.utils
import mimetypes
import typing
if typing.TYPE_CHECKING:
from ._typing import _TYPE_FIELD_VALUE, _TYPE_FIELD_VALUE_TUPLE
def guess_content_type(
filename: str | None, default: str = "application/octet-stream"
) -> str:
"""
Guess the "Content-Type" of a file.
:param filename:
The filename to guess the "Content-Type" of using :mod:`mimetypes`.
:param default:
If no "Content-Type" can be guessed, default to `default`.
"""
if filename:
return mimetypes.guess_type(filename)[0] or default
return default
def format_header_param_rfc2231(name: str, value: _TYPE_FIELD_VALUE) -> str:
"""
Helper function to format and quote a single header parameter using the
strategy defined in RFC 2231.
Particularly useful for header parameters which might contain
non-ASCII values, like file names. This follows
`RFC 2388 Section 4.4 <https://tools.ietf.org/html/rfc2388#section-4.4>`_.
:param name:
The name of the parameter, a string expected to be ASCII only.
:param value:
The value of the parameter, provided as ``bytes`` or `str``.
:returns:
An RFC-2231-formatted unicode string.
.. deprecated:: 2.0.0
Will be removed in urllib3 v2.1.0. This is not valid for
``multipart/form-data`` header parameters.
"""
if isinstance(value, bytes):
value = value.decode("utf-8")
if not any(ch in value for ch in '"\\\r\n'):
result = f'{name}="{value}"'
try:
result.encode("ascii")
except (UnicodeEncodeError, UnicodeDecodeError):
pass
else:
return result
value = email.utils.encode_rfc2231(value, "utf-8")
value = f"{name}*={value}"
return value
def format_multipart_header_param(name: str, value: _TYPE_FIELD_VALUE) -> str:
"""
Format and quote a single multipart header parameter.
This follows the `WHATWG HTML Standard`_ as of 2021/06/10, matching
the behavior of current browser and curl versions. Values are
assumed to be UTF-8. The ``\\n``, ``\\r``, and ``"`` characters are
percent encoded.
.. _WHATWG HTML Standard:
https://html.spec.whatwg.org/multipage/
form-control-infrastructure.html#multipart-form-data
:param name:
The name of the parameter, an ASCII-only ``str``.
:param value:
The value of the parameter, a ``str`` or UTF-8 encoded
``bytes``.
:returns:
A string ``name="value"`` with the escaped value.
.. versionchanged:: 2.0.0
Matches the WHATWG HTML Standard as of 2021/06/10. Control
characters are no longer percent encoded.
.. versionchanged:: 2.0.0
Renamed from ``format_header_param_html5`` and
``format_header_param``. The old names will be removed in
urllib3 v2.1.0.
"""
if isinstance(value, bytes):
value = value.decode("utf-8")
# percent encode \n \r "
value = value.translate({10: "%0A", 13: "%0D", 34: "%22"})
return f'{name}="{value}"'
class RequestField:
"""
A data container for request body parameters.
:param name:
The name of this request field. Must be unicode.
:param data:
The data/value body.
:param filename:
An optional filename of the request field. Must be unicode.
:param headers:
An optional dict-like object of headers to initially use for the field.
.. versionchanged:: 2.0.0
The ``header_formatter`` parameter is deprecated and will
be removed in urllib3 v2.1.0.
"""
def __init__(
self,
name: str,
data: _TYPE_FIELD_VALUE,
filename: str | None = None,
headers: typing.Mapping[str, str] | None = None,
header_formatter: typing.Callable[[str, _TYPE_FIELD_VALUE], str] | None = None,
):
self._name = name
self._filename = filename
self.data = data
self.headers: dict[str, str | None] = {}
if headers:
self.headers = dict(headers)
if header_formatter is not None:
self.header_formatter = header_formatter
else:
self.header_formatter = format_multipart_header_param
@classmethod
def from_tuples(
cls,
fieldname: str,
value: _TYPE_FIELD_VALUE_TUPLE,
header_formatter: typing.Callable[[str, _TYPE_FIELD_VALUE], str] | None = None,
) -> RequestField:
"""
A :class:`~urllib3.fields.RequestField` factory from old-style tuple parameters.
Supports constructing :class:`~urllib3.fields.RequestField` from
parameter of key/value strings AND key/filetuple. A filetuple is a
(filename, data, MIME type) tuple where the MIME type is optional.
For example::
'foo': 'bar',
'fakefile': ('foofile.txt', 'contents of foofile'),
'realfile': ('barfile.txt', open('realfile').read()),
'typedfile': ('bazfile.bin', open('bazfile').read(), 'image/jpeg'),
'nonamefile': 'contents of nonamefile field',
Field names and filenames must be unicode.
"""
filename: str | None
content_type: str | None
data: _TYPE_FIELD_VALUE
if isinstance(value, tuple):
if len(value) == 3:
filename, data, content_type = value
else:
filename, data = value
content_type = guess_content_type(filename)
else:
filename = None
content_type = None
data = value
request_param = cls(
fieldname, data, filename=filename, header_formatter=header_formatter
)
request_param.make_multipart(content_type=content_type)
return request_param
def _render_part(self, name: str, value: _TYPE_FIELD_VALUE) -> str:
"""
Override this method to change how each multipart header
parameter is formatted. By default, this calls
:func:`format_multipart_header_param`.
:param name:
The name of the parameter, an ASCII-only ``str``.
:param value:
The value of the parameter, a ``str`` or UTF-8 encoded
``bytes``.
:meta public:
"""
return self.header_formatter(name, value)
def _render_parts(
self,
header_parts: (
dict[str, _TYPE_FIELD_VALUE | None]
| typing.Sequence[tuple[str, _TYPE_FIELD_VALUE | None]]
),
) -> str:
"""
Helper function to format and quote a single header.
Useful for single headers that are composed of multiple items. E.g.,
'Content-Disposition' fields.
:param header_parts:
A sequence of (k, v) tuples or a :class:`dict` of (k, v) to format
as `k1="v1"; k2="v2"; ...`.
"""
iterable: typing.Iterable[tuple[str, _TYPE_FIELD_VALUE | None]]
parts = []
if isinstance(header_parts, dict):
iterable = header_parts.items()
else:
iterable = header_parts
for name, value in iterable:
if value is not None:
parts.append(self._render_part(name, value))
return "; ".join(parts)
def render_headers(self) -> str:
"""
Renders the headers for this request field.
"""
lines = []
sort_keys = ["Content-Disposition", "Content-Type", "Content-Location"]
for sort_key in sort_keys:
if self.headers.get(sort_key, False):
lines.append(f"{sort_key}: {self.headers[sort_key]}")
for header_name, header_value in self.headers.items():
if header_name not in sort_keys:
if header_value:
lines.append(f"{header_name}: {header_value}")
lines.append("\r\n")
return "\r\n".join(lines)
def make_multipart(
self,
content_disposition: str | None = None,
content_type: str | None = None,
content_location: str | None = None,
) -> None:
"""
Makes this request field into a multipart request field.
This method overrides "Content-Disposition", "Content-Type" and
"Content-Location" headers to the request parameter.
:param content_disposition:
The 'Content-Disposition' of the request body. Defaults to 'form-data'
:param content_type:
The 'Content-Type' of the request body.
:param content_location:
The 'Content-Location' of the request body.
"""
content_disposition = (content_disposition or "form-data") + "; ".join(
[
"",
self._render_parts(
(("name", self._name), ("filename", self._filename))
),
]
)
self.headers["Content-Disposition"] = content_disposition
self.headers["Content-Type"] = content_type
self.headers["Content-Location"] = content_location

View File

@@ -0,0 +1,82 @@
from __future__ import annotations
import binascii
import codecs
import os
import typing
from io import BytesIO
from ._typing import _TYPE_FIELD_VALUE_TUPLE, _TYPE_FIELDS
from .fields import RequestField
writer = codecs.lookup("utf-8")[3]
def choose_boundary() -> str:
"""
Our embarrassingly-simple replacement for mimetools.choose_boundary.
"""
return binascii.hexlify(os.urandom(16)).decode()
def iter_field_objects(fields: _TYPE_FIELDS) -> typing.Iterable[RequestField]:
"""
Iterate over fields.
Supports list of (k, v) tuples and dicts, and lists of
:class:`~urllib3.fields.RequestField`.
"""
iterable: typing.Iterable[RequestField | tuple[str, _TYPE_FIELD_VALUE_TUPLE]]
if isinstance(fields, typing.Mapping):
iterable = fields.items()
else:
iterable = fields
for field in iterable:
if isinstance(field, RequestField):
yield field
else:
yield RequestField.from_tuples(*field)
def encode_multipart_formdata(
fields: _TYPE_FIELDS, boundary: str | None = None
) -> tuple[bytes, str]:
"""
Encode a dictionary of ``fields`` using the multipart/form-data MIME format.
:param fields:
Dictionary of fields or list of (key, :class:`~urllib3.fields.RequestField`).
Values are processed by :func:`urllib3.fields.RequestField.from_tuples`.
:param boundary:
If not specified, then a random boundary will be generated using
:func:`urllib3.filepost.choose_boundary`.
"""
body = BytesIO()
if boundary is None:
boundary = choose_boundary()
for field in iter_field_objects(fields):
body.write(f"--{boundary}\r\n".encode("latin-1"))
writer(body).write(field.render_headers())
data = field.data
if isinstance(data, int):
data = str(data) # Backwards compatibility
if isinstance(data, str):
writer(body).write(data)
else:
body.write(data)
body.write(b"\r\n")
body.write(f"--{boundary}--\r\n".encode("latin-1"))
content_type = f"multipart/form-data; boundary={boundary}"
return body.getvalue(), content_type

View File

@@ -0,0 +1,23 @@
# Dummy file to match upstream modules
# without actually serving them.
# urllib3-future diverged from urllib3.
# only the top-level (public API) are guaranteed to be compatible.
# in-fact urllib3-future propose a better way to migrate/transition toward
# newer protocols.
from __future__ import annotations
import warnings
def inject_into_urllib3() -> None:
warnings.warn(
"urllib3-future do not propose the http2 module as it is useless to us. "
"enjoy HTTP/1.1, HTTP/2, and HTTP/3 without hacks. urllib3-future just works out "
"of the box with all protocols. No hassles.",
UserWarning,
)
def extract_from_urllib3() -> None:
pass

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,2 @@
# Instruct type checkers to look for inline type annotations in this package.
# See PEP 561.

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,41 @@
# For backwards compatibility, provide imports that used to be here.
from __future__ import annotations
from .connection import is_connection_dropped
from .request import SKIP_HEADER, SKIPPABLE_HEADERS, make_headers
from .response import is_fp_closed, parse_alt_svc
from .retry import Retry
from .ssl_ import (
ALPN_PROTOCOLS,
SSLContext,
assert_fingerprint,
create_urllib3_context,
resolve_cert_reqs,
resolve_ssl_version,
ssl_wrap_socket,
)
from .timeout import Timeout
from .url import Url, parse_url
from .wait import wait_for_read, wait_for_write
__all__ = (
"SSLContext",
"ALPN_PROTOCOLS",
"Retry",
"Timeout",
"Url",
"assert_fingerprint",
"create_urllib3_context",
"is_connection_dropped",
"is_fp_closed",
"parse_url",
"make_headers",
"resolve_cert_reqs",
"resolve_ssl_version",
"ssl_wrap_socket",
"wait_for_read",
"wait_for_write",
"SKIP_HEADER",
"SKIPPABLE_HEADERS",
"parse_alt_svc",
)

View File

@@ -0,0 +1,186 @@
from __future__ import annotations
import io
import os
import typing
import warnings
from pathlib import Path
if typing.TYPE_CHECKING:
import ssl
from ...contrib.imcc import load_cert_chain as _ctx_load_cert_chain
from ...contrib.ssa import AsyncSocket, SSLAsyncSocket
from ...exceptions import SSLError
from ..ssl_ import (
ALPN_PROTOCOLS,
_CacheableSSLContext,
_is_key_file_encrypted,
create_urllib3_context,
_KnownCaller,
)
class DummyLock:
def __enter__(self) -> DummyLock:
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore[no-untyped-def]
pass
class _NoLock_CacheableSSLContext(_CacheableSSLContext):
"""Deprecated: we no longer avoid the lock in the async part because we want to allow many loop within many thread..."""
def __init__(self, maxsize: int | None = 32):
super().__init__(maxsize=maxsize)
self._lock = DummyLock() # type: ignore[assignment]
_SSLContextCache = _CacheableSSLContext()
async def ssl_wrap_socket(
sock: AsyncSocket,
keyfile: str | None = None,
certfile: str | None = None,
cert_reqs: int | None = None,
ca_certs: str | None = None,
server_hostname: str | None = None,
ssl_version: int | None = None,
ciphers: str | None = None,
ssl_context: ssl.SSLContext | None = None,
ca_cert_dir: str | None = None,
key_password: str | None = None,
ca_cert_data: None | str | bytes = None,
tls_in_tls: bool = False,
alpn_protocols: list[str] | None = None,
certdata: str | bytes | None = None,
keydata: str | bytes | None = None,
check_hostname: bool | None = None,
ssl_minimum_version: int | None = None,
ssl_maximum_version: int | None = None,
) -> SSLAsyncSocket:
"""
All arguments except for server_hostname, ssl_context, and ca_cert_dir have
the same meaning as they do when using :func:`ssl.wrap_socket`.
:param server_hostname:
When SNI is supported, the expected hostname of the certificate
:param ssl_context:
A pre-made :class:`SSLContext` object. If none is provided, one will
be created using :func:`create_urllib3_context`.
:param ciphers:
A string of ciphers we wish the client to support.
:param ca_cert_dir:
A directory containing CA certificates in multiple separate files, as
supported by OpenSSL's -CApath flag or the capath argument to
SSLContext.load_verify_locations().
:param key_password:
Optional password if the keyfile is encrypted.
:param ca_cert_data:
Optional string containing CA certificates in PEM format suitable for
passing as the cadata parameter to SSLContext.load_verify_locations()
:param tls_in_tls:
No-op in asynchronous mode. Call wrap_socket of the SSLAsyncSocket later.
:param alpn_protocols:
Manually specify other protocols to be announced during tls handshake.
:param certdata:
Specify an in-memory client intermediary certificate for mTLS.
:param keydata:
Specify an in-memory client intermediary key for mTLS.
"""
context = ssl_context
cache_disabled: bool = context is not None
with _SSLContextCache.lock(
keyfile,
certfile if certfile is None else Path(certfile),
cert_reqs,
ca_certs,
ssl_version,
ciphers,
ca_cert_dir if ca_cert_dir is None else Path(ca_cert_dir),
alpn_protocols,
certdata,
keydata,
key_password,
ca_cert_data,
os.getenv("SSLKEYLOGFILE", None),
ssl_minimum_version,
ssl_maximum_version,
check_hostname,
):
cached_ctx = _SSLContextCache.get() if not cache_disabled else None
if cached_ctx is None:
if context is None:
context = create_urllib3_context(
ssl_version,
cert_reqs,
ciphers=ciphers,
caller_id=_KnownCaller.NIQUESTS,
ssl_minimum_version=ssl_minimum_version,
ssl_maximum_version=ssl_maximum_version,
)
if cert_reqs is not None:
context.verify_mode = cert_reqs # type: ignore[assignment]
if check_hostname is not None:
context.check_hostname = check_hostname
if ca_certs or ca_cert_dir or ca_cert_data:
# SSLContext does not support bytes for cadata[...]
if ca_cert_data and isinstance(ca_cert_data, bytes):
ca_cert_data = ca_cert_data.decode()
try:
context.load_verify_locations(ca_certs, ca_cert_dir, ca_cert_data)
except OSError as e:
raise SSLError(e) from e
elif hasattr(context, "load_default_certs"):
store_stats = context.cert_store_stats()
# try to load OS default certs; works well on Windows.
if "x509_ca" not in store_stats or not store_stats["x509_ca"]:
context.load_default_certs()
# Attempt to detect if we get the goofy behavior of the
# keyfile being encrypted and OpenSSL asking for the
# passphrase via the terminal and instead error out.
if keyfile and key_password is None and _is_key_file_encrypted(keyfile):
raise SSLError("Client private key is encrypted, password is required")
if certfile:
if key_password is None:
context.load_cert_chain(certfile, keyfile)
else:
context.load_cert_chain(certfile, keyfile, key_password)
elif certdata and keydata:
try:
_ctx_load_cert_chain(context, certdata, keydata, key_password)
except io.UnsupportedOperation as e:
warnings.warn(
f"""Passing in-memory client/intermediary certificate for mTLS is unsupported on your platform.
Reason: {e}. It will be picked out if you upgrade to a QUIC connection.""",
UserWarning,
)
try:
context.set_alpn_protocols(alpn_protocols or ALPN_PROTOCOLS)
except (
NotImplementedError
): # Defensive: in CI, we always have set_alpn_protocols
pass
if ciphers:
context.set_ciphers(ciphers)
if not cache_disabled:
_SSLContextCache.save(context)
else:
context = cached_ctx
return await sock.wrap_socket(context, server_hostname=server_hostname)

Some files were not shown because too many files have changed in this diff Show More