fix: 포트 충돌 회피 — note_bridge 8098, intent_service 8099

Jellyfin(8096), OrbStack(8097) 포트 충돌으로 변경.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Hyungi Ahn
2026-03-19 13:53:55 +09:00
parent dc08d29509
commit c2257d3a86
2709 changed files with 619549 additions and 10 deletions

View File

@@ -0,0 +1,237 @@
"""
This module contains provisional support for SOCKS proxies from within
urllib3. This module supports SOCKS4, SOCKS4A (an extension of SOCKS4), and
SOCKS5. To enable its functionality, either install PySocks or install this
module with the ``socks`` extra.
The SOCKS implementation supports the full range of urllib3 features. It also
supports the following SOCKS features:
- SOCKS4A (``proxy_url='socks4a://...``)
- SOCKS4 (``proxy_url='socks4://...``)
- SOCKS5 with remote DNS (``proxy_url='socks5h://...``)
- SOCKS5 with local DNS (``proxy_url='socks5://...``)
- Usernames and passwords for the SOCKS proxy
.. note::
It is recommended to use ``socks5h://`` or ``socks4a://`` schemes in
your ``proxy_url`` to ensure that DNS resolution is done from the remote
server instead of client-side when connecting to a domain name.
SOCKS4 supports IPv4 and domain names with the SOCKS4A extension. SOCKS5
supports IPv4, IPv6, and domain names.
When connecting to a SOCKS4 proxy the ``username`` portion of the ``proxy_url``
will be sent as the ``userid`` section of the SOCKS request:
.. code-block:: python
proxy_url="socks4a://<userid>@proxy-host"
When connecting to a SOCKS5 proxy the ``username`` and ``password`` portion
of the ``proxy_url`` will be sent as the username/password to authenticate
with the proxy:
.. code-block:: python
proxy_url="socks5h://<username>:<password>@proxy-host"
"""
from __future__ import annotations
try:
import socks
except ImportError:
import warnings
from ..exceptions import DependencyWarning
warnings.warn(
(
"SOCKS support in urllib3 requires the installation of optional "
"dependencies: specifically, PySocks. For more information, see "
"https://urllib3.readthedocs.io/en/latest/contrib.html#socks-proxies"
),
DependencyWarning,
)
raise
import typing
from socket import timeout as SocketTimeout
from .._typing import _TYPE_SOCKS_OPTIONS
from ..backend import HttpVersion
from ..connection import HTTPConnection, HTTPSConnection
from ..connectionpool import HTTPConnectionPool, HTTPSConnectionPool
from ..exceptions import ConnectTimeoutError, NewConnectionError
from ..poolmanager import PoolManager
from ..util.url import parse_url
try:
import ssl
except ImportError:
ssl = None # type: ignore[assignment]
class SOCKSConnection(HTTPConnection):
"""
A plain-text HTTP connection that connects via a SOCKS proxy.
"""
def __init__(
self,
_socks_options: _TYPE_SOCKS_OPTIONS,
*args: typing.Any,
**kwargs: typing.Any,
) -> None:
self._socks_options = _socks_options
super().__init__(*args, **kwargs)
def _new_conn(self) -> socks.socksocket:
"""
Establish a new connection via the SOCKS proxy.
"""
extra_kw: dict[str, typing.Any] = {}
if self.source_address:
extra_kw["source_address"] = self.source_address
if self.socket_options:
only_tcp_options = []
for opt in self.socket_options:
if len(opt) == 3:
only_tcp_options.append(opt)
elif len(opt) == 4:
protocol: str = opt[3].lower()
if protocol == "udp":
continue
only_tcp_options.append(opt[:3])
extra_kw["socket_options"] = only_tcp_options
try:
conn = socks.create_connection(
(self.host, self.port),
proxy_type=self._socks_options["socks_version"], # type: ignore[arg-type]
proxy_addr=self._socks_options["proxy_host"],
proxy_port=self._socks_options["proxy_port"], # type: ignore[arg-type]
proxy_username=self._socks_options["username"],
proxy_password=self._socks_options["password"],
proxy_rdns=self._socks_options["rdns"],
timeout=self.timeout, # type: ignore[arg-type]
**extra_kw,
)
except SocketTimeout as e:
raise ConnectTimeoutError(
self,
f"Connection to {self.host} timed out. (connect timeout={self.timeout})",
) from e
except socks.ProxyError as e:
# This is fragile as hell, but it seems to be the only way to raise
# useful errors here.
if e.socket_err:
error = e.socket_err
if isinstance(error, SocketTimeout):
raise ConnectTimeoutError(
self,
f"Connection to {self.host} timed out. (connect timeout={self.timeout})",
) from e
else:
# Adding `from e` messes with coverage somehow, so it's omitted.
# See #2386.
raise NewConnectionError(
self, f"Failed to establish a new connection: {error}"
)
else:
raise NewConnectionError(
self, f"Failed to establish a new connection: {e}"
) from e
except OSError as e: # Defensive: PySocks should catch all these.
raise NewConnectionError(
self, f"Failed to establish a new connection: {e}"
) from e
return conn
# We don't need to duplicate the Verified/Unverified distinction from
# urllib3/connection.py here because the HTTPSConnection will already have been
# correctly set to either the Verified or Unverified form by that module. This
# means the SOCKSHTTPSConnection will automatically be the correct type.
class SOCKSHTTPSConnection(SOCKSConnection, HTTPSConnection):
pass
class SOCKSHTTPConnectionPool(HTTPConnectionPool):
ConnectionCls = SOCKSConnection
class SOCKSHTTPSConnectionPool(HTTPSConnectionPool):
ConnectionCls = SOCKSHTTPSConnection
class SOCKSProxyManager(PoolManager):
"""
A version of the urllib3 ProxyManager that routes connections via the
defined SOCKS proxy.
"""
pool_classes_by_scheme = {
"http": SOCKSHTTPConnectionPool,
"https": SOCKSHTTPSConnectionPool,
}
def __init__(
self,
proxy_url: str,
username: str | None = None,
password: str | None = None,
num_pools: int = 10,
headers: typing.Mapping[str, str] | None = None,
**connection_pool_kw: typing.Any,
):
parsed = parse_url(proxy_url)
if username is None and password is None and parsed.auth is not None:
split = parsed.auth.split(":")
if len(split) == 2:
username, password = split
if parsed.scheme == "socks5":
socks_version = socks.PROXY_TYPE_SOCKS5
rdns = False
elif parsed.scheme == "socks5h":
socks_version = socks.PROXY_TYPE_SOCKS5
rdns = True
elif parsed.scheme == "socks4":
socks_version = socks.PROXY_TYPE_SOCKS4
rdns = False
elif parsed.scheme == "socks4a":
socks_version = socks.PROXY_TYPE_SOCKS4
rdns = True
else:
raise ValueError(f"Unable to determine SOCKS version from {proxy_url}")
self.proxy_url = proxy_url
socks_options = {
"socks_version": socks_version,
"proxy_host": parsed.host,
"proxy_port": parsed.port,
"username": username,
"password": password,
"rdns": rdns,
}
connection_pool_kw["_socks_options"] = socks_options
if "disabled_svn" not in connection_pool_kw:
connection_pool_kw["disabled_svn"] = set()
connection_pool_kw["disabled_svn"].add(HttpVersion.h3)
super().__init__(num_pools, headers, **connection_pool_kw)
self.pool_classes_by_scheme = SOCKSProxyManager.pool_classes_by_scheme

View File

@@ -0,0 +1,173 @@
"""
This is hazmat. It can blow up anytime.
Use it with precautions!
Reasoning behind this:
1) python-socks requires another dependency, namely asyncio-timeout, that is one too much for us.
2) it does not support our AsyncSocket wrapper (it has his own internally)
"""
from __future__ import annotations
import asyncio
import socket
import typing
import warnings
from python_socks import _abc as abc
# look the other way if unpleasant. No choice for now.
# will start discussions once we have a solid traffic.
from python_socks._connectors.abc import AsyncConnector
from python_socks._connectors.socks4_async import Socks4AsyncConnector
from python_socks._connectors.socks5_async import Socks5AsyncConnector
from python_socks._errors import ProxyError, ProxyTimeoutError
from python_socks._helpers import parse_proxy_url
from python_socks._protocols.errors import ReplyError
from python_socks._types import ProxyType
from .ssa import AsyncSocket
from .ssa._timeout import timeout as timeout_
class Resolver(abc.AsyncResolver):
def __init__(self, loop: asyncio.AbstractEventLoop):
self._loop = loop
async def resolve(
self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_UNSPEC
) -> tuple[socket.AddressFamily, str]:
infos = await self._loop.getaddrinfo(
host=host,
port=port,
family=family,
type=socket.SOCK_STREAM,
)
if not infos: # Defensive:
raise OSError(f"Can`t resolve address {host}:{port} [{family}]")
infos = sorted(infos, key=lambda info: info[0])
family, _, _, _, address = infos[0]
return family, address[0]
def create_connector(
proxy_type: ProxyType,
username: str | None,
password: str | None,
rdns: bool,
resolver: abc.AsyncResolver,
) -> AsyncConnector:
if proxy_type == ProxyType.SOCKS4:
return Socks4AsyncConnector(
user_id=username,
rdns=rdns,
resolver=resolver,
)
if proxy_type == ProxyType.SOCKS5:
return Socks5AsyncConnector(
username=username,
password=password,
rdns=rdns,
resolver=resolver,
)
raise ValueError(f"Invalid proxy type: {proxy_type}")
class AsyncioProxy:
def __init__(
self,
proxy_type: ProxyType,
host: str,
port: int,
username: str | None = None,
password: str | None = None,
rdns: bool = False,
):
self._loop = asyncio.get_event_loop()
self._proxy_type = proxy_type
self._proxy_host = host
self._proxy_port = port
self._password = password
self._username = username
self._rdns = rdns
self._resolver = Resolver(loop=self._loop)
async def connect(
self,
dest_host: str,
dest_port: int,
timeout: float | None = None,
_socket: AsyncSocket | None = None,
) -> AsyncSocket:
if timeout is None:
timeout = 60
try:
async with timeout_(timeout):
# our dependency started to deprecate passing "_socket"
# which is ... vital for our integration. We'll start by silencing the warning.
# then we'll think on how to proceed.
# A) the maintainer agrees to revert https://github.com/romis2012/python-socks/commit/173a7390469c06aa033f8dca67c827854b462bc3#diff-e4086fa970d1c98b1eb341e58cb70e9ceffe7391b2feecc4b66c7e92ea2de76fR64
# B) the maintainer pursue the removal -> do we vendor our copy of python-socks? is there an alternative?
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
return await self._connect(
dest_host=dest_host,
dest_port=dest_port,
_socket=_socket, # type: ignore[arg-type]
)
except asyncio.TimeoutError as e:
raise ProxyTimeoutError(f"Proxy connection timed out: {timeout}") from e
async def _connect(
self, dest_host: str, dest_port: int, _socket: AsyncSocket
) -> AsyncSocket:
try:
connector = create_connector(
proxy_type=self._proxy_type,
username=self._username,
password=self._password,
rdns=self._rdns,
resolver=self._resolver,
)
await connector.connect(
stream=_socket, # type: ignore[arg-type]
host=dest_host,
port=dest_port,
)
return _socket
except asyncio.CancelledError: # Defensive:
_socket.close()
raise
except ReplyError as e:
_socket.close()
raise ProxyError(e, error_code=e.error_code) # type: ignore[no-untyped-call]
except Exception: # Defensive:
_socket.close()
raise
@property
def proxy_host(self) -> str:
return self._proxy_host
@property
def proxy_port(self) -> int:
return self._proxy_port
@classmethod
def create(cls, *args: typing.Any, **kwargs: typing.Any) -> AsyncioProxy:
return cls(*args, **kwargs)
@classmethod
def from_url(cls, url: str, **kwargs: typing.Any) -> AsyncioProxy:
url_args = parse_proxy_url(url)
return cls(*url_args, **kwargs)

View File

@@ -0,0 +1,27 @@
# Dummy file to match upstream modules
# without actually serving them.
# urllib3-future diverged from urllib3.
# only the top-level (public API) are guaranteed to be compatible.
# in-fact urllib3-future propose a better way to migrate/transition toward
# newer protocols.
from __future__ import annotations
import warnings
def inject_into_urllib3() -> None:
warnings.warn(
(
"urllib3-future does not support WASM / Emscripten platform. "
"Please reinstall legacy urllib3 in the meantime. "
"Run `pip uninstall -y urllib3 urllib3-future` then "
"`pip install urllib3-future`, finally `pip install urllib3`. "
"Sorry for the inconvenience."
),
DeprecationWarning,
)
def extract_from_urllib3() -> None:
pass

View File

@@ -0,0 +1,39 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from ._configuration import QuicTLSConfig
from .protocols import (
HTTP1Protocol,
HTTP2Protocol,
HTTP3Protocol,
HTTPOverQUICProtocol,
HTTPOverTCPProtocol,
HTTPProtocol,
HTTPProtocolFactory,
)
__all__ = (
"QuicTLSConfig",
"HTTP1Protocol",
"HTTP2Protocol",
"HTTP3Protocol",
"HTTPOverQUICProtocol",
"HTTPOverTCPProtocol",
"HTTPProtocol",
"HTTPProtocolFactory",
)

View File

@@ -0,0 +1,59 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import dataclasses
from typing import Any, Mapping
@dataclasses.dataclass
class QuicTLSConfig:
"""
Client TLS configuration.
"""
#: Allows to proceed for server without valid TLS certificates.
insecure: bool = False
#: File with CA certificates to trust for server verification
cafile: str | None = None
#: Directory with CA certificates to trust for server verification
capath: str | None = None
#: Blob with CA certificates to trust for server verification
cadata: bytes | None = None
#: If provided, will trigger an additional load_cert_chain() upon the QUIC Configuration
certfile: str | bytes | None = None
keyfile: str | bytes | None = None
keypassword: str | bytes | None = None
#: The QUIC session ticket which should be used for session resumption
session_ticket: Any | None = None
cert_fingerprint: str | None = None
cert_use_common_name: bool = False
verify_hostname: bool = True
assert_hostname: str | None = None
ciphers: list[Mapping[str, Any]] | None = None
idle_timeout: float = 300.0

View File

@@ -0,0 +1,151 @@
from __future__ import annotations
import typing
from collections import deque
from .events import Event
class StreamMatrix:
"""Efficient way to store events for concurrent streams."""
__slots__ = (
"_matrix",
"_count",
"_event_cursor_id",
)
def __init__(self) -> None:
self._matrix: dict[int | None, deque[Event]] = {}
self._count: int = 0
self._event_cursor_id: int = 0
def __len__(self) -> int:
return self._count
def __bool__(self) -> bool:
return self._count > 0
@property
def streams(self) -> list[int]:
return sorted(i for i in self._matrix.keys() if i is not None)
def append(self, event: Event) -> None:
matrix_idx = getattr(event, "stream_id", None)
event._id = self._event_cursor_id
self._event_cursor_id += 1
if matrix_idx not in self._matrix:
self._matrix[matrix_idx] = deque()
self._matrix[matrix_idx].append(event)
self._count += 1
def extend(self, events: typing.Iterable[Event]) -> None:
triaged_events: dict[int | None, list[Event]] = {}
for event in events:
matrix_idx = getattr(event, "stream_id", None)
event._id = self._event_cursor_id
self._event_cursor_id += 1
self._count += 1
if matrix_idx not in triaged_events:
triaged_events[matrix_idx] = []
triaged_events[matrix_idx].append(event)
for k, v in triaged_events.items():
if k not in self._matrix:
self._matrix[k] = deque()
self._matrix[k].extend(v)
def appendleft(self, event: Event) -> None:
matrix_idx = getattr(event, "stream_id", None)
event._id = self._event_cursor_id
self._event_cursor_id += 1
if matrix_idx not in self._matrix:
self._matrix[matrix_idx] = deque()
self._matrix[matrix_idx].appendleft(event)
self._count += 1
def popleft(self, stream_id: int | None = None) -> Event | None:
if self._count == 0:
return None
have_global_event: bool = None in self._matrix and bool(self._matrix[None])
any_stream_event: bool = (
bool(self._matrix) if not have_global_event else len(self._matrix) > 1
)
if stream_id is None and any_stream_event:
matrix_dict_iter = self._matrix.__iter__()
stream_id = next(matrix_dict_iter)
if stream_id is None:
stream_id = next(matrix_dict_iter)
if (
stream_id is not None
and have_global_event
and stream_id in self._matrix
and self._matrix[None][0]._id < self._matrix[stream_id][0]._id
):
stream_id = None
elif have_global_event is True and stream_id not in self._matrix:
stream_id = None
if stream_id not in self._matrix:
return None
ev = self._matrix[stream_id].popleft()
if ev is not None:
self._count -= 1
if stream_id is not None and not self._matrix[stream_id]:
del self._matrix[stream_id]
return ev
def count(
self,
stream_id: int | None = None,
excl_event: tuple[type[Event], ...] | None = None,
) -> int:
if stream_id is None:
return self._count
if stream_id not in self._matrix:
return 0
return len(
self._matrix[stream_id]
if excl_event is None
else [e for e in self._matrix[stream_id] if not isinstance(e, excl_event)]
)
def has(
self,
stream_id: int | None = None,
excl_event: tuple[type[Event], ...] | None = None,
) -> bool:
if stream_id is None:
return True if self._count else False
if stream_id not in self._matrix:
return False
if excl_event is not None:
return any(
e for e in self._matrix[stream_id] if not isinstance(e, excl_event)
)
return True if self._matrix[stream_id] else False

View File

@@ -0,0 +1,25 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Sequence, Tuple
HeaderType = Tuple[bytes, bytes]
HeadersType = Sequence[HeaderType]
AddressType = Tuple[str, int]
DatagramType = Tuple[bytes, AddressType]

View File

@@ -0,0 +1,43 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from ._events import (
ConnectionTerminated,
DataReceived,
EarlyHeadersReceived,
Event,
GoawayReceived,
HandshakeCompleted,
HeadersReceived,
StreamEvent,
StreamReset,
StreamResetReceived,
)
__all__ = (
"Event",
"ConnectionTerminated",
"GoawayReceived",
"StreamEvent",
"StreamReset",
"StreamResetReceived",
"HeadersReceived",
"DataReceived",
"HandshakeCompleted",
"EarlyHeadersReceived",
)

View File

@@ -0,0 +1,202 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing
from dataclasses import dataclass, field
from .._typing import HeadersType
class Event:
"""
Base class for HTTP events.
This is an abstract base class that should not be initialized.
"""
_id: int
#
# Connection events
#
@dataclass
class ConnectionTerminated(Event):
"""
Connection was terminated.
Extends :class:`.Event`.
"""
#: Reason for closing the connection.
error_code: int = 0
#: Optional message with more information
message: str | None = field(default=None, compare=False)
def __repr__(self) -> str: # Defensive: debug purposes only
cls = type(self).__name__
return f"{cls}(error_code={self.error_code!r}, message={self.message!r})"
@dataclass
class GoawayReceived(Event):
"""
GOAWAY frame was received
Extends :class:`.Event`.
"""
#: Highest stream ID that could be processed.
last_stream_id: int
#: Reason for closing the connection.
error_code: int = 0
def __repr__(self) -> str: # Defensive: debug purposes only
cls = type(self).__name__
return (
f"{cls}(last_stream_id={self.last_stream_id!r}, "
f"error_code={self.error_code!r})"
)
#
# Stream events
#
@dataclass
class StreamEvent(Event):
"""
Event on one HTTP stream.
This is an abstract base class that should not be used directly.
Extends :class:`.Event`.
"""
#: Stream ID
stream_id: int
@dataclass
class StreamReset(StreamEvent):
"""
One stream of an HTTP connection was reset.
When a stream is reset, it must no longer be used, but the parent
connection and other streams are unaffected.
This is an abstract base class that should not be used directly.
More specific subclasses (StreamResetSent or StreamResetReceived)
should be emitted.
Extends :class:`.StreamEvent`.
"""
#: Reason for closing the stream.
error_code: int = 0
end_stream: bool = True
def __repr__(self) -> str: # Defensive: debug purposes only
cls = type(self).__name__
return f"{cls}(stream_id={self.stream_id!r}, error_code={self.error_code!r})"
@dataclass
class StreamResetReceived(StreamReset):
"""
One stream of an HTTP connection was reset by the peer.
This probably means that we did something that the peer does not like.
Extends :class:`.StreamReset`.
"""
@dataclass
class HandshakeCompleted(Event):
alpn_protocol: str | None
def __repr__(self) -> str: # Defensive: debug purposes only
cls = type(self).__name__
return f"{cls}(alpn={self.alpn_protocol})"
@dataclass
class HeadersReceived(StreamEvent):
"""
A frame with HTTP headers was received.
Extends :class:`.StreamEvent`.
"""
#: The received HTTP headers
headers: HeadersType
#: Signals that data will not be sent by the peer over the stream.
end_stream: bool = False
def __repr__(self) -> str: # Defensive: debug purposes only
cls = type(self).__name__
return (
f"{cls}(stream_id={self.stream_id!r}, "
f"len(headers)={len(self.headers)}, end_stream={self.end_stream!r})"
)
@dataclass
class DataReceived(StreamEvent):
"""
A frame with HTTP data was received.
Extends :class:`.StreamEvent`.
"""
#: The received data.
data: bytes
#: Signals that no more data will be sent by the peer over the stream.
end_stream: bool = False
def __repr__(self) -> str: # Defensive: debug purposes only
cls = type(self).__name__
return (
f"{cls}(stream_id={self.stream_id!r}, "
f"len(data)={len(self.data)}, end_stream={self.end_stream!r})"
)
@dataclass
class EarlyHeadersReceived(StreamEvent):
#: The received HTTP headers
headers: HeadersType
def __repr__(self) -> str: # Defensive: debug purposes only
cls = type(self).__name__
return (
f"{cls}(stream_id={self.stream_id!r}, "
f"len(headers)={len(self.headers)}, end_stream=False)"
)
@property
def end_stream(self) -> typing.Literal[False]:
return False

View File

@@ -0,0 +1,37 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from ._factories import HTTPProtocolFactory
from ._protocols import (
HTTP1Protocol,
HTTP2Protocol,
HTTP3Protocol,
HTTPOverQUICProtocol,
HTTPOverTCPProtocol,
HTTPProtocol,
)
__all__ = (
"HTTP1Protocol",
"HTTP2Protocol",
"HTTP3Protocol",
"HTTPOverQUICProtocol",
"HTTPOverTCPProtocol",
"HTTPProtocol",
"HTTPProtocolFactory",
)

View File

@@ -0,0 +1,90 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
HTTP factories create HTTP protools based on defined set of arguments.
We define the :class:`HTTPProtocol` interface to allow interchange
HTTP versions and protocol implementations. But constructors of
the class is not part of the interface. Every implementation
can use a different options to init instances.
Factories unify access to the creation of the protocol instances,
so that clients and servers can swap protocol implementations,
delegating the initialization to factories.
"""
from __future__ import annotations
import importlib
import inspect
from abc import ABCMeta
from typing import Any
from ._protocols import HTTPOverQUICProtocol, HTTPOverTCPProtocol, HTTPProtocol
class HTTPProtocolFactory(metaclass=ABCMeta):
@staticmethod
def new(
type_protocol: type[HTTPProtocol],
implementation: str | None = None,
**kwargs: Any,
) -> HTTPOverQUICProtocol | HTTPOverTCPProtocol:
"""Create a new state-machine that target given protocol type."""
assert type_protocol != HTTPProtocol, (
"HTTPProtocol is ambiguous and cannot be requested in the factory."
)
package_name: str = __name__.split(".")[0]
version_target: str = "".join(
c for c in str(type_protocol).replace(package_name, "") if c.isdigit()
)
module_expr: str = f".protocols.http{version_target}"
if implementation:
module_expr += f"._{implementation.lower()}"
try:
http_module = importlib.import_module(
module_expr, f"{package_name}.contrib.hface"
)
except ImportError as e:
raise NotImplementedError(
f"{type_protocol} cannot be loaded. Tried to import '{module_expr}'."
) from e
implementations: list[
tuple[str, type[HTTPOverQUICProtocol | HTTPOverTCPProtocol]]
] = inspect.getmembers(
http_module,
lambda e: isinstance(e, type)
and issubclass(e, (HTTPOverQUICProtocol, HTTPOverTCPProtocol)),
)
if not implementations:
raise NotImplementedError(
f"{type_protocol} cannot be loaded. "
"No compatible implementation available. "
"Make sure your implementation inherit either from HTTPOverQUICProtocol or HTTPOverTCPProtocol."
)
implementation_target: type[HTTPOverQUICProtocol | HTTPOverTCPProtocol] = (
implementations.pop()[1]
)
return implementation_target(**kwargs)

View File

@@ -0,0 +1,358 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing
from abc import ABCMeta, abstractmethod
from typing import Any, Sequence
if typing.TYPE_CHECKING:
from typing_extensions import Literal
from .._typing import HeadersType
from ..events import Event
class BaseProtocol(metaclass=ABCMeta):
"""Sans-IO common methods whenever it is TCP, UDP or QUIC."""
@abstractmethod
def bytes_received(self, data: bytes) -> None:
"""
Called when some data is received.
"""
raise NotImplementedError
# Sending direction
@abstractmethod
def bytes_to_send(self) -> bytes:
"""
Returns data for sending out of the internal data buffer.
"""
raise NotImplementedError
@abstractmethod
def connection_lost(self) -> None:
"""
Called when the connection is lost or closed.
"""
raise NotImplementedError
def should_wait_remote_flow_control(
self, stream_id: int, amt: int | None = None
) -> bool | None:
"""
Verify if the client should listen network incoming data for
the flow control update purposes.
"""
raise NotImplementedError
def max_frame_size(self) -> int:
"""
Determine if the remote set a limited size for each data frame.
"""
raise NotImplementedError
class OverTCPProtocol(BaseProtocol, metaclass=ABCMeta):
"""
Interface for sans-IO protocols on top TCP.
"""
@abstractmethod
def eof_received(self) -> None:
"""
Called when the other end signals it wont send any more data.
"""
raise NotImplementedError
class OverUDPProtocol(BaseProtocol, metaclass=ABCMeta):
"""
Interface for sans-IO protocols on top UDP.
"""
class OverQUICProtocol(OverUDPProtocol):
@property
@abstractmethod
def connection_ids(self) -> Sequence[bytes]:
"""
QUIC connection IDs
This property can be used to assign UDP packets to QUIC connections.
:return: a sequence of connection IDs
"""
raise NotImplementedError
@property
@abstractmethod
def session_ticket(self) -> Any | None:
raise NotImplementedError
@typing.overload
def getpeercert(self, *, binary_form: Literal[True]) -> bytes: ...
@typing.overload
def getpeercert(self, *, binary_form: Literal[False] = ...) -> dict[str, Any]: ...
@abstractmethod
def getpeercert(self, *, binary_form: bool = False) -> bytes | dict[str, Any]:
raise NotImplementedError
@typing.overload
def getissuercert(self, *, binary_form: Literal[True]) -> bytes | None: ...
@typing.overload
def getissuercert(
self, *, binary_form: Literal[False] = ...
) -> dict[str, Any] | None: ...
@abstractmethod
def getissuercert(
self, *, binary_form: bool = False
) -> bytes | dict[str, Any] | None:
raise NotImplementedError
@abstractmethod
def cipher(self) -> str | None:
raise NotImplementedError
class HTTPProtocol(metaclass=ABCMeta):
"""
Sans-IO representation of an HTTP connection
"""
implementation: str
@staticmethod
@abstractmethod
def exceptions() -> tuple[type[BaseException], ...]:
"""Return exception types that should be handled in your application."""
raise NotImplementedError
@property
@abstractmethod
def multiplexed(self) -> bool:
"""
Whether this connection supports multiple parallel streams.
Returns ``True`` for HTTP/2 and HTTP/3 connections.
"""
raise NotImplementedError
@property
@abstractmethod
def max_stream_count(self) -> int:
"""Determine how much concurrent stream the connection can handle."""
raise NotImplementedError
@abstractmethod
def is_idle(self) -> bool:
"""
Return True if this connection is BOTH available and not doing anything.
"""
raise NotImplementedError
@abstractmethod
def is_available(self) -> bool:
"""
Return whether this connection is capable to open new streams.
"""
raise NotImplementedError
@abstractmethod
def has_expired(self) -> bool:
"""
Return whether this connection is closed or should be closed.
"""
raise NotImplementedError
@abstractmethod
def get_available_stream_id(self) -> int:
"""
Return an ID that can be used to create a new stream.
Use the returned ID with :meth:`.submit_headers` to create the stream.
This method may or may not return one value until that method is called.
:return: stream ID
"""
raise NotImplementedError
@abstractmethod
def submit_headers(
self, stream_id: int, headers: HeadersType, end_stream: bool = False
) -> None:
"""
Submit a frame with HTTP headers.
If this is a client connection, this method starts an HTTP request.
If this is a server connection, it starts an HTTP response.
:param stream_id: stream ID
:param headers: HTTP headers
:param end_stream: whether to close the stream for sending
"""
raise NotImplementedError
@abstractmethod
def submit_data(
self, stream_id: int, data: bytes, end_stream: bool = False
) -> None:
"""
Submit a frame with HTTP data.
:param stream_id: stream ID
:param data: payload
:param end_stream: whether to close the stream for sending
"""
raise NotImplementedError
@abstractmethod
def submit_stream_reset(self, stream_id: int, error_code: int = 0) -> None:
"""
Immediate terminate a stream.
Stream reset is used to request cancellation of a stream
or to indicate that an error condition has occurred.
Use :attr:`.error_codes` to obtain error codes for common problems.
:param stream_id: stream ID
:param error_code: indicates why the stream is being terminated
"""
raise NotImplementedError
@abstractmethod
def submit_close(self, error_code: int = 0) -> None:
"""
Submit graceful close the connection.
Use :attr:`.error_codes` to obtain error codes for common problems.
:param error_code: indicates why the connections is being closed
"""
raise NotImplementedError
@abstractmethod
def next_event(self, stream_id: int | None = None) -> Event | None:
"""
Consume next HTTP event.
:return: an event instance
"""
raise NotImplementedError
def events(self, stream_id: int | None = None) -> typing.Iterator[Event]:
"""
Consume available HTTP events.
:return: an iterator that unpack "next_event" until exhausted.
"""
while True:
ev = self.next_event(stream_id=stream_id)
if ev is None:
break
yield ev
@abstractmethod
def has_pending_event(
self,
*,
stream_id: int | None = None,
excl_event: tuple[type[Event], ...] | None = None,
) -> bool:
"""Verify if there is queued event waiting to be consumed."""
raise NotImplementedError
@abstractmethod
def reshelve(self, *events: Event) -> None:
"""Put back events into the deque."""
raise NotImplementedError
@abstractmethod
def ping(self) -> None:
"""Send a PING frame to the remote peer. Thus keeping the connection alive."""
raise NotImplementedError
class HTTPOverTCPProtocol(HTTPProtocol, OverTCPProtocol, metaclass=ABCMeta):
"""
:class:`HTTPProtocol` over a TCP connection
An interface for HTTP/1 and HTTP/2 protocols.
Extends :class:`.HTTPProtocol`.
"""
class HTTPOverQUICProtocol(HTTPProtocol, OverQUICProtocol, metaclass=ABCMeta):
"""
:class:`HTTPProtocol` over a QUIC connection
Abstract base class for HTTP/3 protocols.
Extends :class:`.HTTPProtocol`.
"""
class HTTP1Protocol(HTTPOverTCPProtocol, metaclass=ABCMeta):
"""
Sans-IO representation of an HTTP/1 connection
An interface for HTTP/1 implementations.
Extends :class:`.HTTPOverTCPProtocol`.
"""
@property
def multiplexed(self) -> bool:
return False
def should_wait_remote_flow_control(
self, stream_id: int, amt: int | None = None
) -> bool | None:
return NotImplemented # type: ignore[no-any-return]
class HTTP2Protocol(HTTPOverTCPProtocol, metaclass=ABCMeta):
"""
Sans-IO representation of an HTTP/2 connection
An abstract base class for HTTP/2 implementations.
Extends :class:`.HTTPOverTCPProtocol`.
"""
@property
def multiplexed(self) -> bool:
return True
class HTTP3Protocol(HTTPOverQUICProtocol, metaclass=ABCMeta):
"""
Sans-IO representation of an HTTP/2 connection
An abstract base class for HTTP/3 implementations.
Extends :class:`.HTTPOverQUICProtocol`
"""
@property
def multiplexed(self) -> bool:
return True

View File

@@ -0,0 +1,21 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from ._h11 import HTTP1ProtocolHyperImpl
__all__ = ("HTTP1ProtocolHyperImpl",)

View File

@@ -0,0 +1,347 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import warnings
from functools import lru_cache
import h11
from h11._state import _SWITCH_UPGRADE, ConnectionState
from ..._stream_matrix import StreamMatrix
from ..._typing import HeadersType
from ...events import (
ConnectionTerminated,
DataReceived,
EarlyHeadersReceived,
Event,
HeadersReceived,
)
from .._protocols import HTTP1Protocol
@lru_cache(maxsize=64)
def capitalize_header_name(name: bytes) -> bytes:
"""
Take a header name and capitalize it.
>>> capitalize_header_name(b"x-hEllo-wORLD")
'X-Hello-World'
>>> capitalize_header_name(b"server")
'Server'
>>> capitalize_header_name(b"contEnt-TYPE")
'Content-Type'
>>> capitalize_header_name(b"content_type")
'Content-Type'
"""
return b"-".join(el.capitalize() for el in name.split(b"-"))
def headers_to_request(headers: HeadersType) -> h11.Event:
method = authority = path = host = None
regular_headers = []
for name, value in headers:
if name.startswith(b":"):
if name == b":method":
method = value
elif name == b":scheme":
pass
elif name == b":authority":
authority = value
elif name == b":path":
path = value
else:
raise ValueError("Unexpected request header: " + name.decode())
else:
if host is None and name == b"host":
host = value
# We found that many projects... actually expect the header name to be sent capitalized... hardcoded
# within their tests. Bad news, we have to keep doing this nonsense (namely capitalize_header_name)
regular_headers.append((capitalize_header_name(name), value))
if authority is None:
raise ValueError("Missing request header: :authority")
if method == b"CONNECT" and path is None:
# CONNECT requests are a special case.
target = authority
else:
target = path # type: ignore[assignment]
if host is None:
regular_headers.insert(0, (b"Host", authority))
elif host != authority:
raise ValueError("Host header does not match :authority.")
return h11.Request(
method=method, # type: ignore[arg-type]
headers=regular_headers,
target=target,
)
def headers_from_response(
response: h11.InformationalResponse | h11.Response,
) -> HeadersType:
"""
Converts an HTTP/1.0 or HTTP/1.1 response to HTTP/2-like headers.
Generates from pseudo (colon) headers from a response line.
"""
return [
(b":status", str(response.status_code).encode("ascii"))
] + response.headers.raw_items()
class RelaxConnectionState(ConnectionState):
def process_event( # type: ignore[no-untyped-def]
self,
role,
event_type,
server_switch_event=None,
) -> None:
if server_switch_event is not None:
if server_switch_event not in self.pending_switch_proposals:
if server_switch_event is _SWITCH_UPGRADE:
warnings.warn(
f"Received server {server_switch_event} event without a pending proposal. "
"This will raise an exception in a future version. It is temporarily relaxed to match the "
"legacy http.client standard library.",
DeprecationWarning,
stacklevel=2,
)
self.pending_switch_proposals.add(_SWITCH_UPGRADE)
return super().process_event(role, event_type, server_switch_event)
class HTTP1ProtocolHyperImpl(HTTP1Protocol):
implementation: str = "h11"
def __init__(self) -> None:
self._connection: h11.Connection = h11.Connection(h11.CLIENT)
self._connection._cstate = RelaxConnectionState()
self._data_buffer: list[bytes] = []
self._events: StreamMatrix = StreamMatrix()
self._terminated: bool = False
self._switched: bool = False
self._current_stream_id: int = 1
@staticmethod
def exceptions() -> tuple[type[BaseException], ...]:
return h11.LocalProtocolError, h11.ProtocolError, h11.RemoteProtocolError
def is_available(self) -> bool:
return self._connection.our_state == self._connection.their_state == h11.IDLE
@property
def max_stream_count(self) -> int:
return 1
def is_idle(self) -> bool:
return self._connection.their_state in {
h11.IDLE,
h11.MUST_CLOSE,
}
def has_expired(self) -> bool:
return self._terminated
def get_available_stream_id(self) -> int:
if not self.is_available():
raise RuntimeError(
"Cannot generate a new stream ID because the connection is not idle. "
"HTTP/1.1 is not multiplexed and we do not support HTTP pipelining."
)
return self._current_stream_id
def submit_close(self, error_code: int = 0) -> None:
pass # no-op
def submit_headers(
self, stream_id: int, headers: HeadersType, end_stream: bool = False
) -> None:
if stream_id != self._current_stream_id:
raise ValueError("Invalid stream ID.")
self._h11_submit(headers_to_request(headers))
if end_stream:
self._h11_submit(h11.EndOfMessage())
def submit_data(
self, stream_id: int, data: bytes, end_stream: bool = False
) -> None:
if stream_id != self._current_stream_id:
raise ValueError("Invalid stream ID.")
if self._connection.their_state == h11.SWITCHED_PROTOCOL:
self._data_buffer.append(data)
if end_stream:
self._events.append(self._connection_terminated())
return
self._h11_submit(h11.Data(data))
if end_stream:
self._h11_submit(h11.EndOfMessage())
def submit_stream_reset(self, stream_id: int, error_code: int = 0) -> None:
# HTTP/1 cannot submit a stream (it does not have real streams).
# But if there are no other streams, we can close the connection instead.
self.connection_lost()
def connection_lost(self) -> None:
if self._connection.their_state == h11.SWITCHED_PROTOCOL:
self._events.append(self._connection_terminated())
return
# This method is called when the connection is closed without an EOF.
# But not all connections support EOF, so being here does not
# necessarily mean that something when wrong.
#
# The tricky part is that HTTP/1.0 server can send responses
# without Content-Length or Transfer-Encoding headers,
# meaning that a response body is closed with the connection.
# In such cases, we require a proper EOF to distinguish complete
# messages from partial messages interrupted by network failure.
if not self._terminated:
self._connection.send_failed()
self._events.append(self._connection_terminated())
def eof_received(self) -> None:
if self._connection.their_state == h11.SWITCHED_PROTOCOL:
self._events.append(self._connection_terminated())
return
self._h11_data_received(b"")
def bytes_received(self, data: bytes) -> None:
if not data:
return # h11 treats empty data as EOF.
if self._connection.their_state == h11.SWITCHED_PROTOCOL:
self._events.append(DataReceived(self._current_stream_id, data))
return
else:
self._h11_data_received(data)
def bytes_to_send(self) -> bytes:
data = b"".join(self._data_buffer)
self._data_buffer.clear()
self._maybe_start_next_cycle()
return data
def next_event(self, stream_id: int | None = None) -> Event | None:
return self._events.popleft(stream_id=stream_id)
def has_pending_event(
self,
*,
stream_id: int | None = None,
excl_event: tuple[type[Event], ...] | None = None,
) -> bool:
return self._events.has(stream_id=stream_id, excl_event=excl_event)
def _h11_submit(self, h11_event: h11.Event) -> None:
chunks = self._connection.send_with_data_passthrough(h11_event)
if chunks:
self._data_buffer += chunks
def _h11_data_received(self, data: bytes) -> None:
self._connection.receive_data(data)
self._fetch_events()
def _fetch_events(self) -> None:
a = self._events.append
while not self._terminated:
try:
h11_event = self._connection.next_event()
except h11.RemoteProtocolError as e:
a(self._connection_terminated(e.error_status_hint, str(e)))
break
ev_type = h11_event.__class__
if h11_event is h11.NEED_DATA or h11_event is h11.PAUSED:
if h11.MUST_CLOSE == self._connection.their_state:
a(self._connection_terminated())
else:
break
elif ev_type is h11.Response:
a(
HeadersReceived(
self._current_stream_id,
headers_from_response(h11_event), # type: ignore[arg-type]
)
)
elif ev_type is h11.InformationalResponse:
a(
EarlyHeadersReceived(
stream_id=self._current_stream_id,
headers=headers_from_response(h11_event), # type: ignore[arg-type]
)
)
elif ev_type is h11.Data:
# officially h11 typed data as "bytes"
# but we... found that it store bytearray sometime.
payload = h11_event.data # type: ignore[union-attr]
a(
DataReceived(
self._current_stream_id,
bytes(payload) if payload.__class__ is bytearray else payload,
)
)
elif ev_type is h11.EndOfMessage:
# HTTP/2 and HTTP/3 send END_STREAM flag with HEADERS and DATA frames.
# We emulate similar behavior for HTTP/1.
if h11_event.headers: # type: ignore[union-attr]
last_event: HeadersReceived | DataReceived = HeadersReceived(
self._current_stream_id,
h11_event.headers, # type: ignore[union-attr]
self._connection.their_state != h11.MIGHT_SWITCH_PROTOCOL, # type: ignore[attr-defined]
)
else:
last_event = DataReceived(
self._current_stream_id,
b"",
self._connection.their_state != h11.MIGHT_SWITCH_PROTOCOL, # type: ignore[attr-defined]
)
a(last_event)
self._maybe_start_next_cycle()
elif ev_type is h11.ConnectionClosed:
a(self._connection_terminated())
def _connection_terminated(
self, error_code: int = 0, message: str | None = None
) -> Event:
self._terminated = True
return ConnectionTerminated(error_code, message)
def _maybe_start_next_cycle(self) -> None:
if h11.DONE == self._connection.our_state == self._connection.their_state:
self._connection.start_next_cycle()
self._current_stream_id += 1
if h11.SWITCHED_PROTOCOL == self._connection.their_state and not self._switched:
data, closed = self._connection.trailing_data
if data:
self._events.append(DataReceived(self._current_stream_id, data))
self._switched = True
def reshelve(self, *events: Event) -> None:
for ev in reversed(events):
self._events.appendleft(ev)
def ping(self) -> None:
raise NotImplementedError("http1 does not support PING")

View File

@@ -0,0 +1,21 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from ._h2 import HTTP2ProtocolHyperImpl
__all__ = ("HTTP2ProtocolHyperImpl",)

View File

@@ -0,0 +1,312 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from secrets import token_bytes
from typing import Iterator
import jh2.config # type: ignore
import jh2.connection # type: ignore
import jh2.errors # type: ignore
import jh2.events # type: ignore
import jh2.exceptions # type: ignore
import jh2.settings # type: ignore
from ..._stream_matrix import StreamMatrix
from ..._typing import HeadersType
from ...events import (
ConnectionTerminated,
DataReceived,
EarlyHeadersReceived,
Event,
GoawayReceived,
HandshakeCompleted,
HeadersReceived,
StreamResetReceived,
)
from .._protocols import HTTP2Protocol
class _PatchedH2Connection(jh2.connection.H2Connection): # type: ignore[misc]
"""
This is a performance hotfix class. We internally, already keep
track of the open stream count.
"""
def __init__(
self,
config: jh2.config.H2Configuration | None = None,
observable_impl: HTTP2ProtocolHyperImpl | None = None,
) -> None:
super().__init__(config=config)
# by default CONNECT is disabled
# we need it to support natively WebSocket over HTTP/2 for example.
self.local_settings = jh2.settings.Settings(
client=True,
initial_values={
jh2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 100,
jh2.settings.SettingCodes.MAX_HEADER_LIST_SIZE: self.DEFAULT_MAX_HEADER_LIST_SIZE,
jh2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL: 1,
},
)
self._observable_impl = observable_impl
def _open_streams(self, *args, **kwargs) -> int: # type: ignore[no-untyped-def]
if self._observable_impl is not None:
return self._observable_impl._open_stream_count
return super()._open_streams(*args, **kwargs) # type: ignore[no-any-return]
def _receive_goaway_frame(self, frame): # type: ignore[no-untyped-def]
"""
Receive a GOAWAY frame on the connection.
We purposely override this method to work around a known bug of jh2.
"""
events = self.state_machine.process_input(
jh2.connection.ConnectionInputs.RECV_GOAWAY
)
err_code = jh2.errors._error_code_from_int(frame.error_code)
# GOAWAY allows an
# endpoint to gracefully stop accepting new streams while still
# finishing processing of previously established streams.
# see https://tools.ietf.org/html/rfc7540#section-6.8
# hyper/h2 does not allow such a thing for now. let's work around this.
if (
err_code == 0
and self._observable_impl is not None
and self._observable_impl._open_stream_count > 0
):
self.state_machine.state = jh2.connection.ConnectionState.CLIENT_OPEN
# Clear the outbound data buffer: we cannot send further data now.
self.clear_outbound_data_buffer()
# Fire an appropriate ConnectionTerminated event.
new_event = jh2.events.ConnectionTerminated()
new_event.error_code = err_code
new_event.last_stream_id = frame.last_stream_id
new_event.additional_data = (
frame.additional_data if frame.additional_data else None
)
events.append(new_event)
return [], events
HEADER_OR_TRAILER_TYPE_SET = {
jh2.events.ResponseReceived,
jh2.events.TrailersReceived,
}
class HTTP2ProtocolHyperImpl(HTTP2Protocol):
implementation: str = "h2"
def __init__(
self,
*,
validate_outbound_headers: bool = False,
validate_inbound_headers: bool = False,
normalize_outbound_headers: bool = False,
normalize_inbound_headers: bool = True,
) -> None:
self._connection: jh2.connection.H2Connection = _PatchedH2Connection(
jh2.config.H2Configuration(
client_side=True,
validate_outbound_headers=validate_outbound_headers,
normalize_outbound_headers=normalize_outbound_headers,
validate_inbound_headers=validate_inbound_headers,
normalize_inbound_headers=normalize_inbound_headers,
),
observable_impl=self,
)
self._open_stream_count: int = 0
self._connection.initiate_connection()
self._connection.increment_flow_control_window(2**24)
self._events: StreamMatrix = StreamMatrix()
self._terminated: bool = False
self._goaway_to_honor: bool = False
self._max_stream_count: int = (
self._connection.remote_settings.max_concurrent_streams
)
self._max_frame_size: int = self._connection.remote_settings.max_frame_size
def max_frame_size(self) -> int:
return self._max_frame_size
@staticmethod
def exceptions() -> tuple[type[BaseException], ...]:
return jh2.exceptions.ProtocolError, jh2.exceptions.H2Error
def is_available(self) -> bool:
if self._terminated:
return False
return self._max_stream_count > self._open_stream_count
@property
def max_stream_count(self) -> int:
return self._max_stream_count
def is_idle(self) -> bool:
return self._terminated is False and self._open_stream_count == 0
def has_expired(self) -> bool:
return self._terminated or self._goaway_to_honor
def get_available_stream_id(self) -> int:
return self._connection.get_next_available_stream_id() # type: ignore[no-any-return]
def submit_close(self, error_code: int = 0) -> None:
self._connection.close_connection(error_code)
def submit_headers(
self, stream_id: int, headers: HeadersType, end_stream: bool = False
) -> None:
self._connection.send_headers(stream_id, headers, end_stream)
self._connection.increment_flow_control_window(2**24, stream_id=stream_id)
self._open_stream_count += 1
def submit_data(
self, stream_id: int, data: bytes, end_stream: bool = False
) -> None:
self._connection.send_data(stream_id, data, end_stream)
def submit_stream_reset(self, stream_id: int, error_code: int = 0) -> None:
self._connection.reset_stream(stream_id, error_code)
def next_event(self, stream_id: int | None = None) -> Event | None:
return self._events.popleft(stream_id=stream_id)
def has_pending_event(
self,
*,
stream_id: int | None = None,
excl_event: tuple[type[Event], ...] | None = None,
) -> bool:
return self._events.has(stream_id=stream_id, excl_event=excl_event)
def _map_events(self, h2_events: list[jh2.events.Event]) -> Iterator[Event]:
for e in h2_events:
ev_type = e.__class__
if ev_type in HEADER_OR_TRAILER_TYPE_SET:
end_stream = e.stream_ended is not None
if end_stream:
self._open_stream_count -= 1
stream = self._connection.streams.pop(e.stream_id)
self._connection._closed_streams[e.stream_id] = stream.closed_by
yield HeadersReceived(e.stream_id, e.headers, end_stream=end_stream)
elif ev_type is jh2.events.DataReceived:
end_stream = e.stream_ended is not None
if end_stream:
self._open_stream_count -= 1
stream = self._connection.streams.pop(e.stream_id)
self._connection._closed_streams[e.stream_id] = stream.closed_by
self._connection.acknowledge_received_data(
e.flow_controlled_length, e.stream_id
)
yield DataReceived(e.stream_id, e.data, end_stream=end_stream)
elif ev_type is jh2.events.InformationalResponseReceived:
yield EarlyHeadersReceived(
e.stream_id,
e.headers,
)
elif ev_type is jh2.events.StreamReset:
self._open_stream_count -= 1
# event StreamEnded may occur before StreamReset
if e.stream_id in self._connection.streams:
stream = self._connection.streams.pop(e.stream_id)
self._connection._closed_streams[e.stream_id] = stream.closed_by
yield StreamResetReceived(e.stream_id, e.error_code)
elif ev_type is jh2.events.ConnectionTerminated:
# ConnectionTerminated from h2 means that GOAWAY was received.
# A server can send GOAWAY for graceful shutdown, where clients
# do not open new streams, but inflight requests can be completed.
#
# Saying "connection was terminated" can be confusing,
# so we emit an event called "GoawayReceived".
if e.error_code == 0:
self._goaway_to_honor = True
yield GoawayReceived(e.last_stream_id, e.error_code)
else:
self._terminated = True
yield ConnectionTerminated(e.error_code, None)
elif ev_type in {
jh2.events.SettingsAcknowledged,
jh2.events.RemoteSettingsChanged,
}:
yield HandshakeCompleted(alpn_protocol="h2")
def connection_lost(self) -> None:
self._connection_terminated()
def eof_received(self) -> None:
self._connection_terminated()
def bytes_received(self, data: bytes) -> None:
if not data:
return
try:
h2_events = self._connection.receive_data(data)
except jh2.exceptions.ProtocolError as e:
self._connection_terminated(e.error_code, str(e))
else:
self._events.extend(self._map_events(h2_events))
# we want to perpetually mark the connection as "saturated"
if self._goaway_to_honor:
self._max_stream_count = self._open_stream_count
if self._connection.remote_settings.has_update:
if not self._goaway_to_honor:
self._max_stream_count = (
self._connection.remote_settings.max_concurrent_streams
)
self._max_frame_size = self._connection.remote_settings.max_frame_size
def bytes_to_send(self) -> bytes:
return self._connection.data_to_send() # type: ignore[no-any-return]
def _connection_terminated(
self, error_code: int = 0, message: str | None = None
) -> None:
if self._terminated:
return
error_code = int(error_code) # Convert h2 IntEnum to an actual int
self._terminated = True
self._events.append(ConnectionTerminated(error_code, message))
def should_wait_remote_flow_control(
self, stream_id: int, amt: int | None = None
) -> bool | None:
flow_remaining_bytes: int = self._connection.local_flow_control_window(
stream_id
)
if amt is None:
return flow_remaining_bytes == 0
return amt > flow_remaining_bytes
def reshelve(self, *events: Event) -> None:
for ev in reversed(events):
self._events.appendleft(ev)
def ping(self) -> None:
self._connection.ping(token_bytes(8))

View File

@@ -0,0 +1,21 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from ._qh3 import HTTP3ProtocolAioQuicImpl
__all__ = ("HTTP3ProtocolAioQuicImpl",)

View File

@@ -0,0 +1,592 @@
# Copyright 2022 Akamai Technologies, Inc
# Largely rewritten in 2023 for urllib3-future
# Copyright 2024 Ahmed Tahri
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import datetime
import ssl
import typing
from collections import deque
from os import environ
from random import randint
from time import time as monotonic
from typing import Any, Iterable, Sequence
if typing.TYPE_CHECKING:
from typing_extensions import Literal
from qh3 import (
CipherSuite,
H3Connection,
H3Error,
ProtocolError,
QuicConfiguration,
QuicConnection,
QuicConnectionError,
QuicFileLogger,
SessionTicket,
h3_events,
quic_events,
)
from qh3.h3.connection import FrameType
from qh3.quic.connection import QuicConnectionState
from ..._configuration import QuicTLSConfig
from ..._stream_matrix import StreamMatrix
from ..._typing import AddressType, HeadersType
from ...events import (
ConnectionTerminated,
DataReceived,
EarlyHeadersReceived,
Event,
GoawayReceived,
)
from ...events import HandshakeCompleted as _HandshakeCompleted
from ...events import HeadersReceived, StreamResetReceived
from .._protocols import HTTP3Protocol
QUIC_RELEVANT_EVENT_TYPES = {
quic_events.HandshakeCompleted,
quic_events.ConnectionTerminated,
quic_events.StreamReset,
}
class HTTP3ProtocolAioQuicImpl(HTTP3Protocol):
implementation: str = "qh3"
def __init__(
self,
*,
remote_address: AddressType,
server_name: str,
tls_config: QuicTLSConfig,
) -> None:
keylogfile_path: str | None = environ.get("SSLKEYLOGFILE", None)
qlogdir_path: str | None = environ.get("QUICLOGDIR", None)
self._configuration: QuicConfiguration = QuicConfiguration(
is_client=True,
verify_mode=ssl.CERT_NONE if tls_config.insecure else ssl.CERT_REQUIRED,
cafile=tls_config.cafile,
capath=tls_config.capath,
cadata=tls_config.cadata,
alpn_protocols=["h3"],
session_ticket=tls_config.session_ticket,
server_name=server_name,
hostname_checks_common_name=tls_config.cert_use_common_name,
assert_fingerprint=tls_config.cert_fingerprint,
verify_hostname=tls_config.verify_hostname,
secrets_log_file=open(keylogfile_path, "w") if keylogfile_path else None, # type: ignore[arg-type]
quic_logger=QuicFileLogger(qlogdir_path) if qlogdir_path else None,
idle_timeout=tls_config.idle_timeout,
max_data=2**24,
max_stream_data=2**24,
)
if tls_config.ciphers:
available_ciphers = {c.name: c for c in CipherSuite}
chosen_ciphers: list[CipherSuite] = []
for cipher in tls_config.ciphers:
if "name" in cipher and isinstance(cipher["name"], str):
chosen_ciphers.append(
available_ciphers[cipher["name"].replace("TLS_", "")]
)
if len(chosen_ciphers) == 0:
raise ValueError(
f"Unable to find a compatible cipher in '{tls_config.ciphers}' to establish a QUIC connection. "
f"QUIC support one of '{['TLS_' + e for e in available_ciphers.keys()]}' only."
)
self._configuration.cipher_suites = chosen_ciphers
if tls_config.certfile:
self._configuration.load_cert_chain(
tls_config.certfile,
tls_config.keyfile,
tls_config.keypassword,
)
self._quic: QuicConnection = QuicConnection(configuration=self._configuration)
self._connection_ids: set[bytes] = set()
self._remote_address = remote_address
self._events: StreamMatrix = StreamMatrix()
self._packets: deque[bytes] = deque()
self._http: H3Connection | None = None
self._terminated: bool = False
self._data_in_flight: bool = False
self._open_stream_count: int = 0
self._total_stream_count: int = 0
self._goaway_to_honor: bool = False
self._max_stream_count: int = (
100 # safe-default, broadly used. (and set by qh3)
)
self._max_frame_size: int | None = None
@staticmethod
def exceptions() -> tuple[type[BaseException], ...]:
return ProtocolError, H3Error, QuicConnectionError, AssertionError
@property
def max_stream_count(self) -> int:
return self._max_stream_count
def is_available(self) -> bool:
return (
self._terminated is False
and self._max_stream_count > self._quic.open_outbound_streams
)
def is_idle(self) -> bool:
return self._terminated is False and self._open_stream_count == 0
def has_expired(self) -> bool:
if not self._terminated and not self._goaway_to_honor:
now = monotonic()
self._quic.handle_timer(now)
self._packets.extend(
map(lambda e: e[0], self._quic.datagrams_to_send(now=now))
)
if self._quic._state in {
QuicConnectionState.CLOSING,
QuicConnectionState.TERMINATED,
}:
self._terminated = True
if (
hasattr(self._quic, "_close_event")
and self._quic._close_event is not None
):
self._events.extend(self._map_quic_event(self._quic._close_event))
self._terminated = True
return self._terminated or self._goaway_to_honor
@property
def session_ticket(self) -> SessionTicket | None:
return self._quic.tls.session_ticket if self._quic and self._quic.tls else None
def get_available_stream_id(self) -> int:
return self._quic.get_next_available_stream_id()
def submit_close(self, error_code: int = 0) -> None:
# QUIC has two different frame types for closing the connection.
# From RFC 9000 (QUIC: A UDP-Based Multiplexed and Secure Transport):
#
# > An endpoint sends a CONNECTION_CLOSE frame (type=0x1c or 0x1d)
# > to notify its peer that the connection is being closed.
# > The CONNECTION_CLOSE frame with a type of 0x1c is used to signal errors
# > at only the QUIC layer, or the absence of errors (with the NO_ERROR code).
# > The CONNECTION_CLOSE frame with a type of 0x1d is used
# > to signal an error with the application that uses QUIC.
frame_type = 0x1D if error_code else 0x1C
self._quic.close(error_code=error_code, frame_type=frame_type)
def submit_headers(
self, stream_id: int, headers: HeadersType, end_stream: bool = False
) -> None:
assert self._http is not None
self._open_stream_count += 1
self._total_stream_count += 1
self._http.send_headers(stream_id, list(headers), end_stream)
def submit_data(
self, stream_id: int, data: bytes, end_stream: bool = False
) -> None:
assert self._http is not None
self._http.send_data(stream_id, data, end_stream)
if end_stream is False:
self._data_in_flight = True
def submit_stream_reset(self, stream_id: int, error_code: int = 0) -> None:
self._quic.reset_stream(stream_id, error_code)
def next_event(self, stream_id: int | None = None) -> Event | None:
return self._events.popleft(stream_id=stream_id)
def has_pending_event(
self,
*,
stream_id: int | None = None,
excl_event: tuple[type[Event], ...] | None = None,
) -> bool:
return self._events.has(stream_id=stream_id, excl_event=excl_event)
@property
def connection_ids(self) -> Sequence[bytes]:
return list(self._connection_ids)
def connection_lost(self) -> None:
self._terminated = True
self._events.append(ConnectionTerminated())
def bytes_received(self, data: bytes) -> None:
self._quic.receive_datagram(data, self._remote_address, now=monotonic())
self._fetch_events()
if self._data_in_flight:
self._data_in_flight = False
# we want to perpetually mark the connection as "saturated"
if self._goaway_to_honor:
self._max_stream_count = self._open_stream_count
else:
# This section may confuse beginners
# See RFC 9000 -> 19.11. MAX_STREAMS Frames
# footer extract:
# Note that these frames (and the corresponding transport parameters)
# do not describe the number of streams that can be opened
# concurrently. The limit includes streams that have been closed as
# well as those that are open.
#
# so, finding that remote_max_streams_bidi is increasing constantly is normal.
new_stream_limit = (
self._quic._remote_max_streams_bidi - self._total_stream_count
)
if (
new_stream_limit
and new_stream_limit != self._max_stream_count
and new_stream_limit > 0
):
self._max_stream_count = new_stream_limit
if (
self._quic._remote_max_stream_data_bidi_remote
and self._quic._remote_max_stream_data_bidi_remote
!= self._max_frame_size
):
self._max_frame_size = self._quic._remote_max_stream_data_bidi_remote
def bytes_to_send(self) -> bytes:
if not self._packets:
now = monotonic()
if self._http is None:
self._quic.connect(self._remote_address, now=now)
self._http = H3Connection(self._quic)
# the QUIC state machine returns datagrams (addr, packet)
# the client never have to worry about the destination
# unless server yield a preferred address?
self._packets.extend(
map(lambda e: e[0], self._quic.datagrams_to_send(now=now))
)
if not self._packets:
return b""
# it is absolutely crucial to return one at a time
# because UDP don't support sending more than
# MTU (to be more precise, lowest MTU in the network path from A (you) to B (server))
return self._packets.popleft()
def _fetch_events(self) -> None:
assert self._http is not None
for quic_event in iter(self._quic.next_event, None):
self._events.extend(self._map_quic_event(quic_event))
for h3_event in self._http.handle_event(quic_event):
self._events.extend(self._map_h3_event(h3_event))
if hasattr(self._quic, "_close_event") and self._quic._close_event is not None:
self._events.extend(self._map_quic_event(self._quic._close_event))
def _map_quic_event(self, quic_event: quic_events.QuicEvent) -> Iterable[Event]:
ev_type = quic_event.__class__
# fastest path execution, most of the time we don't have those
# 3 event types.
if ev_type not in QUIC_RELEVANT_EVENT_TYPES:
return
if ev_type is quic_events.HandshakeCompleted:
yield _HandshakeCompleted(quic_event.alpn_protocol) # type: ignore[attr-defined]
elif ev_type is quic_events.ConnectionTerminated:
if quic_event.frame_type == FrameType.GOAWAY.value: # type: ignore[attr-defined]
self._goaway_to_honor = True
stream_list: list[int] = [
e for e in self._events._matrix.keys() if e is not None
]
yield GoawayReceived(stream_list[-1], quic_event.error_code) # type: ignore[attr-defined]
else:
self._terminated = True
yield ConnectionTerminated(
quic_event.error_code, # type: ignore[attr-defined]
quic_event.reason_phrase, # type: ignore[attr-defined]
)
elif ev_type is quic_events.StreamReset:
self._open_stream_count -= 1
yield StreamResetReceived(quic_event.stream_id, quic_event.error_code) # type: ignore[attr-defined]
def _map_h3_event(self, h3_event: h3_events.H3Event) -> Iterable[Event]:
ev_type = h3_event.__class__
if ev_type is h3_events.HeadersReceived:
if h3_event.stream_ended: # type: ignore[attr-defined]
self._open_stream_count -= 1
yield HeadersReceived(
h3_event.stream_id, # type: ignore[attr-defined]
h3_event.headers, # type: ignore[attr-defined]
h3_event.stream_ended, # type: ignore[attr-defined]
)
elif ev_type is h3_events.DataReceived:
if h3_event.stream_ended: # type: ignore[attr-defined]
self._open_stream_count -= 1
yield DataReceived(h3_event.stream_id, h3_event.data, h3_event.stream_ended) # type: ignore[attr-defined]
elif ev_type is h3_events.InformationalHeadersReceived:
yield EarlyHeadersReceived(
h3_event.stream_id, # type: ignore[attr-defined]
h3_event.headers, # type: ignore[attr-defined]
)
def should_wait_remote_flow_control(
self, stream_id: int, amt: int | None = None
) -> bool | None:
return self._data_in_flight
@typing.overload
def getissuercert(self, *, binary_form: Literal[True]) -> bytes | None: ...
@typing.overload
def getissuercert(
self, *, binary_form: Literal[False] = ...
) -> dict[str, Any] | None: ...
def getissuercert(
self, *, binary_form: bool = False
) -> bytes | dict[str, typing.Any] | None:
x509_certificate = self._quic.get_peercert()
if x509_certificate is None:
raise ValueError("TLS handshake has not been done yet")
if not self._quic.get_issuercerts():
return None
x509_certificate = self._quic.get_issuercerts()[0]
if binary_form:
return x509_certificate.public_bytes()
datetime.datetime.fromtimestamp(
x509_certificate.not_valid_before, tz=datetime.timezone.utc
)
issuer_info = {
"version": x509_certificate.version + 1,
"serialNumber": x509_certificate.serial_number.upper(),
"subject": [],
"issuer": [],
"notBefore": datetime.datetime.fromtimestamp(
x509_certificate.not_valid_before, tz=datetime.timezone.utc
).strftime("%b %d %H:%M:%S %Y")
+ " UTC",
"notAfter": datetime.datetime.fromtimestamp(
x509_certificate.not_valid_after, tz=datetime.timezone.utc
).strftime("%b %d %H:%M:%S %Y")
+ " UTC",
}
_short_name_assoc = {
"CN": "commonName",
"L": "localityName",
"ST": "stateOrProvinceName",
"O": "organizationName",
"OU": "organizationalUnitName",
"C": "countryName",
"STREET": "streetAddress",
"DC": "domainComponent",
"E": "email",
}
for raw_oid, rfc4514_attribute_name, value in x509_certificate.subject:
if rfc4514_attribute_name not in _short_name_assoc:
continue
issuer_info["subject"].append( # type: ignore[attr-defined]
(
(
_short_name_assoc[rfc4514_attribute_name],
value.decode(),
),
)
)
for raw_oid, rfc4514_attribute_name, value in x509_certificate.issuer:
if rfc4514_attribute_name not in _short_name_assoc:
continue
issuer_info["issuer"].append( # type: ignore[attr-defined]
(
(
_short_name_assoc[rfc4514_attribute_name],
value.decode(),
),
)
)
return issuer_info
@typing.overload
def getpeercert(self, *, binary_form: Literal[True]) -> bytes: ...
@typing.overload
def getpeercert(self, *, binary_form: Literal[False] = ...) -> dict[str, Any]: ...
def getpeercert(
self, *, binary_form: bool = False
) -> bytes | dict[str, typing.Any]:
x509_certificate = self._quic.get_peercert()
if x509_certificate is None:
raise ValueError("TLS handshake has not been done yet")
if binary_form:
return x509_certificate.public_bytes()
peer_info = {
"version": x509_certificate.version + 1,
"serialNumber": x509_certificate.serial_number.upper(),
"subject": [],
"issuer": [],
"notBefore": datetime.datetime.fromtimestamp(
x509_certificate.not_valid_before, tz=datetime.timezone.utc
).strftime("%b %d %H:%M:%S %Y")
+ " UTC",
"notAfter": datetime.datetime.fromtimestamp(
x509_certificate.not_valid_after, tz=datetime.timezone.utc
).strftime("%b %d %H:%M:%S %Y")
+ " UTC",
"subjectAltName": [],
"OCSP": [],
"caIssuers": [],
"crlDistributionPoints": [],
}
_short_name_assoc = {
"CN": "commonName",
"L": "localityName",
"ST": "stateOrProvinceName",
"O": "organizationName",
"OU": "organizationalUnitName",
"C": "countryName",
"STREET": "streetAddress",
"DC": "domainComponent",
"E": "email",
}
for raw_oid, rfc4514_attribute_name, value in x509_certificate.subject:
if rfc4514_attribute_name not in _short_name_assoc:
continue
peer_info["subject"].append( # type: ignore[attr-defined]
(
(
_short_name_assoc[rfc4514_attribute_name],
value.decode(),
),
)
)
for raw_oid, rfc4514_attribute_name, value in x509_certificate.issuer:
if rfc4514_attribute_name not in _short_name_assoc:
continue
peer_info["issuer"].append( # type: ignore[attr-defined]
(
(
_short_name_assoc[rfc4514_attribute_name],
value.decode(),
),
)
)
for alt_name in x509_certificate.get_subject_alt_names():
decoded_alt_name = alt_name.decode()
in_parenthesis = decoded_alt_name[
decoded_alt_name.index("(") + 1 : decoded_alt_name.index(")")
]
if decoded_alt_name.startswith("DNS"):
peer_info["subjectAltName"].append(("DNS", in_parenthesis)) # type: ignore[attr-defined]
else:
from ....resolver.utils import inet4_ntoa, inet6_ntoa
if len(in_parenthesis) == 11:
ip_address_decoded = inet4_ntoa(
bytes.fromhex(in_parenthesis.replace(":", ""))
)
else:
ip_address_decoded = inet6_ntoa(
bytes.fromhex(in_parenthesis.replace(":", ""))
)
peer_info["subjectAltName"].append(("IP Address", ip_address_decoded)) # type: ignore[attr-defined]
peer_info["OCSP"] = []
for endpoint in x509_certificate.get_ocsp_endpoints():
decoded_endpoint = endpoint.decode()
peer_info["OCSP"].append( # type: ignore[attr-defined]
decoded_endpoint[decoded_endpoint.index("(") + 1 : -1]
)
peer_info["caIssuers"] = []
for endpoint in x509_certificate.get_issuer_endpoints():
decoded_endpoint = endpoint.decode()
peer_info["caIssuers"].append( # type: ignore[attr-defined]
decoded_endpoint[decoded_endpoint.index("(") + 1 : -1]
)
peer_info["crlDistributionPoints"] = []
for endpoint in x509_certificate.get_crl_endpoints():
decoded_endpoint = endpoint.decode()
peer_info["crlDistributionPoints"].append( # type: ignore[attr-defined]
decoded_endpoint[decoded_endpoint.index("(") + 1 : -1]
)
pop_keys = []
for k in peer_info:
if isinstance(peer_info[k], list):
peer_info[k] = tuple(peer_info[k]) # type: ignore[arg-type]
if not peer_info[k]:
pop_keys.append(k)
for k in pop_keys:
peer_info.pop(k)
return peer_info
def cipher(self) -> str | None:
cipher_suite = self._quic.get_cipher()
if cipher_suite is None:
raise ValueError("TLS handshake has not been done yet")
return f"TLS_{cipher_suite.name}"
def reshelve(self, *events: Event) -> None:
for ev in reversed(events):
self._events.appendleft(ev)
def ping(self) -> None:
self._quic.send_ping(randint(0, 65535))
def max_frame_size(self) -> int:
if self._max_frame_size is not None:
return self._max_frame_size
raise NotImplementedError

View File

@@ -0,0 +1,59 @@
from __future__ import annotations
import typing
from io import UnsupportedOperation
if typing.TYPE_CHECKING:
import ssl
from ._ctypes import load_cert_chain as _ctypes_load_cert_chain
from ._shm import load_cert_chain as _shm_load_cert_chain
SUPPORTED_METHODS: list[
typing.Callable[
[
ssl.SSLContext,
bytes | str,
bytes | str,
bytes | str | typing.Callable[[], str | bytes] | None,
],
None,
]
] = [
_ctypes_load_cert_chain,
_shm_load_cert_chain,
]
def load_cert_chain(
ctx: ssl.SSLContext,
certdata: bytes | str,
keydata: bytes | str,
password: bytes | str | typing.Callable[[], str | bytes] | None = None,
) -> None:
"""
Unique workaround the known limitation of CPython inability to initialize the mTLS context without files.
:raise UnsupportedOperation: If anything goes wrong in the process.
"""
err = None
for supported in SUPPORTED_METHODS:
try:
supported(
ctx,
certdata,
keydata,
password,
)
return
except UnsupportedOperation as e:
if err is None:
err = e
if err is not None:
raise err
raise UnsupportedOperation("unable to initialize mTLS using in-memory cert and key")
__all__ = ("load_cert_chain",)

View File

@@ -0,0 +1,376 @@
from __future__ import annotations
import ctypes
import os
import sys
import typing
from io import UnsupportedOperation
if typing.TYPE_CHECKING:
import ssl
class _OpenSSL:
"""Access hazardous material from CPython OpenSSL (or compatible SSL) implementation."""
def __init__(self) -> None:
import platform
if platform.python_implementation() != "CPython":
raise UnsupportedOperation("Only CPython is supported")
import ssl
self._name = ssl.OPENSSL_VERSION
self.ssl = ssl
# bug seen in Windows + CPython < 3.11
# where CPython official API for options
# cast OpenSSL get_options to SIGNED long
# where we want UNSIGNED long.
_ssl_options_signed_long_bug = False
if not hasattr(ssl, "_ssl"):
raise UnsupportedOperation(
"Unsupported interpreter due to missing private ssl module"
)
if platform.system() == "Windows":
# possible search locations
candidates = {
os.path.dirname(sys.executable),
os.path.join(sys.prefix, "DLLs"),
sys.prefix,
}
if hasattr(ssl._ssl, "__file__"):
candidates.add(os.path.dirname(ssl._ssl.__file__))
_ssl_options_signed_long_bug = sys.version_info < (3, 11)
ssl_potential_match = None
crypto_potential_match = None
for d in candidates:
if not os.path.exists(d):
continue
for filename in os.listdir(d):
if ssl_potential_match is None:
if filename.startswith("libssl") and filename.endswith(".dll"):
ssl_potential_match = os.path.join(d, filename)
if crypto_potential_match is None:
if filename.startswith("libcrypto") and filename.endswith(
".dll"
):
crypto_potential_match = os.path.join(d, filename)
if crypto_potential_match and ssl_potential_match:
break
if not ssl_potential_match or not crypto_potential_match:
raise UnsupportedOperation(
"Could not locate OpenSSL DLLs next to Python; "
"check your /DLLs folder or your PATH."
)
self._ssl = ctypes.CDLL(ssl_potential_match)
self._crypto = ctypes.CDLL(crypto_potential_match)
else:
# that's the most common path
# ssl built in module already loaded both crypto and ssl
# symbols.
if hasattr(ssl._ssl, "__file__"):
self._ssl = ctypes.CDLL(ssl._ssl.__file__)
else:
# _ssl is statically linked into the interpreter
# (e.g. python-build-standalone via uv). OpenSSL symbols
# are in the main process image; ctypes.CDLL(None) exposes them.
# see https://github.com/jawah/urllib3.future/issues/325 for more
# details.
self._ssl = ctypes.CDLL(None)
self._crypto = self._ssl
# we want to ensure a minimal set of symbols
# are present. CPython should have at least:
for required_symbol in [
"SSL_CTX_use_certificate",
"SSL_CTX_check_private_key",
"SSL_CTX_use_PrivateKey",
]:
if not hasattr(self._ssl, required_symbol):
raise UnsupportedOperation(
f"Python interpreter built against '{self._name}' is unsupported. (libssl) {required_symbol} is not present."
)
for required_symbol in [
"BIO_free",
"BIO_new_mem_buf",
"PEM_read_bio_X509",
"PEM_read_bio_PrivateKey",
"ERR_get_error",
"ERR_error_string",
]:
if not hasattr(self._crypto, required_symbol):
raise UnsupportedOperation(
f"Python interpreter built against '{self._name}' is unsupported. (libcrypto) {required_symbol} is not present."
)
# https://docs.openssl.org/3.0/man3/SSL_CTX_use_certificate/
self.SSL_CTX_use_certificate = self._ssl.SSL_CTX_use_certificate
self.SSL_CTX_use_certificate.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
self.SSL_CTX_use_certificate.restype = ctypes.c_int
self.SSL_CTX_check_private_key = self._ssl.SSL_CTX_check_private_key
self.SSL_CTX_check_private_key.argtypes = [ctypes.c_void_p]
self.SSL_CTX_check_private_key.restype = ctypes.c_int
# https://docs.openssl.org/3.0/man3/BIO_new/
self.BIO_free = self._crypto.BIO_free
self.BIO_free.argtypes = [ctypes.c_void_p]
self.BIO_free.restype = None
self.BIO_new_mem_buf = self._crypto.BIO_new_mem_buf
self.BIO_new_mem_buf.argtypes = [ctypes.c_void_p, ctypes.c_int]
self.BIO_new_mem_buf.restype = ctypes.c_void_p
# https://docs.openssl.org/3.0/man3/PEM_read_bio_PrivateKey/
self.PEM_read_bio_X509 = self._crypto.PEM_read_bio_X509
self.PEM_read_bio_X509.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
]
self.PEM_read_bio_X509.restype = ctypes.c_void_p
self.PEM_read_bio_PrivateKey = self._crypto.PEM_read_bio_PrivateKey
self.PEM_read_bio_PrivateKey.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
]
self.PEM_read_bio_PrivateKey.restype = ctypes.c_void_p
# https://docs.openssl.org/3.0/man3/SSL_CTX_use_certificate/
self.SSL_CTX_use_PrivateKey = self._ssl.SSL_CTX_use_PrivateKey
self.SSL_CTX_use_PrivateKey.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
self.SSL_CTX_use_PrivateKey.restype = ctypes.c_int
self.ERR_get_error = self._crypto.ERR_get_error
self.ERR_get_error.argtypes = []
self.ERR_get_error.restype = ctypes.c_ulong
self.ERR_error_string = self._crypto.ERR_error_string
self.ERR_error_string.argtypes = [ctypes.c_ulong, ctypes.c_char_p]
self.ERR_error_string.restype = ctypes.c_char_p
if hasattr(self._ssl, "SSL_CTX_get_options"):
self.SSL_CTX_get_options = self._ssl.SSL_CTX_get_options
self.SSL_CTX_get_options.argtypes = [ctypes.c_void_p]
self.SSL_CTX_get_options.restype = (
ctypes.c_ulong if not _ssl_options_signed_long_bug else ctypes.c_long
) # OpenSSL's options are long
elif hasattr(self._ssl, "SSL_CTX_ctrl"):
# some old build inline SSL_CTX_get_options (mere C define)
# define SSL_CTX_get_options(ctx) SSL_CTX_ctrl((ctx),SSL_CTRL_OPTIONS,0,NULL)
# define SSL_CTRL_OPTIONS 32
self.SSL_CTX_ctrl = self._ssl.SSL_CTX_ctrl
self.SSL_CTX_ctrl.argtypes = [
ctypes.c_void_p,
ctypes.c_int,
ctypes.c_int,
ctypes.c_void_p,
]
self.SSL_CTX_ctrl.restype = (
ctypes.c_ulong if not _ssl_options_signed_long_bug else ctypes.c_long
)
self.SSL_CTX_get_options = lambda ctx: self.SSL_CTX_ctrl( # type: ignore[assignment]
ctx, 32, 0, None
)
else:
raise UnsupportedOperation()
def pull_error(self) -> typing.NoReturn:
raise self.ssl.SSLError(
self.ERR_error_string(
self.ERR_get_error(), ctypes.create_string_buffer(256)
).decode()
)
_IS_GIL_DISABLED = hasattr(sys, "_is_gil_enabled") and sys._is_gil_enabled() is False
_IS_LINUX = sys.platform == "linux"
_FT_HEAD_ADDITIONAL_OFFSET = 1 if _IS_LINUX else 2
_head_extra_fields = []
if sys.flags.debug:
# In debug builds (_POSIX_C_SOURCE or Py_DEBUG is defined), PyObject_HEAD
# is preceded by _PyObject_HEAD_EXTRA, which typically consists of
# two pointers (_ob_next, _ob_prev).
_head_extra_fields = [("_ob_next", ctypes.c_void_p), ("_ob_prev", ctypes.c_void_p)]
# Define the PySSLContext C structure using ctypes.
# This definition assumes that 'SSL_CTX *ctx' is the first member
# immediately following PyObject_HEAD. This has been observed to be
# the case in various CPython versions (e.g., 3.7 through 3.14 so far).
#
# CPython's Modules/_ssl.c (simplified):
# typedef struct {
# PyObject_HEAD // Expands to _PyObject_HEAD_EXTRA (if debug) + ob_refcnt + ob_type
# SSL_CTX *ctx;
# // ... other members ...
# } PySSLContextObject;
#
class PySSLContextStruct(ctypes.Structure):
_fields_ = (
_head_extra_fields # type: ignore[assignment]
+ [
("ob_refcnt", ctypes.c_ssize_t), # Py_ssize_t ob_refcnt;
("ob_type", ctypes.c_void_p), # PyTypeObject *ob_type;
]
+ (
[(f"_ob_ft{i}", ctypes.c_void_p) for i in range(_FT_HEAD_ADDITIONAL_OFFSET)]
if _IS_GIL_DISABLED
else []
)
+ [
("ssl_ctx", ctypes.c_void_p), # SSL_CTX *ctx; (this is the pointer we want)
# If there were other C members between ob_type and ssl_ctx,
# they would need to be defined here with their correct types and padding.
]
)
def _split_client_cert(data: bytes) -> list[bytes]:
line_ending = b"\n" if b"-----\r\n" not in data else b"\r\n"
boundary = b"-----END CERTIFICATE-----" + line_ending
certificates = []
for chunk in data.split(boundary):
if chunk:
start_marker = chunk.find(b"-----BEGIN CERTIFICATE-----" + line_ending)
if start_marker == -1:
break
pem_reconstructed = b"".join([chunk[start_marker:], boundary])
certificates.append(pem_reconstructed)
return certificates
def load_cert_chain(
ctx: ssl.SSLContext,
certdata: bytes | str,
keydata: bytes | str,
password: bytes | str | typing.Callable[[], str | bytes] | None = None,
) -> None:
"""
Unique workaround the known limitation of CPython inability to initialize the mTLS context without files.
:raise UnsupportedOperation: If anything goes wrong in the process.
"""
lib = _OpenSSL()
# Get the memory address of the Python ssl.SSLContext object.
# id() returns the address of the PyObject.
addr = id(ctx)
# Cast this memory address to a pointer to our defined PySSLContextStruct.
ptr_to_pysslcontext_struct = ctypes.cast(addr, ctypes.POINTER(PySSLContextStruct))
# Access the 'ssl_ctx' field from the structure. This field holds the
# actual SSL_CTX* C pointer value.
ssl_ctx_address = ptr_to_pysslcontext_struct.contents.ssl_ctx
# We want to ensure we got the right pointer address
# the safest way to achieve that is to retrieve options
# and compare it with the official ctx property.
if lib.SSL_CTX_get_options is not None:
bypass_options = lib.SSL_CTX_get_options(ssl_ctx_address)
expected_options = int(ctx.options)
if bypass_options != expected_options:
raise UnsupportedOperation(
f"CPython internal SSL_CTX changed! Cannot pursue safely. Expected = {expected_options:x} Actual = {bypass_options:x}"
)
# normalize inputs
if isinstance(certdata, str):
certdata = certdata.encode()
if isinstance(keydata, str):
keydata = keydata.encode()
client_chain = _split_client_cert(certdata)
leaf_certificate = client_chain[0]
# Use a BIO to read the client certificate
# only the leaf certificate is supported here.
cert_bio = lib.BIO_new_mem_buf(leaf_certificate, len(leaf_certificate))
if not cert_bio:
raise MemoryError("Unable to allocate memory to load the client certificate")
# Use a BIO to load the key in-memory
key_bio = lib.BIO_new_mem_buf(keydata, len(keydata))
if not key_bio:
raise MemoryError("Unable to allocate memory to load the client key")
# prepare the password
if callable(password):
password = password()
if isinstance(password, str):
password = password.encode()
assert password is None or isinstance(password, bytes)
# the allocated X509 obj MUST NOT be freed by ourselves
# OpenSSL internals will free it once not needed.
cert = lib.PEM_read_bio_X509(cert_bio, None, None, None)
# we do own the BIO, once the X509 leaf is instantiated, no need
# to keep it afterward.
lib.BIO_free(cert_bio)
if not cert:
lib.pull_error()
pkey = lib.PEM_read_bio_PrivateKey(key_bio, None, None, password)
lib.BIO_free(key_bio)
if not pkey:
lib.pull_error()
if lib.SSL_CTX_use_certificate(ssl_ctx_address, cert) != 1:
lib.pull_error()
if lib.SSL_CTX_use_PrivateKey(ssl_ctx_address, pkey) != 1:
lib.pull_error()
if lib.SSL_CTX_check_private_key(ssl_ctx_address) != 1:
lib.pull_error()
# Unfortunately, most of the time
# SSL_CTX_add_extra_chain_cert is unavailable
# in the final CPython build.
# According to OpenSSL latest docs: "The engine
# will attempt to build the required chain for the CA store"
# It's not going to be used as a trust anchor! (i.e. not self-signed)
# "If no chain is specified, the library will try to complete the
# chain from the available CA certificates in the trusted
# CA storage, see SSL_CTX_load_verify_locations(3)."
# see: https://docs.openssl.org/master/man3/SSL_CTX_add_extra_chain_cert/#notes
if len(client_chain) > 1:
ctx.load_verify_locations(cadata=(b"\n".join(client_chain[1:])).decode())
__all__ = ("load_cert_chain",)

View File

@@ -0,0 +1,122 @@
from __future__ import annotations
import os
import secrets
import stat
import sys
import typing
import warnings
from hashlib import sha256
from io import UnsupportedOperation
if typing.TYPE_CHECKING:
import ssl
def load_cert_chain(
ctx: ssl.SSLContext,
certdata: str | bytes,
keydata: str | bytes | None = None,
password: typing.Callable[[], str | bytes] | str | bytes | None = None,
) -> None:
"""
Unique workaround the known limitation of CPython inability to initialize the mTLS context without files.
Only supported on Linux, FreeBSD, and OpenBSD.
:raise UnsupportedOperation: If anything goes wrong in the process.
"""
if (
sys.platform != "linux"
and sys.platform.startswith("freebsd") is False
and sys.platform.startswith("openbsd") is False
):
raise UnsupportedOperation(
f"Unable to provide support for in-memory client certificate: Unsupported platform {sys.platform}"
)
unique_name: str = f"{sha256(secrets.token_bytes(32)).hexdigest()}.pem"
if isinstance(certdata, bytes):
certdata = certdata.decode("ascii")
if keydata is not None:
if isinstance(keydata, bytes):
keydata = keydata.decode("ascii")
if hasattr(os, "memfd_create"):
fd = os.memfd_create(unique_name, os.MFD_CLOEXEC)
else:
# this branch patch is for CPython <3.8 and PyPy 3.7+
from ctypes import c_int, c_ushort, cdll, create_string_buffer, get_errno, util
loc = util.find_library("rt") or util.find_library("c")
if not loc:
raise UnsupportedOperation(
"Unable to provide support for in-memory client certificate: libc or librt not found."
)
lib = cdll.LoadLibrary(loc)
_shm_open = lib.shm_open
# _shm_unlink = lib.shm_unlink
buf_name = create_string_buffer(unique_name.encode())
try:
fd = _shm_open(
buf_name,
c_int(os.O_RDWR | os.O_CREAT),
c_ushort(stat.S_IRUSR | stat.S_IWUSR),
)
except SystemError as e:
raise UnsupportedOperation(
f"Unable to provide support for in-memory client certificate: {e}"
)
if fd == -1:
raise UnsupportedOperation(
f"Unable to provide support for in-memory client certificate: {os.strerror(get_errno())}"
)
# Linux 3.17+
path = f"/proc/self/fd/{fd}"
# Alt-path
shm_path = f"/dev/shm/{unique_name}"
if os.path.exists(path) is False:
if os.path.exists(shm_path):
path = shm_path
else:
os.fdopen(fd).close()
raise UnsupportedOperation(
"Unable to provide support for in-memory client certificate: no virtual patch available?"
)
os.chmod(path, stat.S_IRUSR | stat.S_IWUSR)
with open(path, "w") as fp:
fp.write(certdata)
if keydata:
fp.write(keydata)
path = fp.name
ctx.load_cert_chain(path, password=password)
# we shall start cleaning remnants
os.fdopen(fd).close()
if os.path.exists(shm_path):
os.unlink(shm_path)
if os.path.exists(path) or os.path.exists(shm_path):
warnings.warn(
"In-memory client certificate: The kernel leaked a file descriptor outside of its expected lifetime.",
ResourceWarning,
)
__all__ = ("load_cert_chain",)

View File

@@ -0,0 +1,28 @@
from __future__ import annotations
import warnings
warnings.warn(
(
"'urllib3.contrib.pyopenssl' module has been removed in urllib3.future due to incompatibilities "
"with our QUIC integration. While the import proceed without error for your convenience, it is rendered "
"completely ineffective. Were you looking for in-memory client certificate? "
"See https://urllib3future.readthedocs.io/en/latest/advanced-usage.html#in-memory-client-mtls-certificate"
),
category=DeprecationWarning,
stacklevel=2,
)
import OpenSSL.SSL # type: ignore # noqa
__all__ = ["inject_into_urllib3", "extract_from_urllib3"]
def inject_into_urllib3() -> None:
"""Kept for BC-purposes."""
...
def extract_from_urllib3() -> None:
"""Kept for BC-purposes."""
...

View File

@@ -0,0 +1,12 @@
from __future__ import annotations
from .factories import ResolverDescription, ResolverFactory
from .protocols import BaseResolver, ManyResolver, ProtocolResolver
__all__ = (
"ResolverFactory",
"ProtocolResolver",
"BaseResolver",
"ManyResolver",
"ResolverDescription",
)

View File

@@ -0,0 +1,11 @@
from __future__ import annotations
from .factories import AsyncResolverDescription, AsyncResolverFactory
from .protocols import AsyncBaseResolver, AsyncManyResolver
__all__ = (
"AsyncResolverDescription",
"AsyncResolverFactory",
"AsyncBaseResolver",
"AsyncManyResolver",
)

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
from ._urllib3 import (
AdGuardResolver,
CloudflareResolver,
GoogleResolver,
HTTPSResolver,
NextDNSResolver,
OpenDNSResolver,
Quad9Resolver,
)
__all__ = (
"HTTPSResolver",
"GoogleResolver",
"CloudflareResolver",
"AdGuardResolver",
"OpenDNSResolver",
"Quad9Resolver",
"NextDNSResolver",
)

View File

@@ -0,0 +1,656 @@
from __future__ import annotations
import socket
import typing
from asyncio import as_completed
from base64 import b64encode
from ....._async.connectionpool import AsyncHTTPSConnectionPool
from ....._async.response import AsyncHTTPResponse
from ....._collections import HTTPHeaderDict
from .....backend import ConnectionInfo, HttpVersion
from .....util.url import parse_url
from ...protocols import (
DomainNameServerQuery,
DomainNameServerReturn,
ProtocolResolver,
SupportedQueryType,
)
from ...utils import is_ipv4, is_ipv6, validate_length_of, parse_https_rdata
from ..protocols import AsyncBaseResolver
class HTTPSResolver(AsyncBaseResolver):
"""
Advanced DNS over HTTPS resolver.
No common ground emerged from IETF w/ JSON. Following Googles DNS over HTTPS schematics that is
also implemented at Cloudflare.
Support RFC 8484 without JSON. Disabled by default.
"""
implementation = "urllib3"
protocol = ProtocolResolver.DOH
def __init__(
self,
server: str | None,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
) -> None:
super().__init__(server, port or 443, *patterns, **kwargs)
self._path: str = "/resolve"
if "path" in kwargs:
if isinstance(kwargs["path"], str) and kwargs["path"] != "/":
self._path = kwargs["path"]
kwargs.pop("path")
self._rfc8484: bool = False
if "rfc8484" in kwargs:
if kwargs["rfc8484"]:
self._rfc8484 = True
kwargs.pop("rfc8484")
assert self._server is not None
if "source_address" in kwargs:
if isinstance(kwargs["source_address"], str):
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
if bind_ip and bind_port.isdigit():
kwargs["source_address"] = (
bind_ip,
int(bind_port),
)
else:
raise ValueError("invalid source_address given in parameters")
else:
raise ValueError("invalid source_address given in parameters")
if "proxy" in kwargs:
kwargs["_proxy"] = parse_url(kwargs["proxy"])
kwargs.pop("proxy")
if "maxsize" not in kwargs:
kwargs["maxsize"] = 10
if "proxy_headers" in kwargs and "_proxy" in kwargs:
proxy_headers = HTTPHeaderDict()
if not isinstance(kwargs["proxy_headers"], list):
kwargs["proxy_headers"] = [kwargs["proxy_headers"]]
for item in kwargs["proxy_headers"]:
if ":" not in item:
raise ValueError("Passed header is invalid in DNS parameters")
k, v = item.split(":", 1)
proxy_headers.add(k, v)
kwargs["_proxy_headers"] = proxy_headers
if "headers" in kwargs:
headers = HTTPHeaderDict()
if not isinstance(kwargs["headers"], list):
kwargs["headers"] = [kwargs["headers"]]
for item in kwargs["headers"]:
if ":" not in item:
raise ValueError("Passed header is invalid in DNS parameters")
k, v = item.split(":", 1)
headers.add(k, v)
kwargs["headers"] = headers
if "disabled_svn" in kwargs:
if not isinstance(kwargs["disabled_svn"], list):
kwargs["disabled_svn"] = [kwargs["disabled_svn"]]
disabled_svn = set()
for svn in kwargs["disabled_svn"]:
svn = svn.lower()
if svn == "h11":
disabled_svn.add(HttpVersion.h11)
elif svn == "h2":
disabled_svn.add(HttpVersion.h2)
elif svn == "h3":
disabled_svn.add(HttpVersion.h3)
kwargs["disabled_svn"] = disabled_svn
if "on_post_connection" in kwargs and callable(kwargs["on_post_connection"]):
self._connection_callback: (
typing.Callable[[ConnectionInfo], None] | None
) = kwargs["on_post_connection"]
kwargs.pop("on_post_connection")
else:
self._connection_callback = None
self._pool = AsyncHTTPSConnectionPool(self._server, self._port, **kwargs)
async def close(self) -> None: # type: ignore[override]
await self._pool.close()
def is_available(self) -> bool:
return self._pool.pool is not None
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Tried to resolve 'localhost' from a HTTPSResolver"
)
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
validate_length_of(host)
promises = []
remote_preemptive_quic_rr = False
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
quic_upgrade_via_dns_rr = False
if family in [socket.AF_UNSPEC, socket.AF_INET]:
if not self._rfc8484:
promises.append(
await self._pool.request_encode_url(
"GET",
self._path,
{"name": host, "type": "1"},
headers={"Accept": "application/dns-json"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
else:
dns_query = DomainNameServerQuery(
host, SupportedQueryType.A, override_id=0
)
dns_payload = bytes(dns_query)
promises.append(
await self._pool.request_encode_url(
"GET",
self._path,
{
"dns": b64encode(dns_payload).decode().replace("=", ""),
},
headers={"Accept": "application/dns-message"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
if not self._rfc8484:
promises.append(
await self._pool.request_encode_url(
"GET",
self._path,
{"name": host, "type": "28"},
headers={"Accept": "application/dns-json"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
else:
dns_query = DomainNameServerQuery(
host, SupportedQueryType.AAAA, override_id=0
)
dns_payload = bytes(dns_query)
promises.append(
await self._pool.request_encode_url(
"GET",
self._path,
{
"dns": b64encode(dns_payload).decode().replace("=", ""),
},
headers={"Accept": "application/dns-message"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
if quic_upgrade_via_dns_rr:
if not self._rfc8484:
promises.append(
await self._pool.request_encode_url(
"GET",
self._path,
{"name": host, "type": "65"},
headers={"Accept": "application/dns-json"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
else:
dns_query = DomainNameServerQuery(
host, SupportedQueryType.HTTPS, override_id=0
)
dns_payload = bytes(dns_query)
promises.append(
await self._pool.request_encode_url(
"GET",
self._path,
{
"dns": b64encode(dns_payload).decode().replace("=", ""),
},
headers={"Accept": "application/dns-message"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
tasks = []
responses = []
for promise in promises:
# already resolved
if isinstance(promise, AsyncHTTPResponse):
responses.append(promise)
continue
tasks.append(self._pool.get_response(promise=promise))
if tasks:
for waiting_promise_coro in as_completed(tasks):
responses.append(await waiting_promise_coro) # type: ignore[arg-type]
results: list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
] = []
for response in responses:
if response.status >= 300:
raise socket.gaierror(
f"DNS over HTTPS was unsuccessful, server response status {response.status}."
)
if not self._rfc8484:
payload = await response.json()
assert "Status" in payload and isinstance(payload["Status"], int)
if payload["Status"] != 0:
msg = (
payload["Comment"]
if "Comment" in payload
else f"Remote DNS indicated that an error occurred while providing resolution. Status {payload['Status']}."
)
if isinstance(msg, list):
msg = ", ".join(msg)
raise socket.gaierror(msg)
assert "Question" in payload and isinstance(payload["Question"], list)
if "Answer" not in payload:
continue
assert isinstance(payload["Answer"], list)
for answer in payload["Answer"]:
if answer["type"] not in [1, 28, 65]:
continue
assert "data" in answer
assert isinstance(answer["data"], str)
# DNS RR/HTTPS
if answer["type"] == 65:
# "1 . alpn=h3,h2 ipv4hint=104.16.132.229,104.16.133.229 ipv6hint=2606:4700::6810:84e5,2606:4700::6810:85e5"
# or..
# "1 . alpn=h2,h3"
rr: str = answer["data"]
if rr.startswith("\\#"): # it means, raw, bytes.
rr = "".join(rr[2:].split(" ")[2:])
try:
raw_record = bytes.fromhex(rr)
except ValueError:
raw_record = b""
if not raw_record:
continue
https_record = parse_https_rdata(raw_record)
if "h3" not in https_record["alpn"]:
continue
remote_preemptive_quic_rr = True
else:
rr_decode: dict[str, str] = dict(
tuple(_.lower().split("=", 1)) # type: ignore[misc]
for _ in rr.split(" ")
if "=" in _
)
if "alpn" not in rr_decode or "h3" not in rr_decode["alpn"]:
continue
remote_preemptive_quic_rr = True
if "ipv4hint" in rr_decode and family in [
socket.AF_UNSPEC,
socket.AF_INET,
]:
for ipv4 in rr_decode["ipv4hint"].split(","):
results.append(
(
socket.AF_INET,
socket.SOCK_DGRAM,
17,
"",
(
ipv4,
port,
),
)
)
if "ipv6hint" in rr_decode and family in [
socket.AF_UNSPEC,
socket.AF_INET6,
]:
for ipv6 in rr_decode["ipv6hint"].split(","):
results.append(
(
socket.AF_INET6,
socket.SOCK_DGRAM,
17,
"",
(
ipv6,
port,
0,
0,
),
)
)
continue
inet_type = (
socket.AF_INET if answer["type"] == 1 else socket.AF_INET6
)
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
(
answer["data"],
port,
)
if inet_type == socket.AF_INET
else (
answer["data"],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
else:
dns_resp = DomainNameServerReturn(await response.data)
for record in dns_resp.records:
if record[0] == SupportedQueryType.HTTPS:
assert isinstance(record[-1], dict)
if "h3" in record[-1]["alpn"]:
remote_preemptive_quic_rr = True
continue
assert not isinstance(record[-1], dict)
inet_type = (
socket.AF_INET
if record[0] == SupportedQueryType.A
else socket.AF_INET6
)
dst_addr = (
(
record[-1],
port,
)
if inet_type == socket.AF_INET
else (
record[-1],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
quic_results: list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
] = []
if remote_preemptive_quic_rr:
any_specified = False
for result in results:
if result[1] == socket.SOCK_STREAM:
quic_results.append(
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
)
else:
any_specified = True
break
if any_specified:
quic_results = []
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
class GoogleResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "google"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
if "rfc8484" in kwargs:
if kwargs["rfc8484"]:
kwargs["path"] = "/dns-query"
super().__init__("dns.google", port, *patterns, **kwargs)
class CloudflareResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "cloudflare"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query"})
super().__init__("cloudflare-dns.com", port, *patterns, **kwargs)
class AdGuardResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query", "rfc8484": True})
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
class OpenDNSResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "opendns"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query", "rfc8484": True})
super().__init__("dns.opendns.com", port, *patterns, **kwargs)
class Quad9Resolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "quad9"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query", "rfc8484": True})
super().__init__("dns11.quad9.net", port, *patterns, **kwargs)
class NextDNSResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "nextdns"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.nextdns.io", port, *patterns, **kwargs)

View File

@@ -0,0 +1,15 @@
from __future__ import annotations
try:
from ._qh3 import AdGuardResolver, NextDNSResolver, QUICResolver
except ImportError:
QUICResolver = None # type: ignore
AdGuardResolver = None # type: ignore
NextDNSResolver = None # type: ignore
__all__ = (
"QUICResolver",
"AdGuardResolver",
"NextDNSResolver",
)

View File

@@ -0,0 +1,557 @@
from __future__ import annotations
import asyncio
import socket
import ssl
import typing
from collections import deque
from ssl import SSLError
from time import time as monotonic
from qh3.quic.configuration import QuicConfiguration
from qh3.quic.connection import QuicConnection
from qh3.quic.events import (
ConnectionTerminated,
HandshakeCompleted,
QuicEvent,
StopSendingReceived,
StreamDataReceived,
StreamReset,
)
from .....util.ssl_ import IS_FIPS, resolve_cert_reqs
from ...protocols import (
COMMON_RCODE_LABEL,
DomainNameServerQuery,
DomainNameServerReturn,
ProtocolResolver,
SupportedQueryType,
)
from ...utils import (
is_ipv4,
is_ipv6,
rfc1035_pack,
rfc1035_should_read,
rfc1035_unpack,
validate_length_of,
)
from ..dou import PlainResolver
from ..system import SystemResolver
if IS_FIPS:
raise ImportError(
"DNS-over-QUIC disabled when Python is built with FIPS-compliant ssl module"
)
class QUICResolver(PlainResolver):
protocol = ProtocolResolver.DOQ
implementation = "qh3"
def __init__(
self,
server: str | None,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
):
super().__init__(server, port or 853, *patterns, **kwargs)
# qh3 load_default_certs seems off. need to investigate.
if "ca_cert_data" not in kwargs and "ca_certs" not in kwargs:
kwargs["ca_cert_data"] = []
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
try:
ctx.load_default_certs()
for der in ctx.get_ca_certs(binary_form=True):
kwargs["ca_cert_data"].append(ssl.DER_cert_to_PEM_cert(der))
if kwargs["ca_cert_data"]:
kwargs["ca_cert_data"] = "".join(kwargs["ca_cert_data"])
else:
del kwargs["ca_cert_data"]
except (AttributeError, ValueError, OSError):
del kwargs["ca_cert_data"]
if "ca_cert_data" not in kwargs and "ca_certs" not in kwargs:
if (
"cert_reqs" not in kwargs
or resolve_cert_reqs(kwargs["cert_reqs"]) is ssl.CERT_REQUIRED
):
raise ssl.SSLError(
"DoQ requires at least one CA loaded in order to verify the remote peer certificate. "
"Add ?cert_reqs=0 to disable certificate checks."
)
configuration = QuicConfiguration(
is_client=True,
alpn_protocols=["doq"],
server_name=self._server
if "server_hostname" not in kwargs
else kwargs["server_hostname"],
verify_mode=resolve_cert_reqs(kwargs["cert_reqs"])
if "cert_reqs" in kwargs
else ssl.CERT_REQUIRED,
cadata=kwargs["ca_cert_data"].encode()
if "ca_cert_data" in kwargs
else None,
cafile=kwargs["ca_certs"] if "ca_certs" in kwargs else None,
idle_timeout=300.0,
)
if "cert_file" in kwargs:
configuration.load_cert_chain(
kwargs["cert_file"],
kwargs["key_file"] if "key_file" in kwargs else None,
kwargs["key_password"] if "key_password" in kwargs else None,
)
elif "cert_data" in kwargs:
configuration.load_cert_chain(
kwargs["cert_data"],
kwargs["key_data"] if "key_data" in kwargs else None,
kwargs["key_password"] if "key_password" in kwargs else None,
)
self._quic = QuicConnection(configuration=configuration)
self._read_semaphore: asyncio.Semaphore = asyncio.Semaphore()
self._connect_attempt: asyncio.Event = asyncio.Event()
self._handshake_event: asyncio.Event = asyncio.Event()
self._terminated: bool = False
self._should_disconnect: bool = False
# DNS over QUIC mandate the size-prefix (unsigned int, 2b)
self._rfc1035_prefix_mandated = True
self._unconsumed: deque[DomainNameServerReturn] = deque()
self._pending: deque[DomainNameServerQuery] = deque()
async def close(self) -> None: # type: ignore[override]
if (
not self._terminated
and self._socket is not None
and not self._socket.should_connect()
):
self._quic.close()
while True:
datagrams = self._quic.datagrams_to_send(monotonic())
if not datagrams:
break
for datagram in datagrams:
data, addr = datagram
await self._socket.sendall(data)
self._socket.close()
await self._socket.wait_for_close()
self._terminated = True
if self._socket is None or self._socket.should_connect():
self._terminated = True
def is_available(self) -> bool:
if self._terminated:
return False
if self._socket is None or self._socket.should_connect():
return True
self._quic.handle_timer(monotonic())
if hasattr(self._quic, "_close_event") and self._quic._close_event is not None:
self._terminated = True
return not self._terminated
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Tried to resolve 'localhost' using the QUICResolver"
)
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
validate_length_of(host)
if self._socket is None and self._connect_attempt.is_set() is False:
self._connect_attempt.set()
assert self.server is not None
self._quic.connect((self._server, self._port), monotonic())
self._socket = await SystemResolver().create_connection(
(self.server, self.port or 853),
timeout=self._timeout,
source_address=self._source_address,
socket_options=None,
socket_kind=self._socket_type,
)
await self.__exchange_until(HandshakeCompleted, receive_first=False)
self._handshake_event.set()
else:
await self._handshake_event.wait()
assert self._socket is not None
remote_preemptive_quic_rr = False
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
quic_upgrade_via_dns_rr = False
tbq = []
if family in [socket.AF_UNSPEC, socket.AF_INET]:
tbq.append(SupportedQueryType.A)
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
tbq.append(SupportedQueryType.AAAA)
if quic_upgrade_via_dns_rr:
tbq.append(SupportedQueryType.HTTPS)
queries = DomainNameServerQuery.bulk(host, *tbq)
open_streams = []
for q in queries:
payload = bytes(q)
self._pending.append(q)
if self._rfc1035_prefix_mandated is True:
payload = rfc1035_pack(payload)
stream_id = self._quic.get_next_available_stream_id()
self._quic.send_stream_data(stream_id, payload, True)
open_streams.append(stream_id)
for dg in self._quic.datagrams_to_send(monotonic()):
await self._socket.sendall(dg[0])
responses: list[DomainNameServerReturn] = []
while len(responses) < len(tbq):
await self._read_semaphore.acquire()
if self._unconsumed:
dns_resp = None
for query in queries:
for unconsumed in self._unconsumed:
if unconsumed.id == query.id:
dns_resp = unconsumed
responses.append(dns_resp)
break
if dns_resp:
break
if dns_resp:
self._unconsumed.remove(dns_resp)
self._pending.remove(query)
self._read_semaphore.release()
continue
try:
events: list[StreamDataReceived] = await self.__exchange_until( # type: ignore[assignment]
StreamDataReceived,
receive_first=True,
event_type_collectable=(StreamDataReceived,),
respect_end_stream_signal=False,
)
payload = b"".join([e.data for e in events])
while rfc1035_should_read(payload):
events.extend(
await self.__exchange_until( # type: ignore[arg-type]
StreamDataReceived,
receive_first=True,
event_type_collectable=(StreamDataReceived,),
respect_end_stream_signal=False,
)
)
payload = b"".join([e.data for e in events])
except (TimeoutError, OSError, socket.timeout, ConnectionError) as e:
raise socket.gaierror(
"Got unexpectedly disconnected while waiting for name resolution"
) from e
self._read_semaphore.release()
if not payload:
continue
#: We can receive two responses at once (or more, concatenated). Let's unwrap them.
fragments = rfc1035_unpack(payload)
for fragment in fragments:
dns_resp = DomainNameServerReturn(fragment)
if any(dns_resp.id == _.id for _ in queries):
responses.append(dns_resp)
query_tbr: DomainNameServerQuery | None = None
for query_tbr in self._pending:
if query_tbr.id == dns_resp.id:
break
if query_tbr:
self._pending.remove(query_tbr)
else:
self._unconsumed.append(dns_resp)
if self._should_disconnect:
await self.close()
self._should_disconnect = False
self._terminated = True
results = []
for response in responses:
if not response.is_ok:
if response.rcode == 2:
raise socket.gaierror(
f"DNSSEC validation failure. Check http://dnsviz.net/d/{host}/dnssec/ and http://dnssec-debugger.verisignlabs.com/{host} for errors"
)
raise socket.gaierror(
f"DNS returned an error: {COMMON_RCODE_LABEL[response.rcode] if response.rcode in COMMON_RCODE_LABEL else f'code {response.rcode}'}"
)
for record in response.records:
if record[0] == SupportedQueryType.HTTPS:
assert isinstance(record[-1], dict)
if "h3" in record[-1]["alpn"]:
remote_preemptive_quic_rr = True
continue
assert not isinstance(record[-1], dict)
inet_type = (
socket.AF_INET
if record[0] == SupportedQueryType.A
else socket.AF_INET6
)
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
(
record[-1],
port,
)
if inet_type == socket.AF_INET
else (
record[-1],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
quic_results = []
if remote_preemptive_quic_rr:
any_specified = False
for result in results:
if result[1] == socket.SOCK_STREAM:
quic_results.append(
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
)
else:
any_specified = True
break
if any_specified:
quic_results = []
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
async def __exchange_until(
self,
event_type: type[QuicEvent] | tuple[type[QuicEvent], ...],
*,
receive_first: bool = False,
event_type_collectable: type[QuicEvent]
| tuple[type[QuicEvent], ...]
| None = None,
respect_end_stream_signal: bool = True,
) -> list[QuicEvent]:
assert self._socket is not None
while True:
if receive_first is False:
now = monotonic()
while True:
datagrams = self._quic.datagrams_to_send(now)
if not datagrams:
break
for datagram in datagrams:
data, addr = datagram
await self._socket.sendall(data)
events = []
while True:
if not self._quic._events:
data_in = await self._socket.recv(1500)
if not data_in:
break
now = monotonic()
if not isinstance(data_in, list):
self._quic.receive_datagram(
data_in, (self._server, self._port), now
)
else:
for gro_segment in data_in:
self._quic.receive_datagram(
gro_segment, (self._server, self._port), now
)
while True:
datagrams = self._quic.datagrams_to_send(now)
if not datagrams:
break
for datagram in datagrams:
data, addr = datagram
await self._socket.sendall(data)
for ev in iter(self._quic.next_event, None):
if isinstance(ev, ConnectionTerminated):
if ev.error_code == 298:
raise SSLError(
"DNS over QUIC did not succeed (Error 298). Chain certificate verification failed."
)
raise socket.gaierror(
f"DNS over QUIC encountered a unrecoverable failure (error {ev.error_code} {ev.reason_phrase})"
)
elif isinstance(ev, StreamReset):
self._terminated = True
raise socket.gaierror(
"DNS over QUIC server submitted a StreamReset. A request was rejected."
)
elif isinstance(ev, StopSendingReceived):
self._should_disconnect = True
continue
if event_type_collectable:
if isinstance(ev, event_type_collectable):
events.append(ev)
else:
events.append(ev)
if isinstance(ev, event_type):
if not respect_end_stream_signal:
return events
if hasattr(ev, "stream_ended") and ev.stream_ended:
return events
elif hasattr(ev, "stream_ended") is False:
return events
return events
class AdGuardResolver(
QUICResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
class NextDNSResolver(
QUICResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "nextdns"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.nextdns.io", port, *patterns, **kwargs)

View File

@@ -0,0 +1,19 @@
from __future__ import annotations
from ._ssl import (
AdGuardResolver,
CloudflareResolver,
GoogleResolver,
OpenDNSResolver,
Quad9Resolver,
TLSResolver,
)
__all__ = (
"TLSResolver",
"GoogleResolver",
"CloudflareResolver",
"AdGuardResolver",
"Quad9Resolver",
"OpenDNSResolver",
)

View File

@@ -0,0 +1,197 @@
from __future__ import annotations
import socket
import typing
from .....util._async.ssl_ import ssl_wrap_socket
from .....util.ssl_ import resolve_cert_reqs
from ...protocols import ProtocolResolver
from ..dou import PlainResolver
from ..system import SystemResolver
class TLSResolver(PlainResolver):
"""
Basic DNS resolver over TLS.
Comply with RFC 7858: https://datatracker.ietf.org/doc/html/rfc7858
"""
protocol = ProtocolResolver.DOT
implementation = "ssl"
def __init__(
self,
server: str | None,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
) -> None:
self._socket_type = socket.SOCK_STREAM
super().__init__(server, port or 853, *patterns, **kwargs)
# DNS over TLS mandate the size-prefix (unsigned int, 2 bytes)
self._rfc1035_prefix_mandated = True
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if self._socket is None and self._connect_attempt.is_set() is False:
assert self.server is not None
self._connect_attempt.set()
self._socket = await SystemResolver().create_connection(
(self.server, self.port or 853),
timeout=self._timeout,
source_address=self._source_address,
socket_options=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1, "tcp"),),
socket_kind=self._socket_type,
)
self._socket = await ssl_wrap_socket(
self._socket,
server_hostname=self.server
if "server_hostname" not in self._kwargs
else self._kwargs["server_hostname"],
keyfile=self._kwargs["key_file"]
if "key_file" in self._kwargs
else None,
certfile=self._kwargs["cert_file"]
if "cert_file" in self._kwargs
else None,
cert_reqs=resolve_cert_reqs(self._kwargs["cert_reqs"])
if "cert_reqs" in self._kwargs
else None,
ca_certs=self._kwargs["ca_certs"]
if "ca_certs" in self._kwargs
else None,
ssl_version=self._kwargs["ssl_version"]
if "ssl_version" in self._kwargs
else None,
ciphers=self._kwargs["ciphers"] if "ciphers" in self._kwargs else None,
ca_cert_dir=self._kwargs["ca_cert_dir"]
if "ca_cert_dir" in self._kwargs
else None,
key_password=self._kwargs["key_password"]
if "key_password" in self._kwargs
else None,
ca_cert_data=self._kwargs["ca_cert_data"]
if "ca_cert_data" in self._kwargs
else None,
certdata=self._kwargs["cert_data"]
if "cert_data" in self._kwargs
else None,
keydata=self._kwargs["key_data"]
if "key_data" in self._kwargs
else None,
)
self._connect_finalized.set()
return await super().getaddrinfo(
host,
port,
family=family,
type=type,
proto=proto,
flags=flags,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
class GoogleResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "google"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.google", port, *patterns, **kwargs)
class CloudflareResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "cloudflare"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("1.1.1.1", port, *patterns, **kwargs)
class AdGuardResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
class OpenDNSResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "opendns"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.opendns.com", port, *patterns, **kwargs)
class Quad9Resolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "quad9"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns11.quad9.net", port, *patterns, **kwargs)

View File

@@ -0,0 +1,17 @@
from __future__ import annotations
from ._socket import (
AdGuardResolver,
CloudflareResolver,
GoogleResolver,
PlainResolver,
Quad9Resolver,
)
__all__ = (
"PlainResolver",
"CloudflareResolver",
"GoogleResolver",
"Quad9Resolver",
"AdGuardResolver",
)

View File

@@ -0,0 +1,431 @@
from __future__ import annotations
import asyncio
import socket
import typing
from collections import deque
from ....ssa import AsyncSocket
from ...protocols import (
COMMON_RCODE_LABEL,
DomainNameServerQuery,
DomainNameServerReturn,
ProtocolResolver,
SupportedQueryType,
)
from ...utils import (
is_ipv4,
is_ipv6,
packet_fragment,
rfc1035_pack,
rfc1035_should_read,
rfc1035_unpack,
validate_length_of,
)
from ..protocols import AsyncBaseResolver
from ..system import SystemResolver
class PlainResolver(AsyncBaseResolver):
"""
Minimalist DNS resolver over UDP
Comply with RFC 1035: https://datatracker.ietf.org/doc/html/rfc1035
EDNS is not supported, yet. But we plan to. Willing to contribute?
"""
protocol = ProtocolResolver.DOU
implementation = "socket"
def __init__(
self,
server: str | None,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
) -> None:
super().__init__(server, port, *patterns, **kwargs)
self._socket: AsyncSocket | None = None
if not hasattr(self, "_socket_type"):
self._socket_type = socket.SOCK_DGRAM
if "source_address" in kwargs and isinstance(kwargs["source_address"], str):
if ":" in kwargs["source_address"]:
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
self._source_address: tuple[str, int] | None = (bind_ip, int(bind_port))
else:
self._source_address = (kwargs["source_address"], 0)
else:
self._source_address = None
if "timeout" in kwargs and isinstance(
kwargs["timeout"],
(
float,
int,
),
):
self._timeout: float | int | None = kwargs["timeout"]
else:
self._timeout = None
#: Only useful for inheritance, e.g. DNS over TLS support dns-message but require a prefix.
self._rfc1035_prefix_mandated: bool = False
self._unconsumed: deque[DomainNameServerReturn] = deque()
self._pending: deque[DomainNameServerQuery] = deque()
self._read_semaphore: asyncio.Semaphore = asyncio.Semaphore()
self._connect_attempt: asyncio.Event = asyncio.Event()
self._connect_finalized: asyncio.Event = asyncio.Event()
self._terminated: bool = False
async def close(self) -> None: # type: ignore[override]
if not self._terminated:
with self._lock:
if self._socket is not None:
self._socket.close()
await self._socket.wait_for_close()
self._terminated = True
def is_available(self) -> bool:
return not self._terminated
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Tried to resolve 'localhost' from a PlainResolver"
)
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
validate_length_of(host)
if self._socket is None and self._connect_attempt.is_set() is False:
self._connect_attempt.set()
assert self.server is not None
self._socket = await SystemResolver().create_connection(
(self.server, self.port or 53),
timeout=self._timeout,
source_address=self._source_address,
socket_options=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1, "tcp"),),
socket_kind=self._socket_type,
)
self._connect_finalized.set()
else:
await self._connect_finalized.wait()
assert self._socket is not None
await self._socket.wait_for_readiness()
remote_preemptive_quic_rr = False
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
quic_upgrade_via_dns_rr = False
tbq = []
if family in [socket.AF_UNSPEC, socket.AF_INET]:
tbq.append(SupportedQueryType.A)
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
tbq.append(SupportedQueryType.AAAA)
if quic_upgrade_via_dns_rr:
tbq.append(SupportedQueryType.HTTPS)
queries = DomainNameServerQuery.bulk(host, *tbq)
for q in queries:
payload = bytes(q)
self._pending.append(q)
if self._rfc1035_prefix_mandated is True:
payload = rfc1035_pack(payload)
await self._socket.sendall(payload)
responses: list[DomainNameServerReturn] = []
while len(responses) < len(tbq):
await self._read_semaphore.acquire()
#: There we want to verify if another thread got a response that belong to this thread.
if self._unconsumed:
dns_resp = None
for query in queries:
for unconsumed in self._unconsumed:
if unconsumed.id == query.id:
dns_resp = unconsumed
responses.append(dns_resp)
break
if dns_resp:
break
if dns_resp:
self._pending.remove(query)
self._unconsumed.remove(dns_resp)
self._read_semaphore.release()
continue
try:
data_in_or_segments = await self._socket.recv(1500)
if isinstance(data_in_or_segments, list):
payloads = data_in_or_segments
elif data_in_or_segments:
payloads = [data_in_or_segments]
else:
payloads = []
if self._rfc1035_prefix_mandated is True and payloads:
payload = b"".join(payloads)
while rfc1035_should_read(payload):
extra = await self._socket.recv(1500)
if isinstance(extra, list):
payload += b"".join(extra)
else:
payload += extra
payloads = [payload]
except (TimeoutError, OSError, socket.timeout, ConnectionError) as e:
raise socket.gaierror(
"Got unexpectedly disconnected while waiting for name resolution"
) from e
self._read_semaphore.release()
if not payloads:
self._terminated = True
raise socket.gaierror(
"Got unexpectedly disconnected while waiting for name resolution"
)
pending_raw_identifiers = [_.raw_id for _ in self._pending]
for payload in payloads:
#: We can receive two responses at once (or more, concatenated). Let's unwrap them.
if self._rfc1035_prefix_mandated is True:
fragments = rfc1035_unpack(payload)
else:
fragments = packet_fragment(payload, *pending_raw_identifiers)
for fragment in fragments:
dns_resp = DomainNameServerReturn(fragment)
if any(dns_resp.id == _.id for _ in queries):
responses.append(dns_resp)
query_tbr: DomainNameServerQuery | None = None
for query_tbr in self._pending:
if query_tbr.id == dns_resp.id:
break
if query_tbr:
self._pending.remove(query_tbr)
else:
self._unconsumed.append(dns_resp)
results = []
for response in responses:
if not response.is_ok:
if response.rcode == 2:
raise socket.gaierror(
f"DNSSEC validation failure. Check http://dnsviz.net/d/{host}/dnssec/ and http://dnssec-debugger.verisignlabs.com/{host} for errors"
)
raise socket.gaierror(
f"DNS returned an error: {COMMON_RCODE_LABEL[response.rcode] if response.rcode in COMMON_RCODE_LABEL else f'code {response.rcode}'}"
)
for record in response.records:
if record[0] == SupportedQueryType.HTTPS:
assert isinstance(record[-1], dict)
if "h3" in record[-1]["alpn"]:
remote_preemptive_quic_rr = True
continue
assert not isinstance(record[-1], dict)
inet_type = (
socket.AF_INET
if record[0] == SupportedQueryType.A
else socket.AF_INET6
)
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
(
record[-1],
port,
)
if inet_type == socket.AF_INET
else (
record[-1],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
quic_results = []
if remote_preemptive_quic_rr:
any_specified = False
for result in results:
if result[1] == socket.SOCK_STREAM:
quic_results.append(
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
)
else:
any_specified = True
break
if any_specified:
quic_results = []
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
class CloudflareResolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "cloudflare"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("1.1.1.1", port, *patterns, **kwargs)
class GoogleResolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "google"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("8.8.8.8", port, *patterns, **kwargs)
class Quad9Resolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "quad9"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("9.9.9.9", port, *patterns, **kwargs)
class AdGuardResolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("94.140.14.140", port, *patterns, **kwargs)

View File

@@ -0,0 +1,243 @@
from __future__ import annotations
import importlib
import inspect
import typing
from abc import ABCMeta
from base64 import b64encode
from typing import Any
from urllib.parse import parse_qs
from ....util import parse_url
from ..factories import ResolverDescription
from ..protocols import ProtocolResolver
from .protocols import AsyncBaseResolver
class AsyncResolverFactory(metaclass=ABCMeta):
@staticmethod
def has(
protocol: ProtocolResolver,
specifier: str | None = None,
implementation: str | None = None,
) -> bool:
package_name: str = __name__.split(".")[0]
module_expr = f".{protocol.value.replace('-', '_')}"
if implementation:
module_expr += f"._{implementation.replace('-', '_').lower()}"
try:
resolver_module = importlib.import_module(
module_expr, f"{package_name}.contrib.resolver._async"
)
except ImportError:
return False
implementations: list[tuple[str, type[AsyncBaseResolver]]] = inspect.getmembers(
resolver_module,
lambda e: isinstance(e, type)
and issubclass(e, AsyncBaseResolver)
and (
(specifier is None and e.specifier is None) or specifier == e.specifier
),
)
if not implementations:
return False
return True
@staticmethod
def new(
protocol: ProtocolResolver,
specifier: str | None = None,
implementation: str | None = None,
**kwargs: Any,
) -> AsyncBaseResolver:
package_name: str = __name__.split(".")[0]
module_expr = f".{protocol.value.replace('-', '_')}"
if implementation:
module_expr += f"._{implementation.replace('-', '_').lower()}"
spe_msg = " " if specifier is None else f' (w/ specifier "{specifier}") '
try:
resolver_module = importlib.import_module(
module_expr, f"{package_name}.contrib.resolver._async"
)
except ImportError as e:
raise NotImplementedError(
f"{protocol}{spe_msg}cannot be loaded. Tried to import '{module_expr}'. Did you specify a non-existent implementation?"
) from e
implementations: list[tuple[str, type[AsyncBaseResolver]]] = inspect.getmembers(
resolver_module,
lambda e: isinstance(e, type)
and issubclass(e, AsyncBaseResolver)
and (
(specifier is None and e.specifier is None) or specifier == e.specifier
)
and hasattr(e, "protocol")
and e.protocol == protocol,
)
if not implementations:
raise NotImplementedError(
f"{protocol}{spe_msg}cannot be loaded. "
"No compatible implementation available. "
"Make sure your implementation inherit from BaseResolver."
)
implementation_target: type[AsyncBaseResolver] = implementations.pop()[1]
return implementation_target(**kwargs)
class AsyncResolverDescription(ResolverDescription):
"""Describe how a BaseResolver must be instantiated."""
def new(self) -> AsyncBaseResolver:
kwargs = {**self.kwargs}
if self.server:
kwargs["server"] = self.server
if self.port:
kwargs["port"] = self.port
if self.host_patterns:
kwargs["patterns"] = self.host_patterns
return AsyncResolverFactory.new(
self.protocol,
self.specifier,
self.implementation,
**kwargs,
)
@staticmethod
def from_url(url: str) -> AsyncResolverDescription:
parsed_url = parse_url(url)
schema = parsed_url.scheme
if schema is None:
raise ValueError("Given DNS url is missing a protocol")
specifier = None
implementation = None
if "+" in schema:
schema, specifier = tuple(schema.lower().split("+", 1))
protocol = ProtocolResolver(schema)
kwargs: dict[str, typing.Any] = {}
if parsed_url.path:
kwargs["path"] = parsed_url.path
if parsed_url.auth:
kwargs["headers"] = dict()
if ":" in parsed_url.auth:
username, password = parsed_url.auth.split(":")
username = username.strip("'\"")
password = password.strip("'\"")
kwargs["headers"]["Authorization"] = (
f"Basic {b64encode(f'{username}:{password}'.encode()).decode()}"
)
else:
kwargs["headers"]["Authorization"] = f"Bearer {parsed_url.auth}"
if parsed_url.query:
parameters = parse_qs(parsed_url.query)
for parameter in parameters:
if not parameters[parameter]:
continue
parameter_insensible = parameter.lower()
if (
isinstance(parameters[parameter], list)
and len(parameters[parameter]) > 1
):
if parameter == "implementation":
raise ValueError("Only one implementation can be passed to URL")
values = []
for e in parameters[parameter]:
if "," in e:
values.extend(e.split(","))
else:
values.append(e)
if parameter_insensible in kwargs:
if isinstance(kwargs[parameter_insensible], list):
kwargs[parameter_insensible].extend(values)
else:
values.append(kwargs[parameter_insensible])
kwargs[parameter_insensible] = values
continue
kwargs[parameter_insensible] = values
continue
value: str = parameters[parameter][0].lower().strip(" ")
if parameter == "implementation":
implementation = value
continue
if "," in value:
list_of_values = value.split(",")
if parameter_insensible in kwargs:
if isinstance(kwargs[parameter_insensible], list):
kwargs[parameter_insensible].extend(list_of_values)
else:
list_of_values.append(kwargs[parameter_insensible])
continue
kwargs[parameter_insensible] = list_of_values
continue
value_converted: bool | int | float | None = None
if value in ["false", "true"]:
value_converted = True if value == "true" else False
elif value.isdigit():
value_converted = int(value)
elif (
value.count(".") == 1
and value.index(".") > 0
and value.replace(".", "").isdigit()
):
value_converted = float(value)
kwargs[parameter_insensible] = (
value if value_converted is None else value_converted
)
host_patterns: list[str] = []
if "hosts" in kwargs:
host_patterns = (
kwargs["hosts"].split(",")
if isinstance(kwargs["hosts"], str)
else kwargs["hosts"]
)
del kwargs["hosts"]
return AsyncResolverDescription(
protocol,
specifier,
implementation,
parsed_url.host,
parsed_url.port,
*host_patterns,
**kwargs,
)

View File

@@ -0,0 +1,5 @@
from __future__ import annotations
from ._dict import InMemoryResolver
__all__ = ("InMemoryResolver",)

View File

@@ -0,0 +1,186 @@
from __future__ import annotations
import socket
import typing
from .....util.url import _IPV6_ADDRZ_RE
from ...protocols import ProtocolResolver
from ...utils import is_ipv4, is_ipv6
from ..protocols import AsyncBaseResolver
class InMemoryResolver(AsyncBaseResolver):
protocol = ProtocolResolver.MANUAL
implementation = "dict"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
kwargs.pop("port")
super().__init__(None, None, *patterns, **kwargs)
self._maxsize = 65535 if "maxsize" not in kwargs else int(kwargs["maxsize"])
self._hosts: dict[str, list[tuple[socket.AddressFamily, str]]] = {}
if self._host_patterns:
for record in self._host_patterns:
if ":" not in record:
continue
hostname, addr = record.split(":", 1)
self.register(hostname, addr)
self._host_patterns = tuple([])
def recycle(self) -> AsyncBaseResolver:
return self
async def close(self) -> None: # type: ignore[override]
pass # no-op
def is_available(self) -> bool:
return True
def have_constraints(self) -> bool:
return True
def support(self, hostname: str | bytes | None) -> bool | None:
if hostname is None:
hostname = "localhost"
if isinstance(hostname, bytes):
hostname = hostname.decode("ascii")
return hostname in self._hosts
def register(self, hostname: str, ipaddr: str) -> None:
if hostname not in self._hosts:
self._hosts[hostname] = []
else:
for e in self._hosts[hostname]:
t, addr = e
if addr in ipaddr:
return
if _IPV6_ADDRZ_RE.match(ipaddr):
self._hosts[hostname].append((socket.AF_INET6, ipaddr[1:-1]))
elif is_ipv6(ipaddr):
self._hosts[hostname].append((socket.AF_INET6, ipaddr))
else:
self._hosts[hostname].append((socket.AF_INET, ipaddr))
if len(self._hosts) > self._maxsize:
k = None
for k in self._hosts.keys():
break
if k:
self._hosts.pop(k)
def clear(self, hostname: str) -> None:
if hostname in self._hosts:
del self._hosts[hostname]
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
host = "localhost" # Defensive: stdlib cpy behavior
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
results: list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
] = []
if host not in self._hosts:
raise socket.gaierror(f"no records found for hostname {host} in-memory")
for entry in self._hosts[host]:
addr_type, addr_target = entry
if family != socket.AF_UNSPEC:
if family != addr_type:
continue
results.append(
(
addr_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
(addr_target, port)
if addr_type == socket.AF_INET
else (addr_target, port, 0, 0),
)
)
if not results:
raise socket.gaierror(f"no records found for hostname {host} in-memory")
return results

View File

@@ -0,0 +1,105 @@
from __future__ import annotations
import socket
import typing
from ...protocols import ProtocolResolver
from ...utils import is_ipv4, is_ipv6
from ..protocols import AsyncBaseResolver
class NullResolver(AsyncBaseResolver):
protocol = ProtocolResolver.NULL
implementation = "dummy"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
kwargs.pop("port")
super().__init__(None, None, *patterns, **kwargs)
def recycle(self) -> AsyncBaseResolver:
return self
async def close(self) -> None: # type: ignore[override]
pass # no-op
def is_available(self) -> bool:
return True
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
host = "localhost" # Defensive: stdlib cpy behavior
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
raise socket.gaierror(f"Tried to resolve '{host}' using the NullResolver")
__all__ = ("NullResolver",)

View File

@@ -0,0 +1,375 @@
from __future__ import annotations
import asyncio
import ipaddress
import socket
import struct
import sys
import typing
from abc import ABCMeta, abstractmethod
from datetime import datetime, timedelta, timezone
from ...._constant import UDP_LINUX_GRO
from ...._typing import _TYPE_SOCKET_OPTIONS, _TYPE_TIMEOUT_INTERNAL
from ....exceptions import LocationParseError
from ....util.connection import _set_socket_options, allowed_gai_family
from ....util.timeout import _DEFAULT_TIMEOUT
from ...ssa import AsyncSocket
from ...ssa._timeout import timeout as timeout_
from ..protocols import BaseResolver
class AsyncBaseResolver(BaseResolver, metaclass=ABCMeta):
def recycle(self) -> AsyncBaseResolver:
return super().recycle() # type: ignore[return-value]
@abstractmethod
async def close(self) -> None: # type: ignore[override]
"""Terminate the given resolver instance. This should render it unusable. Further inquiries should raise an exception."""
raise NotImplementedError
@abstractmethod
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
"""This method align itself on the standard library socket.getaddrinfo(). It must be implemented as-is on your Resolver."""
raise NotImplementedError
# This function is copied from socket.py in the Python 2.7 standard
# library test suite. Added to its signature is only `socket_options`.
# One additional modification is that we avoid binding to IPv6 servers
# discovered in DNS if the system doesn't have IPv6 functionality.
async def create_connection( # type: ignore[override]
self,
address: tuple[str, int],
timeout: _TYPE_TIMEOUT_INTERNAL = _DEFAULT_TIMEOUT,
source_address: tuple[str, int] | None = None,
socket_options: _TYPE_SOCKET_OPTIONS | None = None,
socket_kind: socket.SocketKind = socket.SOCK_STREAM,
*,
quic_upgrade_via_dns_rr: bool = False,
timing_hook: typing.Callable[[tuple[timedelta, timedelta, datetime]], None]
| None = None,
default_socket_family: socket.AddressFamily = socket.AF_UNSPEC,
) -> AsyncSocket:
"""Connect to *address* and return the socket object.
Convenience function. Connect to *address* (a 2-tuple ``(host,
port)``) and return the socket object. Passing the optional
*timeout* parameter will set the timeout on the socket instance
before attempting to connect. If no *timeout* is supplied, the
global default timeout setting returned by :func:`socket.getdefaulttimeout`
is used. If *source_address* is set it must be a tuple of (host, port)
for the socket to bind as a source address before making the connection.
An host of '' or port 0 tells the OS to use the default.
"""
host, port = address
if host.startswith("["):
host = host.strip("[]")
err = None
# Using the value from allowed_gai_family() in the context of getaddrinfo lets
# us select whether to work with IPv4 DNS records, IPv6 records, or both.
# The original create_connection function always returns all records.
family = allowed_gai_family()
if family != socket.AF_UNSPEC:
default_socket_family = family
if source_address is not None:
if isinstance(
ipaddress.ip_address(source_address[0]), ipaddress.IPv4Address
):
default_socket_family = socket.AF_INET
else:
default_socket_family = socket.AF_INET6
try:
host.encode("idna")
except UnicodeError:
raise LocationParseError(f"'{host}', label empty or too long") from None
dt_pre_resolve = datetime.now(tz=timezone.utc)
if timeout is not _DEFAULT_TIMEOUT and timeout is not None:
# we can hang here in case of bad networking conditions
# the DNS may never answer or the packets can be lost.
# this isn't possible in sync mode. unfortunately.
# found by user at https://github.com/jawah/niquests/issues/183
# todo: find a way to limit getaddrinfo delays in sync mode.
try:
async with timeout_(timeout):
records = await self.getaddrinfo(
host,
port,
default_socket_family,
socket_kind,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
except TimeoutError:
raise socket.gaierror(
f"unable to resolve '{host}' within timeout. the DNS server may be unresponsive."
)
else:
records = await self.getaddrinfo(
host,
port,
default_socket_family,
socket_kind,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
delta_post_resolve = datetime.now(tz=timezone.utc) - dt_pre_resolve
dt_pre_established = datetime.now(tz=timezone.utc)
for res in records:
af, socktype, proto, canonname, sa = res
sock = None
try:
sock = AsyncSocket(af, socktype, proto)
# we need to add this or reusing the same origin port will likely fail within
# short period of time. kernel put port on wait shut.
if source_address:
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
except (OSError, AttributeError): # Defensive: very old OS?
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
except (
OSError,
AttributeError,
): # Defensive: we can't do anything better than this.
pass
try:
sock.setsockopt(
socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0)
)
except (OSError, AttributeError):
pass
# attempt to leverage GRO when under Linux
if socktype == socket.SOCK_DGRAM and sys.platform == "linux":
try:
sock.setsockopt(socket.SOL_UDP, UDP_LINUX_GRO, 1)
except OSError: # Defensive: oh, well(...) anyway!
pass
# If provided, set socket level options before connecting.
_set_socket_options(sock, socket_options)
if timeout is not _DEFAULT_TIMEOUT:
sock.settimeout(timeout)
if source_address:
sock.bind(source_address)
try:
await sock.connect(sa)
except asyncio.CancelledError:
sock.close()
raise
# Break explicitly a reference cycle
err = None
delta_post_established = (
datetime.now(tz=timezone.utc) - dt_pre_established
)
if timing_hook is not None:
timing_hook(
(
delta_post_resolve,
delta_post_established,
datetime.now(tz=timezone.utc),
)
)
return sock
except (OSError, OverflowError) as _:
err = _
if sock is not None:
sock.close()
if isinstance(_, OverflowError):
break
if err is not None:
try:
raise err
finally:
# Break explicitly a reference cycle
err = None
else:
raise OSError("getaddrinfo returns an empty list")
class AsyncManyResolver(AsyncBaseResolver):
"""
Special resolver that use many child resolver. Priorities
are based on given order (list of BaseResolver).
"""
def __init__(self, *resolvers: AsyncBaseResolver) -> None:
super().__init__(None, None)
self._size = len(resolvers)
self._unconstrained: list[AsyncBaseResolver] = [
_ for _ in resolvers if not _.have_constraints()
]
self._constrained: list[AsyncBaseResolver] = [
_ for _ in resolvers if _.have_constraints()
]
self._concurrent: int = 0
self._terminated: bool = False
def recycle(self) -> AsyncBaseResolver:
resolvers = []
for resolver in self._unconstrained + self._constrained:
resolvers.append(resolver.recycle())
return AsyncManyResolver(*resolvers)
async def close(self) -> None: # type: ignore[override]
for resolver in self._unconstrained + self._constrained:
await resolver.close()
self._terminated = True
def is_available(self) -> bool:
return not self._terminated
def __resolvers(
self, constrained: bool = False
) -> typing.Generator[AsyncBaseResolver, None, None]:
resolvers = self._unconstrained if not constrained else self._constrained
if not resolvers:
return
with self._lock:
self._concurrent += 1
try:
resolver_count = len(resolvers)
start_idx = (self._concurrent - 1) % resolver_count
for idx in range(start_idx, resolver_count):
if not resolvers[idx].is_available():
with self._lock:
resolvers[idx] = resolvers[idx].recycle()
yield resolvers[idx]
if start_idx > 0:
for idx in range(0, start_idx):
if not resolvers[idx].is_available():
with self._lock:
resolvers[idx] = resolvers[idx].recycle()
yield resolvers[idx]
finally:
with self._lock:
self._concurrent -= 1
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if isinstance(host, bytes):
host = host.decode("ascii")
if host is None:
host = "localhost"
tested_resolvers = []
any_constrained_tried: bool = False
for resolver in self.__resolvers(True):
can_resolve = resolver.support(host)
if can_resolve is True:
any_constrained_tried = True
try:
results = await resolver.getaddrinfo(
host,
port,
family,
type,
proto,
flags,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
if results:
return results
except socket.gaierror as exc:
if isinstance(exc.args[0], str) and (
"DNSSEC" in exc.args[0] or "DNSKEY" in exc.args[0]
):
raise
continue
elif can_resolve is False:
tested_resolvers.append(resolver)
if any_constrained_tried:
raise socket.gaierror(
f"Name or service not known: {host} using {self._size - len(self._unconstrained)} resolver(s)"
)
for resolver in self.__resolvers():
try:
results = await resolver.getaddrinfo(
host,
port,
family,
type,
proto,
flags,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
if results:
return results
except socket.gaierror as exc:
if isinstance(exc.args[0], str) and (
"DNSSEC" in exc.args[0] or "DNSKEY" in exc.args[0]
):
raise
continue
raise socket.gaierror(
f"Name or service not known: {host} using {self._size - len(self._constrained)} resolver(s)"
)

View File

@@ -0,0 +1,5 @@
from __future__ import annotations
from ._socket import SystemResolver
__all__ = ("SystemResolver",)

View File

@@ -0,0 +1,66 @@
from __future__ import annotations
import asyncio
import socket
import typing
from ...protocols import ProtocolResolver
from ..protocols import AsyncBaseResolver
class SystemResolver(AsyncBaseResolver):
implementation = "socket"
protocol = ProtocolResolver.SYSTEM
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
kwargs.pop("port")
super().__init__(None, None, *patterns, **kwargs)
def support(self, hostname: str | bytes | None) -> bool | None:
if hostname is None:
return True
if isinstance(hostname, bytes):
hostname = hostname.decode("ascii")
if hostname == "localhost":
return True
return super().support(hostname)
def recycle(self) -> AsyncBaseResolver:
return self
async def close(self) -> None: # type: ignore[override]
pass # no-op!
def is_available(self) -> bool:
return True
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
return await asyncio.get_running_loop().getaddrinfo(
host=host,
port=port,
family=family,
type=type,
proto=proto,
flags=flags,
)

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
from ._urllib3 import (
AdGuardResolver,
CloudflareResolver,
GoogleResolver,
HTTPSResolver,
NextDNSResolver,
OpenDNSResolver,
Quad9Resolver,
)
__all__ = (
"HTTPSResolver",
"GoogleResolver",
"CloudflareResolver",
"AdGuardResolver",
"OpenDNSResolver",
"Quad9Resolver",
"NextDNSResolver",
)

View File

@@ -0,0 +1,641 @@
from __future__ import annotations
import socket
import typing
from base64 import b64encode
from ...._collections import HTTPHeaderDict
from ....backend import ConnectionInfo, HttpVersion, ResponsePromise
from ....connectionpool import HTTPSConnectionPool
from ....response import HTTPResponse
from ....util.url import parse_url
from ..protocols import (
BaseResolver,
DomainNameServerQuery,
DomainNameServerReturn,
ProtocolResolver,
SupportedQueryType,
)
from ..utils import is_ipv4, is_ipv6, validate_length_of, parse_https_rdata
class HTTPSResolver(BaseResolver):
"""
Advanced DNS over HTTPS resolver.
No common ground emerged from IETF w/ JSON. Following Googles DNS over HTTPS schematics that is
also implemented at Cloudflare.
Support RFC 8484 without JSON. Disabled by default.
"""
implementation = "urllib3"
protocol = ProtocolResolver.DOH
def __init__(
self,
server: str | None,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
) -> None:
super().__init__(server, port or 443, *patterns, **kwargs)
self._path: str = "/resolve"
if "path" in kwargs:
if isinstance(kwargs["path"], str) and kwargs["path"] != "/":
self._path = kwargs["path"]
kwargs.pop("path")
self._rfc8484: bool = False
if "rfc8484" in kwargs:
if kwargs["rfc8484"]:
self._rfc8484 = True
kwargs.pop("rfc8484")
assert self._server is not None
if "source_address" in kwargs:
if isinstance(kwargs["source_address"], str):
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
if bind_ip and bind_port.isdigit():
kwargs["source_address"] = (
bind_ip,
int(bind_port),
)
else:
raise ValueError("invalid source_address given in parameters")
else:
raise ValueError("invalid source_address given in parameters")
if "proxy" in kwargs:
kwargs["_proxy"] = parse_url(kwargs["proxy"])
kwargs.pop("proxy")
if "maxsize" not in kwargs:
kwargs["maxsize"] = 10
if "proxy_headers" in kwargs and "_proxy" in kwargs:
proxy_headers = HTTPHeaderDict()
if not isinstance(kwargs["proxy_headers"], list):
kwargs["proxy_headers"] = [kwargs["proxy_headers"]]
for item in kwargs["proxy_headers"]:
if ":" not in item:
raise ValueError("Passed header is invalid in DNS parameters")
k, v = item.split(":", 1)
proxy_headers.add(k, v)
kwargs["_proxy_headers"] = proxy_headers
if "headers" in kwargs:
headers = HTTPHeaderDict()
if not isinstance(kwargs["headers"], list):
kwargs["headers"] = [kwargs["headers"]]
for item in kwargs["headers"]:
if ":" not in item:
raise ValueError("Passed header is invalid in DNS parameters")
k, v = item.split(":", 1)
headers.add(k, v)
kwargs["headers"] = headers
if "disabled_svn" in kwargs:
if not isinstance(kwargs["disabled_svn"], list):
kwargs["disabled_svn"] = [kwargs["disabled_svn"]]
disabled_svn = set()
for svn in kwargs["disabled_svn"]:
svn = svn.lower()
if svn == "h11":
disabled_svn.add(HttpVersion.h11)
elif svn == "h2":
disabled_svn.add(HttpVersion.h2)
elif svn == "h3":
disabled_svn.add(HttpVersion.h3)
kwargs["disabled_svn"] = disabled_svn
if "on_post_connection" in kwargs and callable(kwargs["on_post_connection"]):
self._connection_callback: (
typing.Callable[[ConnectionInfo], None] | None
) = kwargs["on_post_connection"]
kwargs.pop("on_post_connection")
else:
self._connection_callback = None
self._pool = HTTPSConnectionPool(self._server, self._port, **kwargs)
def close(self) -> None:
self._pool.close()
def is_available(self) -> bool:
return self._pool.pool is not None
def getaddrinfo(
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Tried to resolve 'localhost' from a HTTPSResolver"
)
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror("Address family for hostname not supported")
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror("Address family for hostname not supported")
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
validate_length_of(host)
promises: list[HTTPResponse | ResponsePromise] = []
remote_preemptive_quic_rr = False
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
quic_upgrade_via_dns_rr = False
if family in [socket.AF_UNSPEC, socket.AF_INET]:
if not self._rfc8484:
promises.append(
self._pool.request_encode_url(
"GET",
self._path,
{"name": host, "type": "1"},
headers={"Accept": "application/dns-json"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
else:
dns_query = DomainNameServerQuery(
host, SupportedQueryType.A, override_id=0
)
dns_payload = bytes(dns_query)
promises.append(
self._pool.request_encode_url(
"GET",
self._path,
{
"dns": b64encode(dns_payload).decode().replace("=", ""),
},
headers={"Accept": "application/dns-message"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
if not self._rfc8484:
promises.append(
self._pool.request_encode_url(
"GET",
self._path,
{"name": host, "type": "28"},
headers={"Accept": "application/dns-json"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
else:
dns_query = DomainNameServerQuery(
host, SupportedQueryType.AAAA, override_id=0
)
dns_payload = bytes(dns_query)
promises.append(
self._pool.request_encode_url(
"GET",
self._path,
{
"dns": b64encode(dns_payload).decode().replace("=", ""),
},
headers={"Accept": "application/dns-message"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
if quic_upgrade_via_dns_rr:
if not self._rfc8484:
promises.append(
self._pool.request_encode_url(
"GET",
self._path,
{"name": host, "type": "65"},
headers={"Accept": "application/dns-json"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
else:
dns_query = DomainNameServerQuery(
host, SupportedQueryType.HTTPS, override_id=0
)
dns_payload = bytes(dns_query)
promises.append(
self._pool.request_encode_url(
"GET",
self._path,
{
"dns": b64encode(dns_payload).decode().replace("=", ""),
},
headers={"Accept": "application/dns-message"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
responses: list[HTTPResponse] = []
for promise in promises:
if isinstance(promise, HTTPResponse):
responses.append(promise)
continue
responses.append(self._pool.get_response(promise=promise)) # type: ignore[arg-type]
results: list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
] = []
for response in responses:
if response.status >= 300:
raise socket.gaierror(
f"DNS over HTTPS was unsuccessful, server response status {response.status}."
)
if not self._rfc8484:
payload = response.json()
assert "Status" in payload and isinstance(payload["Status"], int)
if payload["Status"] != 0:
msg = (
payload["Comment"]
if "Comment" in payload
else f"Remote DNS indicated that an error occurred while providing resolution. Status {payload['Status']}."
)
if isinstance(msg, list):
msg = ", ".join(msg)
raise socket.gaierror(msg)
assert "Question" in payload and isinstance(payload["Question"], list)
if "Answer" not in payload:
continue
assert isinstance(payload["Answer"], list)
for answer in payload["Answer"]:
if answer["type"] not in [1, 28, 65]:
continue
assert "data" in answer
assert isinstance(answer["data"], str)
# DNS RR/HTTPS
if answer["type"] == 65:
# "1 . alpn=h3,h2 ipv4hint=104.16.132.229,104.16.133.229 ipv6hint=2606:4700::6810:84e5,2606:4700::6810:85e5"
# or..
# "1 . alpn=h2,h3"
rr: str = answer["data"]
if rr.startswith("\\#"): # it means, raw, bytes.
rr = "".join(rr[2:].split(" ")[2:])
try:
raw_record = bytes.fromhex(rr)
except ValueError:
raw_record = b""
https_record = parse_https_rdata(raw_record)
if "h3" not in https_record["alpn"]:
continue
remote_preemptive_quic_rr = True
else:
rr_decode: dict[str, str] = dict(
tuple(_.lower().split("=", 1)) # type: ignore[misc]
for _ in rr.split(" ")
if "=" in _
)
if "alpn" not in rr_decode or "h3" not in rr_decode["alpn"]:
continue
remote_preemptive_quic_rr = True
if "ipv4hint" in rr_decode and family in [
socket.AF_UNSPEC,
socket.AF_INET,
]:
for ipv4 in rr_decode["ipv4hint"].split(","):
results.append(
(
socket.AF_INET,
socket.SOCK_DGRAM,
17,
"",
(
ipv4,
port,
),
)
)
if "ipv6hint" in rr_decode and family in [
socket.AF_UNSPEC,
socket.AF_INET6,
]:
for ipv6 in rr_decode["ipv6hint"].split(","):
results.append(
(
socket.AF_INET6,
socket.SOCK_DGRAM,
17,
"",
(
ipv6,
port,
0,
0,
),
)
)
continue
inet_type = (
socket.AF_INET if answer["type"] == 1 else socket.AF_INET6
)
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
(
answer["data"],
port,
)
if inet_type == socket.AF_INET
else (
answer["data"],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
else:
dns_resp = DomainNameServerReturn(response.data)
for record in dns_resp.records:
if record[0] == SupportedQueryType.HTTPS:
assert isinstance(record[-1], dict)
if "h3" in record[-1]["alpn"]:
remote_preemptive_quic_rr = True
continue
assert not isinstance(record[-1], dict)
inet_type = (
socket.AF_INET
if record[0] == SupportedQueryType.A
else socket.AF_INET6
)
dst_addr = (
(
record[-1],
port,
)
if inet_type == socket.AF_INET
else (
record[-1],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
quic_results: list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
] = []
if remote_preemptive_quic_rr:
any_specified = False
for result in results:
if result[1] == socket.SOCK_STREAM:
quic_results.append(
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
)
else:
any_specified = True
break
if any_specified:
quic_results = []
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
class GoogleResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "google"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
if "rfc8484" in kwargs:
if kwargs["rfc8484"]:
kwargs["path"] = "/dns-query"
super().__init__("dns.google", port, *patterns, **kwargs)
class CloudflareResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "cloudflare"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query"})
super().__init__("cloudflare-dns.com", port, *patterns, **kwargs)
class AdGuardResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query", "rfc8484": True})
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
class OpenDNSResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "opendns"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query", "rfc8484": True})
super().__init__("dns.opendns.com", port, *patterns, **kwargs)
class Quad9Resolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "quad9"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query", "rfc8484": True})
super().__init__("dns11.quad9.net", port, *patterns, **kwargs)
class NextDNSResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "nextdns"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.nextdns.io", port, *patterns, **kwargs)

View File

@@ -0,0 +1,15 @@
from __future__ import annotations
try:
from ._qh3 import AdGuardResolver, NextDNSResolver, QUICResolver
except ImportError:
QUICResolver = None # type: ignore
AdGuardResolver = None # type: ignore
NextDNSResolver = None # type: ignore
__all__ = (
"QUICResolver",
"AdGuardResolver",
"NextDNSResolver",
)

View File

@@ -0,0 +1,541 @@
from __future__ import annotations
import socket
import ssl
import typing
from collections import deque
from ssl import SSLError
from time import time as monotonic
from qh3.quic.configuration import QuicConfiguration
from qh3.quic.connection import QuicConnection
from qh3.quic.events import (
ConnectionTerminated,
HandshakeCompleted,
QuicEvent,
StopSendingReceived,
StreamDataReceived,
StreamReset,
)
from ....util.ssl_ import IS_FIPS, resolve_cert_reqs
from ...ssa._gro import _sock_has_gro, _sock_has_gso, sync_recv_gro, sync_sendmsg_gso
from ..dou import PlainResolver
from ..protocols import (
COMMON_RCODE_LABEL,
DomainNameServerQuery,
DomainNameServerReturn,
ProtocolResolver,
SupportedQueryType,
)
from ..utils import (
is_ipv4,
is_ipv6,
rfc1035_pack,
rfc1035_should_read,
rfc1035_unpack,
validate_length_of,
)
if IS_FIPS:
raise ImportError(
"DNS-over-QUIC disabled when Python is built with FIPS-compliant ssl module"
)
class QUICResolver(PlainResolver):
protocol = ProtocolResolver.DOQ
implementation = "qh3"
def __init__(
self,
server: str,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
):
super().__init__(server, port or 853, *patterns, **kwargs)
# qh3 load_default_certs seems off. need to investigate.
if "ca_cert_data" not in kwargs and "ca_certs" not in kwargs:
kwargs["ca_cert_data"] = []
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
try:
ctx.load_default_certs()
for der in ctx.get_ca_certs(binary_form=True):
kwargs["ca_cert_data"].append(ssl.DER_cert_to_PEM_cert(der))
if kwargs["ca_cert_data"]:
kwargs["ca_cert_data"] = "".join(kwargs["ca_cert_data"])
else:
del kwargs["ca_cert_data"]
except (AttributeError, ValueError, OSError):
del kwargs["ca_cert_data"]
if "ca_cert_data" not in kwargs and "ca_certs" not in kwargs:
if (
"cert_reqs" not in kwargs
or resolve_cert_reqs(kwargs["cert_reqs"]) is ssl.CERT_REQUIRED
):
raise ssl.SSLError(
"DoQ requires at least one CA loaded in order to verify the remote peer certificate. "
"Add ?cert_reqs=0 to disable certificate checks."
)
configuration = QuicConfiguration(
is_client=True,
alpn_protocols=["doq"],
server_name=self._server
if "server_hostname" not in kwargs
else kwargs["server_hostname"],
verify_mode=resolve_cert_reqs(kwargs["cert_reqs"])
if "cert_reqs" in kwargs
else ssl.CERT_REQUIRED,
cadata=kwargs["ca_cert_data"].encode()
if "ca_cert_data" in kwargs
else None,
cafile=kwargs["ca_certs"] if "ca_certs" in kwargs else None,
idle_timeout=300.0,
)
if "cert_file" in kwargs:
configuration.load_cert_chain(
kwargs["cert_file"],
kwargs["key_file"] if "key_file" in kwargs else None,
kwargs["key_password"] if "key_password" in kwargs else None,
)
elif "cert_data" in kwargs:
configuration.load_cert_chain(
kwargs["cert_data"],
kwargs["key_data"] if "key_data" in kwargs else None,
kwargs["key_password"] if "key_password" in kwargs else None,
)
self._quic = QuicConnection(configuration=configuration)
self._dgram_gro_enabled: bool = _sock_has_gro(self._socket)
self._dgram_gso_enabled: bool = _sock_has_gso(self._socket)
self._quic.connect((self._server, self._port), monotonic())
self.__exchange_until(HandshakeCompleted, receive_first=False)
self._terminated: bool = False
self._should_disconnect: bool = False
# DNS over QUIC mandate the size-prefix (unsigned int, 2b)
self._rfc1035_prefix_mandated = True
self._unconsumed: deque[DomainNameServerReturn] = deque()
self._pending: deque[DomainNameServerQuery] = deque()
def close(self) -> None:
if not self._terminated:
with self._lock:
self._quic.close()
while True:
datagrams = self._quic.datagrams_to_send(monotonic())
if not datagrams:
break
if self._dgram_gso_enabled and len(datagrams) > 1:
sync_sendmsg_gso(self._socket, [d[0] for d in datagrams])
else:
for datagram in datagrams:
self._socket.sendall(datagram[0])
self._socket.close()
self._terminated = True
def is_available(self) -> bool:
self._quic.handle_timer(monotonic())
if hasattr(self._quic, "_close_event") and self._quic._close_event is not None:
self._terminated = True
return not self._terminated
def getaddrinfo(
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Tried to resolve 'localhost' using the QUICResolver"
)
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
validate_length_of(host)
remote_preemptive_quic_rr = False
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
quic_upgrade_via_dns_rr = False
tbq = []
if family in [socket.AF_UNSPEC, socket.AF_INET]:
tbq.append(SupportedQueryType.A)
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
tbq.append(SupportedQueryType.AAAA)
if quic_upgrade_via_dns_rr:
tbq.append(SupportedQueryType.HTTPS)
queries = DomainNameServerQuery.bulk(host, *tbq)
open_streams = []
with self._lock:
for q in queries:
payload = bytes(q)
self._pending.append(q)
if self._rfc1035_prefix_mandated is True:
payload = rfc1035_pack(payload)
stream_id = self._quic.get_next_available_stream_id()
self._quic.send_stream_data(stream_id, payload, True)
open_streams.append(stream_id)
datagrams = self._quic.datagrams_to_send(monotonic())
if self._dgram_gso_enabled and len(datagrams) > 1:
sync_sendmsg_gso(self._socket, [d[0] for d in datagrams])
else:
for dg in datagrams:
self._socket.sendall(dg[0])
responses: list[DomainNameServerReturn] = []
while len(responses) < len(tbq):
with self._lock:
if self._unconsumed:
dns_resp = None
for query in queries:
for unconsumed in self._unconsumed:
if unconsumed.id == query.id:
dns_resp = unconsumed
responses.append(dns_resp)
break
if dns_resp:
break
if dns_resp:
self._unconsumed.remove(dns_resp)
self._pending.remove(query)
continue
try:
events: list[StreamDataReceived] = self.__exchange_until( # type: ignore[assignment]
StreamDataReceived,
receive_first=True,
event_type_collectable=(StreamDataReceived,),
respect_end_stream_signal=False,
)
payload = b"".join([e.data for e in events])
while rfc1035_should_read(payload):
events.extend(
self.__exchange_until( # type: ignore[arg-type]
StreamDataReceived,
receive_first=True,
event_type_collectable=(StreamDataReceived,),
respect_end_stream_signal=False,
)
)
payload = b"".join([e.data for e in events])
except (TimeoutError, OSError, socket.timeout, ConnectionError) as e:
raise socket.gaierror(
"Got unexpectedly disconnected while waiting for name resolution"
) from e
if not payload:
continue
#: We can receive two responses at once (or more, concatenated). Let's unwrap them.
fragments = rfc1035_unpack(payload)
for fragment in fragments:
dns_resp = DomainNameServerReturn(fragment)
if any(dns_resp.id == _.id for _ in queries):
responses.append(dns_resp)
query_tbr: DomainNameServerQuery | None = None
for query_tbr in self._pending:
if query_tbr.id == dns_resp.id:
break
if query_tbr:
self._pending.remove(query_tbr)
else:
self._unconsumed.append(dns_resp)
if self._should_disconnect:
with self._lock:
self.close()
self._should_disconnect = False
self._terminated = True
results = []
for response in responses:
if not response.is_ok:
if response.rcode == 2:
raise socket.gaierror(
f"DNSSEC validation failure. Check http://dnsviz.net/d/{host}/dnssec/ and http://dnssec-debugger.verisignlabs.com/{host} for errors"
)
raise socket.gaierror(
f"DNS returned an error: {COMMON_RCODE_LABEL[response.rcode] if response.rcode in COMMON_RCODE_LABEL else f'code {response.rcode}'}"
)
for record in response.records:
if record[0] == SupportedQueryType.HTTPS:
assert isinstance(record[-1], dict)
if "h3" in record[-1]["alpn"]:
remote_preemptive_quic_rr = True
continue
assert not isinstance(record[-1], dict)
inet_type = (
socket.AF_INET
if record[0] == SupportedQueryType.A
else socket.AF_INET6
)
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
(
record[-1],
port,
)
if inet_type == socket.AF_INET
else (
record[-1],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
quic_results = []
if remote_preemptive_quic_rr:
any_specified = False
for result in results:
if result[1] == socket.SOCK_STREAM:
quic_results.append(
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
)
else:
any_specified = True
break
if any_specified:
quic_results = []
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
def __exchange_until(
self,
event_type: type[QuicEvent] | tuple[type[QuicEvent], ...],
*,
receive_first: bool = False,
event_type_collectable: type[QuicEvent]
| tuple[type[QuicEvent], ...]
| None = None,
respect_end_stream_signal: bool = True,
) -> list[QuicEvent]:
while True:
if receive_first is False:
now = monotonic()
while True:
datagrams = self._quic.datagrams_to_send(now)
if not datagrams:
break
if self._dgram_gso_enabled and len(datagrams) > 1:
sync_sendmsg_gso(self._socket, [d[0] for d in datagrams])
else:
for datagram in datagrams:
self._socket.sendall(datagram[0])
events = []
while True:
if not self._quic._events:
if self._dgram_gro_enabled:
data_in = sync_recv_gro(self._socket, 65535)
else:
data_in = self._socket.recv(1500)
if not data_in:
break
now = monotonic()
if isinstance(data_in, list):
for gro_segment in data_in:
self._quic.receive_datagram(
gro_segment, (self._server, self._port), now
)
else:
self._quic.receive_datagram(
data_in, (self._server, self._port), now
)
while True:
now = monotonic()
datagrams = self._quic.datagrams_to_send(now)
if not datagrams:
break
if self._dgram_gso_enabled and len(datagrams) > 1:
sync_sendmsg_gso(self._socket, [d[0] for d in datagrams])
else:
for datagram in datagrams:
self._socket.sendall(datagram[0])
for ev in iter(self._quic.next_event, None):
if isinstance(ev, ConnectionTerminated):
if ev.error_code == 298:
raise SSLError(
"DNS over QUIC did not succeed (Error 298). Chain certificate verification failed."
)
raise socket.gaierror(
f"DNS over QUIC encountered a unrecoverable failure (error {ev.error_code} {ev.reason_phrase})"
)
elif isinstance(ev, StreamReset):
self._terminated = True
raise socket.gaierror(
"DNS over QUIC server submitted a StreamReset. A request was rejected."
)
elif isinstance(ev, StopSendingReceived):
self._should_disconnect = True
continue
if event_type_collectable:
if isinstance(ev, event_type_collectable):
events.append(ev)
else:
events.append(ev)
if isinstance(ev, event_type):
if not respect_end_stream_signal:
return events
if hasattr(ev, "stream_ended") and ev.stream_ended:
return events
elif hasattr(ev, "stream_ended") is False:
return events
return events
class AdGuardResolver(
QUICResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
class NextDNSResolver(
QUICResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "nextdns"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.nextdns.io", port, *patterns, **kwargs)

View File

@@ -0,0 +1,19 @@
from __future__ import annotations
from ._ssl import (
AdGuardResolver,
CloudflareResolver,
GoogleResolver,
OpenDNSResolver,
Quad9Resolver,
TLSResolver,
)
__all__ = (
"TLSResolver",
"GoogleResolver",
"CloudflareResolver",
"AdGuardResolver",
"Quad9Resolver",
"OpenDNSResolver",
)

View File

@@ -0,0 +1,156 @@
from __future__ import annotations
import socket
import typing
from ....util.ssl_ import resolve_cert_reqs, ssl_wrap_socket
from ..dou import PlainResolver
from ..protocols import ProtocolResolver
from ..system import SystemResolver
class TLSResolver(PlainResolver):
"""
Basic DNS resolver over TLS.
Comply with RFC 7858: https://datatracker.ietf.org/doc/html/rfc7858
"""
protocol = ProtocolResolver.DOT
implementation = "ssl"
def __init__(
self,
server: str,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
) -> None:
if "timeout" in kwargs and isinstance(kwargs["timeout"], (int, float)):
timeout = kwargs["timeout"]
else:
timeout = None
if "source_address" in kwargs and isinstance(kwargs["source_address"], str):
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
else:
bind_ip, bind_port = "0.0.0.0", "0"
self._socket = SystemResolver().create_connection(
(server, port or 853),
timeout=timeout,
source_address=(bind_ip, int(bind_port))
if bind_ip != "0.0.0.0" or bind_port != "0"
else None,
socket_options=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1, "tcp"),),
socket_kind=socket.SOCK_STREAM,
)
super().__init__(server, port, *patterns, **kwargs)
self._socket = ssl_wrap_socket(
self._socket,
server_hostname=server
if "server_hostname" not in kwargs
else kwargs["server_hostname"],
keyfile=kwargs["key_file"] if "key_file" in kwargs else None,
certfile=kwargs["cert_file"] if "cert_file" in kwargs else None,
cert_reqs=resolve_cert_reqs(kwargs["cert_reqs"])
if "cert_reqs" in kwargs
else None,
ca_certs=kwargs["ca_certs"] if "ca_certs" in kwargs else None,
ssl_version=kwargs["ssl_version"] if "ssl_version" in kwargs else None,
ciphers=kwargs["ciphers"] if "ciphers" in kwargs else None,
ca_cert_dir=kwargs["ca_cert_dir"] if "ca_cert_dir" in kwargs else None,
key_password=kwargs["key_password"] if "key_password" in kwargs else None,
ca_cert_data=kwargs["ca_cert_data"] if "ca_cert_data" in kwargs else None,
certdata=kwargs["cert_data"] if "cert_data" in kwargs else None,
keydata=kwargs["key_data"] if "key_data" in kwargs else None,
)
# DNS over TLS mandate the size-prefix (unsigned int, 2 bytes)
self._rfc1035_prefix_mandated = True
class GoogleResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "google"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.google", port, *patterns, **kwargs)
class CloudflareResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "cloudflare"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("1.1.1.1", port, *patterns, **kwargs)
class AdGuardResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
class OpenDNSResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "opendns"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.opendns.com", port, *patterns, **kwargs)
class Quad9Resolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "quad9"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns11.quad9.net", port, *patterns, **kwargs)

View File

@@ -0,0 +1,17 @@
from __future__ import annotations
from ._socket import (
AdGuardResolver,
CloudflareResolver,
GoogleResolver,
PlainResolver,
Quad9Resolver,
)
__all__ = (
"PlainResolver",
"CloudflareResolver",
"GoogleResolver",
"Quad9Resolver",
"AdGuardResolver",
)

View File

@@ -0,0 +1,415 @@
from __future__ import annotations
import socket
import typing
from collections import deque
from ...ssa._gro import _sock_has_gro, sync_recv_gro
from ..protocols import (
COMMON_RCODE_LABEL,
BaseResolver,
DomainNameServerQuery,
DomainNameServerReturn,
ProtocolResolver,
SupportedQueryType,
)
from ..system import SystemResolver
from ..utils import (
is_ipv4,
is_ipv6,
packet_fragment,
rfc1035_pack,
rfc1035_should_read,
rfc1035_unpack,
validate_length_of,
)
class PlainResolver(BaseResolver):
"""
Minimalist DNS resolver over UDP
Comply with RFC 1035: https://datatracker.ietf.org/doc/html/rfc1035
EDNS is not supported, yet. But we plan to. Willing to contribute?
"""
protocol = ProtocolResolver.DOU
implementation = "socket"
def __init__(
self,
server: str,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
) -> None:
super().__init__(server, port, *patterns, **kwargs)
if not hasattr(self, "_socket"):
if "timeout" in kwargs and isinstance(
kwargs["timeout"],
(
float,
int,
),
):
timeout = kwargs["timeout"]
else:
timeout = None
if "source_address" in kwargs and isinstance(kwargs["source_address"], str):
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
else:
bind_ip, bind_port = "0.0.0.0", "0"
self._socket = SystemResolver().create_connection(
(server, port or 53),
timeout=timeout,
source_address=(bind_ip, int(bind_port))
if bind_ip != "0.0.0.0" or bind_port != "0"
else None,
socket_options=None,
socket_kind=socket.SOCK_DGRAM,
)
#: Only useful for inheritance, e.g. DNS over TLS support dns-message but require a prefix.
self._rfc1035_prefix_mandated: bool = False
self._gro_enabled: bool = _sock_has_gro(self._socket)
self._unconsumed: deque[DomainNameServerReturn] = deque()
self._pending: deque[DomainNameServerQuery] = deque()
self._terminated: bool = False
def close(self) -> None:
if not self._terminated:
with self._lock:
if self._socket is not None:
self._socket.shutdown(0)
self._socket.close()
self._terminated = True
def is_available(self) -> bool:
return not self._terminated
def getaddrinfo(
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Tried to resolve 'localhost' from a PlainResolver"
)
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
validate_length_of(host)
remote_preemptive_quic_rr = False
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
quic_upgrade_via_dns_rr = False
tbq = []
if family in [socket.AF_UNSPEC, socket.AF_INET]:
tbq.append(SupportedQueryType.A)
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
tbq.append(SupportedQueryType.AAAA)
if quic_upgrade_via_dns_rr:
tbq.append(SupportedQueryType.HTTPS)
queries = DomainNameServerQuery.bulk(host, *tbq)
with self._lock:
for q in queries:
payload = bytes(q)
self._pending.append(q)
if self._rfc1035_prefix_mandated is True:
payload = rfc1035_pack(payload)
self._socket.sendall(payload)
responses: list[DomainNameServerReturn] = []
while len(responses) < len(tbq):
with self._lock:
#: There we want to verify if another thread got a response that belong to this thread.
if self._unconsumed:
dns_resp = None
for query in queries:
for unconsumed in self._unconsumed:
if unconsumed.id == query.id:
dns_resp = unconsumed
responses.append(dns_resp)
break
if dns_resp:
break
if dns_resp:
self._pending.remove(query)
self._unconsumed.remove(dns_resp)
continue
try:
if self._gro_enabled:
data_in_or_segments = sync_recv_gro(self._socket, 65535)
else:
data_in_or_segments = self._socket.recv(1500)
if isinstance(data_in_or_segments, list):
payloads = data_in_or_segments
elif data_in_or_segments:
payloads = [data_in_or_segments]
else:
payloads = []
if self._rfc1035_prefix_mandated is True and payloads:
payload = b"".join(payloads)
while rfc1035_should_read(payload):
extra = self._socket.recv(1500)
if isinstance(extra, list):
payload += b"".join(extra)
else:
payload += extra
payloads = [payload]
except (TimeoutError, OSError, socket.timeout, ConnectionError) as e:
raise socket.gaierror(
"Got unexpectedly disconnected while waiting for name resolution"
) from e
if not payloads:
self._terminated = True
raise socket.gaierror(
"Got unexpectedly disconnected while waiting for name resolution"
)
pending_raw_identifiers = [_.raw_id for _ in self._pending]
for payload in payloads:
#: We can receive two responses at once (or more, concatenated). Let's unwrap them.
if self._rfc1035_prefix_mandated is True:
fragments = rfc1035_unpack(payload)
else:
fragments = packet_fragment(payload, *pending_raw_identifiers)
for fragment in fragments:
dns_resp = DomainNameServerReturn(fragment)
if any(dns_resp.id == _.id for _ in queries):
responses.append(dns_resp)
query_tbr: DomainNameServerQuery | None = None
for query_tbr in self._pending:
if query_tbr.id == dns_resp.id:
break
if query_tbr:
self._pending.remove(query_tbr)
else:
self._unconsumed.append(dns_resp)
results = []
for response in responses:
if not response.is_ok:
if response.rcode == 2:
raise socket.gaierror(
f"DNSSEC validation failure. Check http://dnsviz.net/d/{host}/dnssec/ and http://dnssec-debugger.verisignlabs.com/{host} for errors"
)
raise socket.gaierror(
f"DNS returned an error: {COMMON_RCODE_LABEL[response.rcode] if response.rcode in COMMON_RCODE_LABEL else f'code {response.rcode}'}"
)
for record in response.records:
if record[0] == SupportedQueryType.HTTPS:
assert isinstance(record[-1], dict)
if "h3" in record[-1]["alpn"]:
remote_preemptive_quic_rr = True
continue
assert not isinstance(record[-1], dict)
inet_type = (
socket.AF_INET
if record[0] == SupportedQueryType.A
else socket.AF_INET6
)
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
(
record[-1],
port,
)
if inet_type == socket.AF_INET
else (
record[-1],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
quic_results = []
if remote_preemptive_quic_rr:
any_specified = False
for result in results:
if result[1] == socket.SOCK_STREAM:
quic_results.append(
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
)
else:
any_specified = True
break
if any_specified:
quic_results = []
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
class CloudflareResolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "cloudflare"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("1.1.1.1", port, *patterns, **kwargs)
class GoogleResolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "google"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("8.8.8.8", port, *patterns, **kwargs)
class Quad9Resolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "quad9"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("9.9.9.9", port, *patterns, **kwargs)
class AdGuardResolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("94.140.14.140", port, *patterns, **kwargs)

View File

@@ -0,0 +1,230 @@
from __future__ import annotations
import importlib
import inspect
import typing
from abc import ABCMeta
from base64 import b64encode
from typing import Any
from urllib.parse import parse_qs
from ...util import parse_url
from .protocols import BaseResolver, ProtocolResolver
class ResolverFactory(metaclass=ABCMeta):
@staticmethod
def new(
protocol: ProtocolResolver,
specifier: str | None = None,
implementation: str | None = None,
**kwargs: Any,
) -> BaseResolver:
package_name: str = __name__.split(".")[0]
module_expr = f".{protocol.value.replace('-', '_')}"
if implementation:
module_expr += f"._{implementation.replace('-', '_').lower()}"
spe_msg = " " if specifier is None else f' (w/ specifier "{specifier}") '
try:
resolver_module = importlib.import_module(
module_expr, f"{package_name}.contrib.resolver"
)
except ImportError as e:
raise NotImplementedError(
f"{protocol}{spe_msg}cannot be loaded. Tried to import '{module_expr}'. Did you specify a non-existent implementation?"
) from e
implementations: list[tuple[str, type[BaseResolver]]] = inspect.getmembers(
resolver_module,
lambda e: isinstance(e, type)
and issubclass(e, BaseResolver)
and (
(specifier is None and e.specifier is None) or specifier == e.specifier
),
)
if not implementations:
raise NotImplementedError(
f"{protocol}{spe_msg}cannot be loaded. "
"No compatible implementation available. "
"Make sure your implementation inherit from BaseResolver."
)
implementation_target: type[BaseResolver] = implementations.pop()[1]
return implementation_target(**kwargs)
class ResolverDescription:
"""Describe how a BaseResolver must be instantiated."""
def __init__(
self,
protocol: ProtocolResolver,
specifier: str | None = None,
implementation: str | None = None,
server: str | None = None,
port: int | None = None,
*host_patterns: str,
**kwargs: typing.Any,
) -> None:
self.protocol = protocol
self.specifier = specifier
self.implementation = implementation
self.server = server
self.port = port
self.host_patterns = host_patterns
self.kwargs = kwargs
def __setitem__(self, key: str, value: typing.Any) -> None:
self.kwargs[key] = value
def __contains__(self, item: str) -> bool:
return item in self.kwargs
def new(self) -> BaseResolver:
kwargs = {**self.kwargs}
if self.server:
kwargs["server"] = self.server
if self.port:
kwargs["port"] = self.port
if self.host_patterns:
kwargs["patterns"] = self.host_patterns
return ResolverFactory.new(
self.protocol,
self.specifier,
self.implementation,
**kwargs,
)
@staticmethod
def from_url(url: str) -> ResolverDescription:
parsed_url = parse_url(url)
schema = parsed_url.scheme
if schema is None:
raise ValueError("Given DNS url is missing a protocol")
specifier = None
implementation = None
if "+" in schema:
schema, specifier = tuple(schema.lower().split("+", 1))
protocol = ProtocolResolver(schema)
kwargs: dict[str, typing.Any] = {}
if parsed_url.path:
kwargs["path"] = parsed_url.path
if parsed_url.auth:
kwargs["headers"] = dict()
if ":" in parsed_url.auth:
username, password = parsed_url.auth.split(":")
username = username.strip("'\"")
password = password.strip("'\"")
kwargs["headers"]["Authorization"] = (
f"Basic {b64encode(f'{username}:{password}'.encode()).decode()}"
)
else:
kwargs["headers"]["Authorization"] = f"Bearer {parsed_url.auth}"
if parsed_url.query:
parameters = parse_qs(parsed_url.query)
for parameter in parameters:
if not parameters[parameter]:
continue
parameter_insensible = parameter.lower()
if (
isinstance(parameters[parameter], list)
and len(parameters[parameter]) > 1
):
if parameter == "implementation":
raise ValueError("Only one implementation can be passed to URL")
values = []
for e in parameters[parameter]:
if "," in e:
values.extend(e.split(","))
else:
values.append(e)
if parameter_insensible in kwargs:
if isinstance(kwargs[parameter_insensible], list):
kwargs[parameter_insensible].extend(values)
else:
values.append(kwargs[parameter_insensible])
kwargs[parameter_insensible] = values
continue
kwargs[parameter_insensible] = values
continue
value: str = parameters[parameter][0].lower().strip(" ")
if parameter == "implementation":
implementation = value
continue
if "," in value:
list_of_values = value.split(",")
if parameter_insensible in kwargs:
if isinstance(kwargs[parameter_insensible], list):
kwargs[parameter_insensible].extend(list_of_values)
else:
list_of_values.append(kwargs[parameter_insensible])
continue
kwargs[parameter_insensible] = list_of_values
continue
value_converted: bool | int | float | None = None
if value in ["false", "true"]:
value_converted = True if value == "true" else False
elif value.isdigit():
value_converted = int(value)
elif (
value.count(".") == 1
and value.index(".") > 0
and value.replace(".", "").isdigit()
):
value_converted = float(value)
kwargs[parameter_insensible] = (
value if value_converted is None else value_converted
)
host_patterns: list[str] = []
if "hosts" in kwargs:
host_patterns = (
kwargs["hosts"].split(",")
if isinstance(kwargs["hosts"], str)
else kwargs["hosts"]
)
del kwargs["hosts"]
return ResolverDescription(
protocol,
specifier,
implementation,
parsed_url.host,
parsed_url.port,
*host_patterns,
**kwargs,
)

View File

@@ -0,0 +1,5 @@
from __future__ import annotations
from ._dict import InMemoryResolver
__all__ = ("InMemoryResolver",)

View File

@@ -0,0 +1,192 @@
from __future__ import annotations
import socket
import typing
from ....util.url import _IPV6_ADDRZ_RE
from ..protocols import BaseResolver, ProtocolResolver
from ..utils import is_ipv4, is_ipv6
class InMemoryResolver(BaseResolver):
protocol = ProtocolResolver.MANUAL
implementation = "dict"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
kwargs.pop("port")
super().__init__(None, None, *patterns, **kwargs)
self._maxsize = 65535 if "maxsize" not in kwargs else int(kwargs["maxsize"])
self._hosts: dict[str, list[tuple[socket.AddressFamily, str]]] = {}
if self._host_patterns:
for record in self._host_patterns:
if ":" not in record:
continue
hostname, addr = record.split(":", 1)
self.register(hostname, addr)
self._host_patterns = tuple([])
# probably about our happy eyeballs impl (sync only)
if len(self._hosts) == 1 and len(self._hosts[list(self._hosts.keys())[0]]) == 1:
self._unsafe_expose = True
def recycle(self) -> BaseResolver:
return self
def close(self) -> None:
pass # no-op
def is_available(self) -> bool:
return True
def have_constraints(self) -> bool:
return True
def support(self, hostname: str | bytes | None) -> bool | None:
if hostname is None:
hostname = "localhost"
if isinstance(hostname, bytes):
hostname = hostname.decode("ascii")
return hostname in self._hosts
def register(self, hostname: str, ipaddr: str) -> None:
with self._lock:
if hostname not in self._hosts:
self._hosts[hostname] = []
else:
for e in self._hosts[hostname]:
t, addr = e
if addr in ipaddr:
return
if _IPV6_ADDRZ_RE.match(ipaddr):
self._hosts[hostname].append((socket.AF_INET6, ipaddr[1:-1]))
elif is_ipv6(ipaddr):
self._hosts[hostname].append((socket.AF_INET6, ipaddr))
else:
self._hosts[hostname].append((socket.AF_INET, ipaddr))
if len(self._hosts) > self._maxsize:
k = None
for k in self._hosts.keys():
break
if k:
self._hosts.pop(k)
def clear(self, hostname: str) -> None:
with self._lock:
if hostname in self._hosts:
del self._hosts[hostname]
def getaddrinfo(
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
host = "localhost" # Defensive: stdlib cpy behavior
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
results: list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
] = []
with self._lock:
if host not in self._hosts:
raise socket.gaierror(f"no records found for hostname {host} in-memory")
for entry in self._hosts[host]:
addr_type, addr_target = entry
if family != socket.AF_UNSPEC:
if family != addr_type:
continue
results.append(
(
addr_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
(addr_target, port)
if addr_type == socket.AF_INET
else (addr_target, port, 0, 0),
)
)
if not results:
raise socket.gaierror(f"no records found for hostname {host} in-memory")
return results

View File

@@ -0,0 +1,104 @@
from __future__ import annotations
import socket
import typing
from ..protocols import BaseResolver, ProtocolResolver
from ..utils import is_ipv4, is_ipv6
class NullResolver(BaseResolver):
protocol = ProtocolResolver.NULL
implementation = "dummy"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
kwargs.pop("port")
super().__init__(None, None, *patterns, **kwargs)
def recycle(self) -> BaseResolver:
return self
def close(self) -> None:
pass # no-op
def is_available(self) -> bool:
return True
def getaddrinfo(
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
host = "localhost" # Defensive: stdlib cpy behavior
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror(
"Servname not supported for ai_socktype"
) # Defensive: stdlib cpy behavior
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror(
"Address family for hostname not supported"
) # Defensive: stdlib cpy behavior
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror(
"Address family for hostname not supported"
) # Defensive: stdlib cpy behavior
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
raise socket.gaierror(f"Tried to resolve '{host}' using the NullResolver")
__all__ = ("NullResolver",)

View File

@@ -0,0 +1,655 @@
from __future__ import annotations
import ipaddress
import socket
import struct
import sys
import threading
import typing
from abc import ABCMeta, abstractmethod
from copy import deepcopy
from datetime import datetime, timedelta, timezone
from enum import Enum
from random import randint
from ..._constant import UDP_LINUX_GRO
from ..._typing import _TYPE_SOCKET_OPTIONS, _TYPE_TIMEOUT_INTERNAL
from ...exceptions import LocationParseError
from ...util.connection import _set_socket_options, allowed_gai_family
from ...util.ssl_match_hostname import CertificateError, match_hostname
from ...util.timeout import _DEFAULT_TIMEOUT
from .utils import inet4_ntoa, inet6_ntoa, parse_https_rdata
if typing.TYPE_CHECKING:
from .utils import HttpsRecord
class ProtocolResolver(str, Enum):
"""
At urllib3.future we aim to propose a wide range of DNS-protocols.
The most used techniques are available.
"""
#: Ask the OS native DNS layer
SYSTEM = "system"
#: DNS over HTTPS
DOH = "doh"
#: DNS over QUIC
DOQ = "doq"
#: DNS over TLS
DOT = "dot"
#: DNS over UDP (insecure)
DOU = "dou"
#: Manual (e.g. hosts)
MANUAL = "in-memory"
#: Void (e.g. purposely disable resolution)
NULL = "null"
#: Custom (e.g. your own implementation, use this when it does not suit any of the protocols specified)
CUSTOM = "custom"
class BaseResolver(metaclass=ABCMeta):
protocol: typing.ClassVar[ProtocolResolver]
specifier: typing.ClassVar[str | None] = None
implementation: typing.ClassVar[str]
def __init__(
self,
server: str | None,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
) -> None:
self._server = server
self._port = port
self._host_patterns: tuple[str, ...] = patterns
self._lock = threading.Lock()
self._kwargs = kwargs
if not self._host_patterns and "patterns" in kwargs:
self._host_patterns = kwargs["patterns"]
# allow to temporarily expose a sock that is "being" created
# this helps with our Happy Eyeballs implementation in sync.
self._unsafe_expose: bool = False
self._sock_cursor: socket.socket | None = None
def recycle(self) -> BaseResolver:
if self.is_available():
raise RuntimeError("Attempting to recycle a Resolver that was not closed")
args = list(self.__class__.__init__.__code__.co_varnames)
args.remove("self")
kwargs_cpy = deepcopy(self._kwargs)
if self._server:
kwargs_cpy["server"] = self._server
if self._port:
kwargs_cpy["port"] = self._port
if "patterns" in args and "kwargs" in args:
return self.__class__(*self._host_patterns, **kwargs_cpy) # type: ignore[arg-type]
elif "kwargs" in args:
return self.__class__(**kwargs_cpy)
return self.__class__() # type: ignore[call-arg]
@property
def server(self) -> str | None:
return self._server
@property
def port(self) -> int | None:
return self._port
def have_constraints(self) -> bool:
return bool(self._host_patterns)
def support(self, hostname: str | bytes | None) -> bool | None:
"""
Determine if given hostname is especially resolvable by given resolver.
If this resolver does not have any constrained list of host, it returns None. Meaning
it support any hostname for resolution.
"""
if not self._host_patterns:
return None
if hostname is None:
hostname = "localhost"
if isinstance(hostname, bytes):
hostname = hostname.decode("ascii")
try:
match_hostname(
{"subjectAltName": (tuple(("DNS", e) for e in self._host_patterns))},
hostname,
)
except CertificateError:
return False
return True
@abstractmethod
def close(self) -> None:
"""Terminate the given resolver instance. This should render it unusable. Further inquiries should raise an exception."""
raise NotImplementedError
@abstractmethod
def is_available(self) -> bool:
"""Determine if Resolver can receive inquiries."""
raise NotImplementedError
@abstractmethod
def getaddrinfo(
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
"""This method align itself on the standard library socket.getaddrinfo(). It must be implemented as-is on your Resolver."""
raise NotImplementedError
# This function is copied from socket.py in the Python 2.7 standard
# library test suite. Added to its signature is only `socket_options`.
# One additional modification is that we avoid binding to IPv6 servers
# discovered in DNS if the system doesn't have IPv6 functionality.
def create_connection(
self,
address: tuple[str, int],
timeout: _TYPE_TIMEOUT_INTERNAL = _DEFAULT_TIMEOUT,
source_address: tuple[str, int] | None = None,
socket_options: _TYPE_SOCKET_OPTIONS | None = None,
socket_kind: socket.SocketKind = socket.SOCK_STREAM,
*,
quic_upgrade_via_dns_rr: bool = False,
timing_hook: typing.Callable[[tuple[timedelta, timedelta, datetime]], None]
| None = None,
default_socket_family: socket.AddressFamily = socket.AF_UNSPEC,
) -> socket.socket:
"""Connect to *address* and return the socket object.
Convenience function. Connect to *address* (a 2-tuple ``(host,
port)``) and return the socket object. Passing the optional
*timeout* parameter will set the timeout on the socket instance
before attempting to connect. If no *timeout* is supplied, the
global default timeout setting returned by :func:`socket.getdefaulttimeout`
is used. If *source_address* is set it must be a tuple of (host, port)
for the socket to bind as a source address before making the connection.
An host of '' or port 0 tells the OS to use the default.
"""
host, port = address
if host.startswith("["):
host = host.strip("[]")
err = None
# Using the value from allowed_gai_family() in the context of getaddrinfo lets
# us select whether to work with IPv4 DNS records, IPv6 records, or both.
# The original create_connection function always returns all records.
family = allowed_gai_family()
if family != socket.AF_UNSPEC:
default_socket_family = family
if source_address is not None:
if isinstance(
ipaddress.ip_address(source_address[0]), ipaddress.IPv4Address
):
default_socket_family = socket.AF_INET
else:
default_socket_family = socket.AF_INET6
try:
host.encode("idna")
except UnicodeError:
raise LocationParseError(f"'{host}', label empty or too long") from None
dt_pre_resolve = datetime.now(tz=timezone.utc)
records = self.getaddrinfo(
host,
port,
default_socket_family,
socket_kind,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
delta_post_resolve = datetime.now(tz=timezone.utc) - dt_pre_resolve
dt_pre_established = datetime.now(tz=timezone.utc)
for res in records:
af, socktype, proto, canonname, sa = res
sock = None
try:
sock = socket.socket(af, socktype, proto)
# we need to add this or reusing the same origin port will likely fail within
# short period of time. kernel put port on wait shut.
if source_address is not None:
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
except (
OSError,
AttributeError,
): # Defensive: Windows or very old OS?
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
except (
OSError,
AttributeError,
): # Defensive: we can't do anything better than this.
pass
try:
sock.setsockopt(
socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0)
)
except (OSError, AttributeError):
pass
sock.bind(source_address)
# attempt to leverage GRO when under Linux
if socktype == socket.SOCK_DGRAM and sys.platform == "linux":
try:
sock.setsockopt(socket.SOL_UDP, UDP_LINUX_GRO, 1)
except OSError: # Defensive: oh, well(...) anyway!
pass
# If provided, set socket level options before connecting.
_set_socket_options(sock, socket_options)
if timeout is not _DEFAULT_TIMEOUT:
sock.settimeout(timeout)
if self._unsafe_expose:
self._sock_cursor = sock
sock.connect(sa)
if self._unsafe_expose:
self._sock_cursor = None
# Break explicitly a reference cycle
err = None
delta_post_established = (
datetime.now(tz=timezone.utc) - dt_pre_established
)
if timing_hook is not None:
timing_hook(
(
delta_post_resolve,
delta_post_established,
datetime.now(tz=timezone.utc),
)
)
return sock
except (OSError, OverflowError) as _:
err = _
if sock is not None:
sock.close()
if isinstance(_, OverflowError):
break
if err is not None:
try:
raise err
finally:
# Break explicitly a reference cycle
err = None
else:
raise OSError("getaddrinfo returns an empty list")
class ManyResolver(BaseResolver):
"""
Special resolver that use many child resolver. Priorities
are based on given order (list of BaseResolver).
"""
def __init__(self, *resolvers: BaseResolver) -> None:
super().__init__(None, None)
self._size = len(resolvers)
self._unconstrained: list[BaseResolver] = [
_ for _ in resolvers if not _.have_constraints()
]
self._constrained: list[BaseResolver] = [
_ for _ in resolvers if _.have_constraints()
]
self._concurrent: int = 0
self._terminated: bool = False
def recycle(self) -> BaseResolver:
resolvers = []
for resolver in self._unconstrained + self._constrained:
resolvers.append(resolver.recycle())
return ManyResolver(*resolvers)
def close(self) -> None:
for resolver in self._unconstrained + self._constrained:
resolver.close()
self._terminated = True
def is_available(self) -> bool:
return not self._terminated
def __resolvers(
self, constrained: bool = False
) -> typing.Generator[BaseResolver, None, None]:
resolvers = self._unconstrained if not constrained else self._constrained
if not resolvers:
return
with self._lock:
self._concurrent += 1
try:
resolver_count = len(resolvers)
start_idx = (self._concurrent - 1) % resolver_count
for idx in range(start_idx, resolver_count):
if not resolvers[idx].is_available():
with self._lock:
resolvers[idx] = resolvers[idx].recycle()
yield resolvers[idx]
if start_idx > 0:
for idx in range(0, start_idx):
if not resolvers[idx].is_available():
with self._lock:
resolvers[idx] = resolvers[idx].recycle()
yield resolvers[idx]
finally:
with self._lock:
self._concurrent -= 1
def getaddrinfo(
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if isinstance(host, bytes):
host = host.decode("ascii")
if host is None:
host = "localhost"
tested_resolvers = []
any_constrained_tried: bool = False
for resolver in self.__resolvers(True):
can_resolve = resolver.support(host)
if can_resolve is True:
any_constrained_tried = True
try:
results = resolver.getaddrinfo(
host,
port,
family,
type,
proto,
flags,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
if results:
return results
except socket.gaierror as exc:
if isinstance(exc.args[0], str) and (
"DNSSEC" in exc.args[0] or "DNSKEY" in exc.args[0]
):
raise
continue
elif can_resolve is False:
tested_resolvers.append(resolver)
if any_constrained_tried:
raise socket.gaierror(
f"Name or service not known: {host} using {self._size - len(self._unconstrained)} resolver(s)"
)
for resolver in self.__resolvers():
try:
results = resolver.getaddrinfo(
host,
port,
family,
type,
proto,
flags,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
if results:
return results
except socket.gaierror as exc:
if isinstance(exc.args[0], str) and (
"DNSSEC" in exc.args[0] or "DNSKEY" in exc.args[0]
):
raise
continue
raise socket.gaierror(
f"Name or service not known: {host} using {self._size - len(self._constrained)} resolver(s)"
)
class SupportedQueryType(int, Enum):
"""
urllib3.future does not need anything else so far. let's be pragmatic.
Each type is associated with its hex value as per the RFC.
"""
A = 0x0001
AAAA = 0x001C
HTTPS = 0x0041
class DomainNameServerQuery:
"""
Minimalist DNS query/message to ask for A, AAAA and HTTPS records.
Only meant for urllib3.future use. Does not cover all of possible extent of use.
"""
def __init__(
self, host: str, query_type: SupportedQueryType, override_id: int | None = None
) -> None:
self._id = struct.pack(
"!H", randint(0x0000, 0xFFFF) if override_id is None else override_id
)
self._host = host
self._query = query_type
self._flags = struct.pack("!H", 0x0100)
self._qd_count = struct.pack("!H", 1)
self._cached: bytes | None = None
@property
def id(self) -> int:
return struct.unpack("!H", self._id)[0] # type: ignore[no-any-return]
@property
def raw_id(self) -> bytes:
return self._id
def __repr__(self) -> str:
return f"<Query '{self._host}' IN {self._query.name}>"
def __bytes__(self) -> bytes:
if self._cached:
return self._cached
payload = b""
payload += self._id
payload += self._flags
payload += self._qd_count
payload += b"\x00\x00"
payload += b"\x00\x00"
payload += b"\x00\x00"
for ext in self._host.split("."):
payload += struct.pack("!B", len(ext))
payload += ext.encode("ascii")
payload += b"\x00"
payload += struct.pack("!H", self._query.value)
payload += struct.pack("!H", 0x0001)
self._cached = payload
return payload
@staticmethod
def bulk(host: str, *types: SupportedQueryType) -> list[DomainNameServerQuery]:
queries = []
for query_type in types:
queries.append(DomainNameServerQuery(host, query_type=query_type))
return queries
#: Most common status code, not exhaustive at all.
COMMON_RCODE_LABEL: dict[int, str] = {
0: "No Error",
1: "Format Error",
2: "Server Failure",
3: "Non-Existent Domain",
5: "Query Refused",
9: "Not Authorized",
}
class DomainNameServerParseException(Exception): ...
class DomainNameServerReturn:
"""
Minimalist DNS response parser. Allow to quickly extract key-data out of it.
Meant for A, AAAA and HTTPS records. Basically only what we need.
"""
def __init__(self, payload: bytes) -> None:
try:
up = struct.unpack("!HHHHHH", payload[:12])
self._id = up[0]
self._flags = up[1]
self._qd_count = up[2]
self._an_count = up[3]
self._rcode = int(f"0x{hex(payload[3])[-1]}", 16)
self._hostname: str = ""
idx = 12
while True:
c = payload[idx]
if c == 0:
idx += 1
break
self._hostname += payload[idx + 1 : idx + 1 + c].decode("ascii") + "."
idx += c + 1
self._records: list[tuple[SupportedQueryType, int, str | HttpsRecord]] = []
if self._an_count:
idx += 4
while idx < len(payload):
up = struct.unpack("!HHHI", payload[idx : idx + 10])
entry_size = struct.unpack("!H", payload[idx + 10 : idx + 12])[0]
data = payload[idx + 12 : idx + 12 + entry_size]
if len(data) == 4:
decoded_data: str | HttpsRecord = inet4_ntoa(data)
elif len(data) == 16:
decoded_data = inet6_ntoa(data)
elif data:
decoded_data = parse_https_rdata(data)
else:
continue
try:
self._records.append(
(SupportedQueryType(up[1]), up[-1], decoded_data)
)
except ValueError:
pass
idx += 12 + entry_size
except (struct.error, IndexError, ValueError, UnicodeDecodeError) as e:
raise DomainNameServerParseException(
"A protocol error occurred while parsing the DNS response payload: "
f"{str(e)}"
) from e
@property
def id(self) -> int:
return self._id # type: ignore[no-any-return]
@property
def hostname(self) -> str:
return self._hostname
@property
def records(self) -> list[tuple[SupportedQueryType, int, str | HttpsRecord]]:
return self._records
@property
def is_found(self) -> bool:
return bool(self._records)
@property
def rcode(self) -> int:
return self._rcode
@property
def is_ok(self) -> bool:
return self._rcode == 0
def __repr__(self) -> str:
if self.is_ok:
return f"<Records '{self.hostname}' {self._records}>"
return f"<DNS Error '{self.hostname}' with Status {self.rcode} ({COMMON_RCODE_LABEL[self.rcode] if self.rcode in COMMON_RCODE_LABEL else 'Unknown'})>"

View File

@@ -0,0 +1,5 @@
from __future__ import annotations
from ._socket import SystemResolver
__all__ = ("SystemResolver",)

View File

@@ -0,0 +1,65 @@
from __future__ import annotations
import socket
import typing
from ..protocols import BaseResolver, ProtocolResolver
class SystemResolver(BaseResolver):
implementation = "socket"
protocol = ProtocolResolver.SYSTEM
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
kwargs.pop("port")
super().__init__(None, None, *patterns, **kwargs)
def support(self, hostname: str | bytes | None) -> bool | None:
if hostname is None:
return True
if isinstance(hostname, bytes):
hostname = hostname.decode("ascii")
if hostname == "localhost":
return True
return super().support(hostname)
def recycle(self) -> BaseResolver:
return self
def close(self) -> None:
pass # no-op!
def is_available(self) -> bool:
return True
def getaddrinfo(
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
# the | tuple[int, bytes] is silently ignored, can't happen with our cases.
return socket.getaddrinfo( # type: ignore[return-value]
host=host,
port=port,
family=family,
type=type,
proto=proto,
flags=flags,
)

View File

@@ -0,0 +1,322 @@
from __future__ import annotations
import base64
import binascii
import socket
import struct
import typing
if typing.TYPE_CHECKING:
class HttpsRecord(typing.TypedDict):
priority: int
target: str
alpn: list[str]
ipv4hint: list[str]
ipv6hint: list[str]
echconfig: list[str]
def inet4_ntoa(address: bytes) -> str:
"""
Convert an IPv4 address from bytes to str.
"""
if len(address) != 4:
raise ValueError(
f"IPv4 addresses are 4 bytes long, got {len(address)} byte(s) instead"
)
return "%u.%u.%u.%u" % (address[0], address[1], address[2], address[3])
def inet6_ntoa(address: bytes) -> str:
"""
Convert an IPv6 address from bytes to str.
"""
if len(address) != 16:
raise ValueError(
f"IPv6 addresses are 16 bytes long, got {len(address)} byte(s) instead"
)
hex = binascii.hexlify(address)
chunks = []
i = 0
length = len(hex)
while i < length:
chunk = hex[i : i + 4].decode().lstrip("0") or "0"
chunks.append(chunk)
i += 4
# Compress the longest subsequence of 0-value chunks to ::
best_start = 0
best_len = 0
start = -1
last_was_zero = False
for i in range(8):
if chunks[i] != "0":
if last_was_zero:
end = i
current_len = end - start
if current_len > best_len:
best_start = start
best_len = current_len
last_was_zero = False
elif not last_was_zero:
start = i
last_was_zero = True
if last_was_zero:
end = 8
current_len = end - start
if current_len > best_len:
best_start = start
best_len = current_len
if best_len > 1:
if best_start == 0 and (best_len == 6 or best_len == 5 and chunks[5] == "ffff"):
# We have an embedded IPv4 address
if best_len == 6:
prefix = "::"
else:
prefix = "::ffff:"
thex = prefix + inet4_ntoa(address[12:])
else:
thex = (
":".join(chunks[:best_start])
+ "::"
+ ":".join(chunks[best_start + best_len :])
)
else:
thex = ":".join(chunks)
return thex
def packet_fragment(payload: bytes, *identifiers: bytes) -> tuple[bytes, ...]:
results = []
offset = 0
start_packet_idx = []
lead_identifier = None
for identifier in identifiers:
idx = payload[:12].find(identifier)
if idx == -1:
continue
if idx != 0:
offset = idx
start_packet_idx.append(idx - offset)
lead_identifier = identifier
break
for identifier in identifiers:
if identifier == lead_identifier:
continue
if offset == 0:
idx = payload.find(b"\x02" + identifier)
else:
idx = payload.find(identifier)
if idx == -1:
continue
start_packet_idx.append(idx - offset)
if not start_packet_idx:
raise ValueError(
"no identifiable dns message emerged from given payload. "
"this should not happen at all. networking issue?"
)
if len(start_packet_idx) == 1:
return (payload,)
start_packet_idx = sorted(start_packet_idx)
previous_idx = None
for idx in start_packet_idx:
if previous_idx is None:
previous_idx = idx
continue
results.append(payload[previous_idx:idx])
previous_idx = idx
results.append(payload[previous_idx:])
return tuple(results)
def is_ipv4(addr: str) -> bool:
try:
socket.inet_aton(addr)
return True
except OSError:
return False
def is_ipv6(addr: str) -> bool:
try:
socket.inet_pton(socket.AF_INET6, addr)
return True
except OSError:
return False
def validate_length_of(hostname: str) -> None:
"""RFC 1035 impose a limit on a domain name length. We verify it there."""
if len(hostname.strip(".")) > 253:
raise UnicodeError("hostname to resolve exceed 253 characters")
elif any([len(_) > 63 for _ in hostname.split(".")]):
raise UnicodeError("at least one label to resolve exceed 63 characters")
def rfc1035_should_read(payload: bytes) -> bool:
if not payload:
return False
if len(payload) <= 2:
return True
cursor = payload
while True:
expected_size: int = struct.unpack("!H", cursor[:2])[0]
if len(cursor[2:]) == expected_size:
return False
elif len(cursor[2:]) < expected_size:
return True
cursor = cursor[2 + expected_size :]
def rfc1035_unpack(payload: bytes) -> tuple[bytes, ...]:
cursor = payload
packets = []
while cursor:
expected_size: int = struct.unpack("!H", cursor[:2])[0]
packets.append(cursor[2 : 2 + expected_size])
cursor = cursor[2 + expected_size :]
return tuple(packets)
def rfc1035_pack(message: bytes) -> bytes:
return struct.pack("!H", len(message)) + message
def read_name(data: bytes, offset: int) -> tuple[str, int]:
"""
Read a DNSencoded name (with compression pointers) from data[offset:].
Returns (name, new_offset).
"""
labels = []
while True:
length = data[offset]
# compression pointer?
if length & 0xC0 == 0xC0:
pointer = struct.unpack_from("!H", data, offset)[0] & 0x3FFF
subname, _ = read_name(data, pointer)
labels.append(subname)
offset += 2
break
if length == 0:
offset += 1
break
offset += 1
labels.append(data[offset : offset + length].decode())
offset += length
return ".".join(labels), offset
def parse_echconfigs(buf: bytes) -> list[str]:
"""
buf is the raw bytes of the ECHConfig vector:
- 2-byte total length, then for each:
- 2-byte cfg length + that many bytes of cfg
We return a list of Base64 strings (one per config).
"""
if len(buf) < 2:
return []
off = 2
total = struct.unpack_from("!H", buf, 0)[0]
end = 2 + total
out = []
while off + 2 <= end:
cfg_len = struct.unpack_from("!H", buf, off)[0]
off += 2
cfg = buf[off : off + cfg_len]
off += cfg_len
out.append(base64.b64encode(cfg).decode())
return out
def parse_https_rdata(rdata: bytes) -> HttpsRecord:
"""
Parse the RDATA of an SVCB/HTTPS record.
Returns a dict with keys: priority, target, alpn, ipv4hint, ipv6hint, echconfig.
"""
off = 0
priority = struct.unpack_from("!H", rdata, off)[0]
off += 2
target, off = read_name(rdata, off)
# pull out all the key/value params
params = {}
while off + 4 <= len(rdata):
key, length = struct.unpack_from("!HH", rdata, off)
off += 4
params[key] = rdata[off : off + length]
off += length
# decode ALPN (key=1), IPv4 (4), IPv6 (6), ECHConfig (5)
def parse_alpn(buf: bytes) -> list[str]:
out = []
i: int = 0
while i < len(buf):
ln = buf[i]
out.append(buf[i + 1 : i + 1 + ln].decode())
i += 1 + ln
return out
alpn: list[str] = parse_alpn(params.get(1, b""))
ipv4 = [
inet4_ntoa(params[4][i : i + 4]) for i in range(0, len(params.get(4, b"")), 4)
]
ipv6 = [
inet6_ntoa(params[6][i : i + 16]) for i in range(0, len(params.get(6, b"")), 16)
]
echconfs = parse_echconfigs(params.get(5, b""))
return {
"priority": priority,
"target": target or ".", # empty name → root
"alpn": alpn,
"ipv4hint": ipv4,
"ipv6hint": ipv6,
"echconfig": echconfs,
}
__all__ = (
"inet4_ntoa",
"inet6_ntoa",
"packet_fragment",
"is_ipv4",
"is_ipv6",
"validate_length_of",
"rfc1035_pack",
"rfc1035_unpack",
"rfc1035_should_read",
"parse_https_rdata",
)

View File

@@ -0,0 +1,453 @@
"""
This module contains provisional support for SOCKS proxies from within
urllib3. This module supports SOCKS4, SOCKS4A (an extension of SOCKS4), and
SOCKS5. To enable its functionality, either install python-socks or install this
module with the ``socks`` extra.
The SOCKS implementation supports the full range of urllib3 features. It also
supports the following SOCKS features:
- SOCKS4A (``proxy_url='socks4a://...``)
- SOCKS4 (``proxy_url='socks4://...``)
- SOCKS5 with remote DNS (``proxy_url='socks5h://...``)
- SOCKS5 with local DNS (``proxy_url='socks5://...``)
- Usernames and passwords for the SOCKS proxy
.. note::
It is recommended to use ``socks5h://`` or ``socks4a://`` schemes in
your ``proxy_url`` to ensure that DNS resolution is done from the remote
server instead of client-side when connecting to a domain name.
SOCKS4 supports IPv4 and domain names with the SOCKS4A extension. SOCKS5
supports IPv4, IPv6, and domain names.
When connecting to a SOCKS4 proxy the ``username`` portion of the ``proxy_url``
will be sent as the ``userid`` section of the SOCKS request:
.. code-block:: python
proxy_url="socks4a://<userid>@proxy-host"
When connecting to a SOCKS5 proxy the ``username`` and ``password`` portion
of the ``proxy_url`` will be sent as the username/password to authenticate
with the proxy:
.. code-block:: python
proxy_url="socks5h://<username>:<password>@proxy-host"
"""
from __future__ import annotations
import warnings
#: We purposely want to support PySocks[...] due to our shadowing of the legacy "urllib3". "Dot not disturb" policy.
BYPASS_SOCKS_LEGACY: bool = False
try:
from python_socks import (
ProxyConnectionError,
ProxyError,
ProxyTimeoutError,
ProxyType,
)
from python_socks.sync import Proxy
from ._socks_override import AsyncioProxy
except ImportError:
from ..exceptions import DependencyWarning
try:
import socks # noqa
except ImportError:
warnings.warn(
(
"SOCKS support in urllib3.future requires the installation of an optional "
"dependency: python-socks. For more information, see "
"https://urllib3future.readthedocs.io/en/latest/contrib.html#socks-proxies"
),
DependencyWarning,
)
else:
from ._socks_legacy import (
SOCKSConnection,
SOCKSHTTPConnectionPool,
SOCKSHTTPSConnection,
SOCKSHTTPSConnectionPool,
SOCKSProxyManager,
)
BYPASS_SOCKS_LEGACY = True
if not BYPASS_SOCKS_LEGACY:
raise
if not BYPASS_SOCKS_LEGACY:
import typing
from socket import socket
from socket import timeout as SocketTimeout
# asynchronous part
from .._async.connection import AsyncHTTPConnection, AsyncHTTPSConnection
from .._async.connectionpool import (
AsyncHTTPConnectionPool,
AsyncHTTPSConnectionPool,
)
from .._async.poolmanager import AsyncPoolManager
from .._typing import _TYPE_SOCKS_OPTIONS
from ..backend import HttpVersion
# synchronous part
from ..connection import HTTPConnection, HTTPSConnection
from ..connectionpool import HTTPConnectionPool, HTTPSConnectionPool
from ..contrib.ssa import AsyncSocket
from ..exceptions import ConnectTimeoutError, NewConnectionError
from ..poolmanager import PoolManager
from ..util.url import parse_url
try:
import ssl
except ImportError:
ssl = None # type: ignore[assignment]
class SOCKSConnection(HTTPConnection): # type: ignore[no-redef]
"""
A plain-text HTTP connection that connects via a SOCKS proxy.
"""
def __init__(
self,
_socks_options: _TYPE_SOCKS_OPTIONS,
*args: typing.Any,
**kwargs: typing.Any,
) -> None:
self._socks_options = _socks_options
super().__init__(*args, **kwargs)
def _new_conn(self) -> socket:
"""
Establish a new connection via the SOCKS proxy.
"""
extra_kw: dict[str, typing.Any] = {}
if self.source_address:
extra_kw["source_address"] = self.source_address
if self.socket_options:
only_tcp_options = []
for opt in self.socket_options:
if len(opt) == 3:
only_tcp_options.append(opt)
elif len(opt) == 4:
protocol: str = opt[3].lower()
if protocol == "udp":
continue
only_tcp_options.append(opt[:3])
extra_kw["socket_options"] = only_tcp_options
try:
assert self._socks_options["proxy_host"] is not None
assert self._socks_options["proxy_port"] is not None
p = Proxy(
proxy_type=self._socks_options["socks_version"], # type: ignore[arg-type]
host=self._socks_options["proxy_host"],
port=int(self._socks_options["proxy_port"]),
username=self._socks_options["username"],
password=self._socks_options["password"],
rdns=self._socks_options["rdns"],
)
_socket = self._resolver.create_connection(
(
self._socks_options["proxy_host"],
int(self._socks_options["proxy_port"]),
),
timeout=self.timeout,
source_address=self.source_address,
socket_options=extra_kw["socket_options"],
quic_upgrade_via_dns_rr=False,
timing_hook=lambda _: setattr(self, "_connect_timings", _),
)
# our dependency started to deprecate passing "_socket"
# which is ... vital for our integration. We'll start by silencing the warning.
# then we'll think on how to proceed.
# A) the maintainer agrees to revert https://github.com/romis2012/python-socks/commit/173a7390469c06aa033f8dca67c827854b462bc3#diff-e4086fa970d1c98b1eb341e58cb70e9ceffe7391b2feecc4b66c7e92ea2de76fR64
# B) the maintainer pursue the removal -> do we vendor our copy of python-socks? is there an alternative?
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
return p.connect(
self.host,
self.port,
self.timeout,
_socket=_socket,
)
except (SocketTimeout, ProxyTimeoutError) as e:
raise ConnectTimeoutError(
self,
f"Connection to {self.host} timed out. (connect timeout={self.timeout})",
) from e
except (ProxyConnectionError, ProxyError) as e:
raise NewConnectionError(
self, f"Failed to establish a new connection: {e}"
) from e
except OSError as e: # Defensive: PySocks should catch all these.
raise NewConnectionError(
self, f"Failed to establish a new connection: {e}"
) from e
# We don't need to duplicate the Verified/Unverified distinction from
# urllib3/connection.py here because the HTTPSConnection will already have been
# correctly set to either the Verified or Unverified form by that module. This
# means the SOCKSHTTPSConnection will automatically be the correct type.
class SOCKSHTTPSConnection(SOCKSConnection, HTTPSConnection): # type: ignore[no-redef]
pass
class SOCKSHTTPConnectionPool(HTTPConnectionPool): # type: ignore[no-redef]
ConnectionCls = SOCKSConnection
class SOCKSHTTPSConnectionPool(HTTPSConnectionPool): # type: ignore[no-redef]
ConnectionCls = SOCKSHTTPSConnection
class SOCKSProxyManager(PoolManager): # type: ignore[no-redef]
"""
A version of the urllib3 ProxyManager that routes connections via the
defined SOCKS proxy.
"""
pool_classes_by_scheme = {
"http": SOCKSHTTPConnectionPool,
"https": SOCKSHTTPSConnectionPool,
}
def __init__(
self,
proxy_url: str,
username: str | None = None,
password: str | None = None,
num_pools: int = 10,
headers: typing.Mapping[str, str] | None = None,
**connection_pool_kw: typing.Any,
):
parsed = parse_url(proxy_url)
if username is None and password is None and parsed.auth is not None:
split = parsed.auth.split(":")
if len(split) == 2:
username, password = split
if parsed.scheme == "socks5":
socks_version = ProxyType.SOCKS5
rdns = False
elif parsed.scheme == "socks5h":
socks_version = ProxyType.SOCKS5
rdns = True
elif parsed.scheme == "socks4":
socks_version = ProxyType.SOCKS4
rdns = False
elif parsed.scheme == "socks4a":
socks_version = ProxyType.SOCKS4
rdns = True
else:
raise ValueError(f"Unable to determine SOCKS version from {proxy_url}")
self.proxy_url = proxy_url
socks_options = {
"socks_version": socks_version,
"proxy_host": parsed.host,
"proxy_port": parsed.port,
"username": username,
"password": password,
"rdns": rdns,
}
connection_pool_kw["_socks_options"] = socks_options
if "disabled_svn" not in connection_pool_kw:
connection_pool_kw["disabled_svn"] = set()
connection_pool_kw["disabled_svn"].add(HttpVersion.h3)
super().__init__(num_pools, headers, **connection_pool_kw)
self.pool_classes_by_scheme = SOCKSProxyManager.pool_classes_by_scheme
class AsyncSOCKSConnection(AsyncHTTPConnection):
"""
A plain-text HTTP connection that connects via a SOCKS proxy.
"""
def __init__(
self,
_socks_options: _TYPE_SOCKS_OPTIONS,
*args: typing.Any,
**kwargs: typing.Any,
) -> None:
self._socks_options = _socks_options
super().__init__(*args, **kwargs)
async def _new_conn(self) -> AsyncSocket: # type: ignore[override]
"""
Establish a new connection via the SOCKS proxy.
"""
extra_kw: dict[str, typing.Any] = {}
if self.source_address:
extra_kw["source_address"] = self.source_address
if self.socket_options:
only_tcp_options = []
for opt in self.socket_options:
if len(opt) == 3:
only_tcp_options.append(opt)
elif len(opt) == 4:
protocol: str = opt[3].lower()
if protocol == "udp":
continue
only_tcp_options.append(opt[:3])
extra_kw["socket_options"] = only_tcp_options
try:
assert self._socks_options["proxy_host"] is not None
assert self._socks_options["proxy_port"] is not None
p = AsyncioProxy(
proxy_type=self._socks_options["socks_version"], # type: ignore[arg-type]
host=self._socks_options["proxy_host"],
port=int(self._socks_options["proxy_port"]),
username=self._socks_options["username"],
password=self._socks_options["password"],
rdns=self._socks_options["rdns"],
)
_socket = await self._resolver.create_connection(
(
self._socks_options["proxy_host"],
int(self._socks_options["proxy_port"]),
),
timeout=self.timeout,
source_address=self.source_address,
socket_options=extra_kw["socket_options"],
quic_upgrade_via_dns_rr=False,
timing_hook=lambda _: setattr(self, "_connect_timings", _),
)
return await p.connect(
self.host,
self.port,
self.timeout,
_socket,
)
except (SocketTimeout, ProxyTimeoutError) as e:
raise ConnectTimeoutError(
self,
f"Connection to {self.host} timed out. (connect timeout={self.timeout})",
) from e
except (ProxyConnectionError, ProxyError) as e:
raise NewConnectionError(
self, f"Failed to establish a new connection: {e}"
) from e
except OSError as e: # Defensive: PySocks should catch all these.
raise NewConnectionError(
self, f"Failed to establish a new connection: {e}"
) from e
# We don't need to duplicate the Verified/Unverified distinction from
# urllib3/connection.py here because the HTTPSConnection will already have been
# correctly set to either the Verified or Unverified form by that module. This
# means the SOCKSHTTPSConnection will automatically be the correct type.
class AsyncSOCKSHTTPSConnection(AsyncSOCKSConnection, AsyncHTTPSConnection):
pass
class AsyncSOCKSHTTPConnectionPool(AsyncHTTPConnectionPool):
ConnectionCls = AsyncSOCKSConnection
class AsyncSOCKSHTTPSConnectionPool(AsyncHTTPSConnectionPool):
ConnectionCls = AsyncSOCKSHTTPSConnection
class AsyncSOCKSProxyManager(AsyncPoolManager):
"""
A version of the urllib3 ProxyManager that routes connections via the
defined SOCKS proxy.
"""
pool_classes_by_scheme = {
"http": AsyncSOCKSHTTPConnectionPool,
"https": AsyncSOCKSHTTPSConnectionPool,
}
def __init__(
self,
proxy_url: str,
username: str | None = None,
password: str | None = None,
num_pools: int = 10,
headers: typing.Mapping[str, str] | None = None,
**connection_pool_kw: typing.Any,
):
parsed = parse_url(proxy_url)
if username is None and password is None and parsed.auth is not None:
split = parsed.auth.split(":")
if len(split) == 2:
username, password = split
if parsed.scheme == "socks5":
socks_version = ProxyType.SOCKS5
rdns = False
elif parsed.scheme == "socks5h":
socks_version = ProxyType.SOCKS5
rdns = True
elif parsed.scheme == "socks4":
socks_version = ProxyType.SOCKS4
rdns = False
elif parsed.scheme == "socks4a":
socks_version = ProxyType.SOCKS4
rdns = True
else:
raise ValueError(f"Unable to determine SOCKS version from {proxy_url}")
self.proxy_url = proxy_url
socks_options = {
"socks_version": socks_version,
"proxy_host": parsed.host,
"proxy_port": parsed.port,
"username": username,
"password": password,
"rdns": rdns,
}
connection_pool_kw["_socks_options"] = socks_options
if "disabled_svn" not in connection_pool_kw:
connection_pool_kw["disabled_svn"] = set()
connection_pool_kw["disabled_svn"].add(HttpVersion.h3)
super().__init__(num_pools, headers, **connection_pool_kw)
self.pool_classes_by_scheme = AsyncSOCKSProxyManager.pool_classes_by_scheme
__all__ = [
"SOCKSConnection",
"SOCKSProxyManager",
"SOCKSHTTPSConnection",
"SOCKSHTTPSConnectionPool",
"SOCKSHTTPConnectionPool",
]
if not BYPASS_SOCKS_LEGACY:
__all__ += [
"AsyncSOCKSConnection",
"AsyncSOCKSHTTPSConnection",
"AsyncSOCKSHTTPConnectionPool",
"AsyncSOCKSHTTPSConnectionPool",
"AsyncSOCKSProxyManager",
]

View File

@@ -0,0 +1,520 @@
from __future__ import annotations
import asyncio
import platform
import socket
import typing
import warnings
from ._timeout import timeout
from ._gro import open_dgram_connection, DatagramReader, DatagramWriter
StandardTimeoutError = socket.timeout
try:
from concurrent.futures import TimeoutError as FutureTimeoutError
except ImportError:
FutureTimeoutError = TimeoutError # type: ignore[misc]
try:
AsyncioTimeoutError = asyncio.exceptions.TimeoutError
except AttributeError:
AsyncioTimeoutError = TimeoutError # type: ignore[misc]
if typing.TYPE_CHECKING:
import ssl
from typing_extensions import Literal
from ..._typing import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT
def _can_shutdown_and_close_selector_loop_bug() -> bool:
import platform
if platform.system() == "Windows" and platform.python_version_tuple()[:2] == (
"3",
"7",
):
return int(platform.python_version_tuple()[-1]) >= 17
return True
# Windows + asyncio bug where doing our shutdown procedure induce a crash
# in SelectorLoop
# File "C:\hostedtoolcache\windows\Python\3.7.9\x64\lib\selectors.py", line 314, in _select
# r, w, x = select.select(r, w, w, timeout)
# [WinError 10038] An operation was attempted on something that is not a socket
_CPYTHON_SELECTOR_CLOSE_BUG_EXIST = _can_shutdown_and_close_selector_loop_bug() is False
class AsyncSocket:
"""
This class is brought to add a level of abstraction to an asyncio transport (reader, or writer)
We don't want to have two distinct code (async/sync) but rather a unified and easily verifiable
code base.
'ssa' stands for Simplified - Socket - Asynchronous.
"""
def __init__(
self,
family: socket.AddressFamily = socket.AF_INET,
type: socket.SocketKind = socket.SOCK_STREAM,
proto: int = -1,
fileno: int | None = None,
) -> None:
self.family: socket.AddressFamily = family
self.type: socket.SocketKind = type
self.proto: int = proto
self._fileno: int | None = fileno
self._connect_called: bool = False
self._established: asyncio.Event = asyncio.Event()
# we do that everytime to forward properly options / advanced settings
self._sock: socket.socket = socket.socket(
family=self.family, type=self.type, proto=self.proto, fileno=fileno
)
# set nonblocking / or cause the loop to block with dgram socket...
self._sock.settimeout(0)
self._writer: asyncio.StreamWriter | DatagramWriter | None = None
self._reader: asyncio.StreamReader | DatagramReader | None = None
self._writer_semaphore: asyncio.Semaphore = asyncio.Semaphore()
self._reader_semaphore: asyncio.Semaphore = asyncio.Semaphore()
self._addr: tuple[str, int] | tuple[str, int, int, int] | None = None
self._external_timeout: float | int | None = None
self._tls_in_tls = False
def fileno(self) -> int:
return self._fileno if self._fileno is not None else self._sock.fileno()
async def wait_for_close(self) -> None:
if self._connect_called:
return
if self._writer is None:
return
try:
# report made in https://github.com/jawah/niquests/issues/184
# made us believe that sometime ssl_transport is freed before
# getting there. So we could end up there with a half broken
# writer state. The original user was using Windows at the time.
is_ssl = self._writer.get_extra_info("ssl_object") is not None
except AttributeError:
is_ssl = False
if is_ssl:
# Give the connection a chance to write any data in the buffer,
# and then forcibly tear down the SSL connection.
await asyncio.sleep(0)
self._writer.transport.abort()
try:
# wait_closed can hang indefinitely!
# on Python 3.8 and 3.9
# there's some case where Python want an explicit EOT
# (spoiler: it was a CPython bug) fixed in recent interpreters.
# to circumvent this and still have a proper close
# we enforce a maximum delay (1000ms).
async with timeout(1):
await self._writer.wait_closed()
except TimeoutError:
pass
def close(self) -> None:
if self._writer is not None:
self._writer.close()
edge_case_close_bug_exist = _CPYTHON_SELECTOR_CLOSE_BUG_EXIST
# Windows + asyncio + asyncio.SelectorEventLoop limits us on how far
# we can safely shutdown the socket.
if not edge_case_close_bug_exist and platform.system() == "Windows":
if hasattr(asyncio, "SelectorEventLoop") and isinstance(
asyncio.get_running_loop(), asyncio.SelectorEventLoop
):
edge_case_close_bug_exist = True
try:
# see https://github.com/MagicStack/uvloop/issues/241
# and https://github.com/jawah/niquests/issues/166
# probably not just uvloop.
uvloop_edge_case_bug = False
# keep track of our clean exit procedure
shutdown_called = False
close_called = False
if hasattr(self._sock, "shutdown"):
try:
self._sock.shutdown(socket.SHUT_RD)
shutdown_called = True
except TypeError:
uvloop_edge_case_bug = True
# uvloop don't support shutdown! and sometime does not support close()...
# see https://github.com/jawah/niquests/issues/166 for ctx.
try:
self._sock.close()
close_called = True
except TypeError:
# last chance of releasing properly the underlying fd!
try:
direct_sock = socket.socket(fileno=self._sock.fileno())
except (OSError, ValueError):
pass
else:
try:
direct_sock.shutdown(socket.SHUT_RD)
shutdown_called = True
except OSError:
warnings.warn(
(
"urllib3-future is unable to properly close your async socket. "
"This mean that you are probably using an asyncio implementation like uvloop "
"that does not support shutdown() or/and close() on the socket transport. "
"This will lead to unclosed socket (fd)."
),
ResourceWarning,
)
finally:
direct_sock.detach()
# we have to force call close() on our sock object (even after shutdown).
# or we'll get a resource warning for sure!
if isinstance(self._sock, socket.socket) and hasattr(self._sock, "close"):
if not uvloop_edge_case_bug and not edge_case_close_bug_exist:
try:
self._sock.close()
close_called = True
except (OSError, TypeError):
pass
if not close_called or not shutdown_called:
# this branch detect whether we have an asyncio.TransportSocket instead of socket.socket.
if hasattr(self._sock, "_sock") and not edge_case_close_bug_exist:
try:
self._sock._sock.close()
except (AttributeError, OSError, TypeError):
pass
except (
OSError
): # branch where we failed to connect and still try to release resource
if isinstance(self._sock, socket.socket):
try:
self._sock.close() # don't call close on asyncio.TransportSocket
except (OSError, TypeError, AttributeError):
pass
elif hasattr(self._sock, "_sock") and not edge_case_close_bug_exist:
try:
self._sock._sock.detach()
except (AttributeError, OSError, TypeError):
pass
self._connect_called = False
self._established.clear()
async def wait_for_readiness(self) -> None:
await self._established.wait()
def setsockopt(self, __level: int, __optname: int, __value: int | bytes) -> None:
self._sock.setsockopt(__level, __optname, __value)
@typing.overload
def getsockopt(self, __level: int, __optname: int) -> int: ...
@typing.overload
def getsockopt(self, __level: int, __optname: int, buflen: int) -> bytes: ...
def getsockopt(
self, __level: int, __optname: int, buflen: int | None = None
) -> int | bytes:
if buflen is None:
return self._sock.getsockopt(__level, __optname)
return self._sock.getsockopt(__level, __optname, buflen)
def should_connect(self) -> bool:
return self._connect_called is False
async def connect(self, addr: tuple[str, int] | tuple[str, int, int, int]) -> None:
if self._connect_called:
raise OSError(
"attempted to connect twice on a already established connection"
)
self._connect_called = True
# there's a particularity on Windows
# we must not forward non-IP in addr due to
# a limitation in the network bridge used in asyncio
if platform.system() == "Windows":
from ..resolver.utils import is_ipv4, is_ipv6
host, port = addr[:2]
if not is_ipv4(host) and not is_ipv6(host):
res = await asyncio.get_running_loop().getaddrinfo(
host,
port,
family=self.family,
type=self.type,
)
if not res:
raise socket.gaierror(f"unable to resolve hostname {host}")
addr = res[0][-1]
if self._external_timeout is not None:
try:
async with timeout(self._external_timeout):
await asyncio.get_running_loop().sock_connect(self._sock, addr)
except (FutureTimeoutError, AsyncioTimeoutError, TimeoutError) as e:
self._connect_called = False
raise StandardTimeoutError from e
except RuntimeError:
raise ConnectionError(
"Likely FD Kernel/Loop Racing Allocation Error. You should retry."
)
else:
try:
await asyncio.get_running_loop().sock_connect(self._sock, addr)
except RuntimeError: # Defensive: CPython might raise RuntimeError if there is a FD allocation error.
raise ConnectionError(
"Likely FD Kernel/Loop Racing Allocation Error. You should retry."
)
if self.type == socket.SOCK_STREAM or self.type == -1: # type: ignore[comparison-overlap]
self._reader, self._writer = await asyncio.open_connection(sock=self._sock)
elif self.type == socket.SOCK_DGRAM:
self._reader, self._writer = await open_dgram_connection(sock=self._sock)
# can become an asyncio.TransportSocket
assert self._writer is not None
self._sock = self._writer.get_extra_info("socket", self._sock)
self._addr = addr
self._established.set()
async def wrap_socket(
self,
ctx: ssl.SSLContext,
*,
server_hostname: str | None = None,
ssl_handshake_timeout: float | None = None,
) -> SSLAsyncSocket:
await self._established.wait()
self._established.clear()
# only if Python <= 3.10
try:
setattr(
asyncio.sslproto._SSLProtocolTransport, # type: ignore[attr-defined]
"_start_tls_compatible",
True,
)
except AttributeError:
pass
if self.type == socket.SOCK_STREAM:
assert self._writer is not None
assert isinstance(self._writer, asyncio.StreamWriter)
# bellow is hard to maintain. Starting with 3.11+, it is useless.
protocol = self._writer._protocol # type: ignore[attr-defined]
await self._writer.drain()
new_transport = await self._writer._loop.start_tls( # type: ignore[attr-defined]
self._writer._transport, # type: ignore[attr-defined]
protocol,
ctx,
server_side=False,
server_hostname=server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
)
self._writer._transport = new_transport # type: ignore[attr-defined]
transport = self._writer.transport
protocol._stream_writer = self._writer
protocol._transport = transport
protocol._over_ssl = transport.get_extra_info("sslcontext") is not None
self._tls_ctx = ctx
else:
raise RuntimeError("Unsupported socket type")
self._established.set()
self.__class__ = SSLAsyncSocket
return self # type: ignore[return-value]
async def recv(self, size: int = -1) -> bytes | list[bytes]:
"""Receive data from the socket.
Returns ``bytes`` for a single datagram (or stream chunk), or
``list[bytes]`` when GRO / batch-receive delivered multiple
coalesced datagrams in one syscall. The caller can then feed
all segments to the QUIC state-machine in a tight loop before
probing, avoiding per-datagram overhead."""
if size == -1:
size = 65536
assert self._reader is not None
await self._established.wait()
await self._reader_semaphore.acquire()
try:
if self._external_timeout is not None:
try:
async with timeout(self._external_timeout):
return await self._reader.read(n=size)
except (FutureTimeoutError, AsyncioTimeoutError, TimeoutError) as e:
self._reader_semaphore.release()
raise StandardTimeoutError from e
except OSError as e: # Defensive: treat any OSError as ConnReset!
raise ConnectionResetError() from e
return await self._reader.read(n=size)
finally:
self._reader_semaphore.release()
async def read_exact(self, size: int = -1) -> bytes | list[bytes]:
"""Just an alias for recv(), it is needed due to our custom AsyncSocks override."""
return await self.recv(size=size)
async def read(self) -> bytes | list[bytes]:
"""Just an alias for recv(), it is needed due to our custom AsyncSocks override."""
return await self.recv()
async def sendall(self, data: bytes | bytearray | memoryview | list[bytes]) -> None:
assert self._writer is not None
await self._established.wait()
await self._writer_semaphore.acquire()
try:
self._writer.write(data) # type: ignore[arg-type]
await self._writer.drain()
except Exception:
raise
finally:
self._writer_semaphore.release()
async def write_all(
self, data: bytes | bytearray | memoryview | list[bytes]
) -> None:
"""Just an alias for sendall(), it is needed due to our custom AsyncSocks override."""
await self.sendall(data)
async def send(self, data: bytes | bytearray | memoryview | list[bytes]) -> None:
await self.sendall(data)
def settimeout(self, __value: float | None = None) -> None:
self._external_timeout = __value
def gettimeout(self) -> float | None:
return self._external_timeout
def getpeername(self) -> tuple[str, int]:
return self._sock.getpeername() # type: ignore[no-any-return]
def bind(self, addr: tuple[str, int]) -> None:
self._sock.bind(addr)
class SSLAsyncSocket(AsyncSocket):
_tls_ctx: ssl.SSLContext
_tls_in_tls: bool
@typing.overload
def getpeercert(
self, binary_form: Literal[False] = ...
) -> _TYPE_PEER_CERT_RET_DICT | None: ...
@typing.overload
def getpeercert(self, binary_form: Literal[True]) -> bytes | None: ...
def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET:
return self.sslobj.getpeercert(binary_form=binary_form) # type: ignore[return-value]
def selected_alpn_protocol(self) -> str | None:
return self.sslobj.selected_alpn_protocol()
@property
def sslobj(self) -> ssl.SSLSocket | ssl.SSLObject:
if self._writer is not None:
sslobj: ssl.SSLSocket | ssl.SSLObject | None = self._writer.get_extra_info(
"ssl_object"
)
if sslobj is not None:
return sslobj
raise RuntimeError(
'"ssl_object" could not be extracted from this SslAsyncSock instance'
)
def version(self) -> str | None:
return self.sslobj.version()
@property
def context(self) -> ssl.SSLContext:
return self.sslobj.context
@property
def _sslobj(self) -> ssl.SSLSocket | ssl.SSLObject:
return self.sslobj
def cipher(self) -> tuple[str, str, int] | None:
return self.sslobj.cipher()
async def wrap_socket(
self,
ctx: ssl.SSLContext,
*,
server_hostname: str | None = None,
ssl_handshake_timeout: float | None = None,
) -> SSLAsyncSocket:
self._tls_in_tls = True
return await super().wrap_socket(
ctx,
server_hostname=server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
)
def _has_complete_support_dgram() -> bool:
"""A bug exist in PyPy asyncio implementation that prevent us to use a DGRAM socket.
This piece of code inform us, potentially, if PyPy has fixed the winapi implementation.
See https://github.com/pypy/pypy/issues/4008 and https://github.com/jawah/niquests/pull/87
The stacktrace look as follows:
File "C:\\hostedtoolcache\\windows\\PyPy\3.10.13\x86\\Lib\asyncio\\windows_events.py", line 594, in connect
_overlapped.WSAConnect(conn.fileno(), address)
AttributeError: module '_overlapped' has no attribute 'WSAConnect'
"""
import platform
if platform.system() == "Windows" and platform.python_implementation() == "PyPy":
try:
import _overlapped # type: ignore[import-not-found]
except ImportError: # Defensive:
return False
if hasattr(_overlapped, "WSAConnect"):
return True
return False
return True
__all__ = (
"AsyncSocket",
"SSLAsyncSocket",
"_has_complete_support_dgram",
)

View File

@@ -0,0 +1,640 @@
"""
High-performance asyncio DatagramTransport with Linux-specific UDP
receive/send coalescing:
- GRO (receive): ``setsockopt(SOL_UDP, UDP_GRO)`` + ``recvmsg`` cmsg
- GSO (send): ``sendmsg`` with ``UDP_SEGMENT`` cmsg
All other platforms fall back to the standard asyncio DatagramTransport.
"""
from __future__ import annotations
import asyncio
import collections
import socket
import struct
from collections import deque
from typing import Any, Callable
from ..._constant import UDP_LINUX_GRO, UDP_LINUX_SEGMENT
_UINT16 = struct.Struct("=H")
_DEFAULT_GRO_BUF = 65535
# Flow control watermarks for the custom write queue
_HIGH_WATERMARK = 64 * 1024
_LOW_WATERMARK = 16 * 1024
# GSO kernel limit: max segments per sendmsg call
_GSO_MAX_SEGMENTS = 64
def _sock_has_gro(sock: socket.socket) -> bool:
"""Check if GRO is enabled on *sock* (caller must have set it)."""
try:
return sock.getsockopt(socket.SOL_UDP, UDP_LINUX_GRO) == 1
except OSError:
return False
def _sock_has_gso(sock: socket.socket) -> bool:
"""Check if the kernel supports GSO on *sock*."""
try:
sock.getsockopt(socket.SOL_UDP, UDP_LINUX_SEGMENT)
return True
except OSError:
return False
def _split_gro_buffer(buf: bytes, segment_size: int) -> list[bytes]:
if segment_size <= 0 or len(buf) <= segment_size:
return [buf]
segments = []
mv = memoryview(buf)
for offset in range(0, len(buf), segment_size):
segments.append(bytes(mv[offset : offset + segment_size]))
return segments
def _group_by_segment_size(datagrams: list[bytes]) -> list[tuple[int, list[bytes]]]:
"""Group consecutive same-size datagrams for Linux UDP GSO.
GSO requires all segments to be the same size (except the last,
which may be shorter). Max 64 segments per ``sendmsg`` call."""
if not datagrams:
return []
groups: list[tuple[int, list[bytes]]] = []
current_size = len(datagrams[0])
current_group: list[bytes] = [datagrams[0]]
for dgram in datagrams[1:]:
if len(dgram) == current_size and len(current_group) < _GSO_MAX_SEGMENTS:
current_group.append(dgram)
else:
groups.append((current_size, current_group))
current_size = len(dgram)
current_group = [dgram]
groups.append((current_size, current_group))
return groups
def sync_recv_gro(
sock: socket.socket, bufsize: int, gro_segment_size: int = 1280
) -> bytes | list[bytes]:
"""Blocking recvmsg with GRO cmsg parsing. Returns bytes or list[bytes]."""
ancbufsize = socket.CMSG_SPACE(_UINT16.size)
data, ancdata, _flags, addr = sock.recvmsg(bufsize, ancbufsize)
if not data:
return b""
segment_size = gro_segment_size
for cmsg_level, cmsg_type, cmsg_data in ancdata:
if cmsg_level == socket.SOL_UDP and cmsg_type == UDP_LINUX_GRO:
(segment_size,) = _UINT16.unpack(cmsg_data[:2])
break
if len(data) <= segment_size:
return data
return _split_gro_buffer(data, segment_size)
def sync_sendmsg_gso(sock: socket.socket, datagrams: list[bytes]) -> None:
"""Batch-send datagrams using GSO. Falls back to individual sends."""
for segment_size, group in _group_by_segment_size(datagrams):
if len(group) == 1:
sock.sendall(group[0])
continue
buf = b"".join(group)
sock.sendmsg(
[buf],
[(socket.SOL_UDP, UDP_LINUX_SEGMENT, _UINT16.pack(segment_size))],
)
class _OptimizedDatagramTransport(asyncio.DatagramTransport):
__slots__ = (
"_loop",
"_sock",
"_protocol",
"_address",
"_gro_enabled",
"_gso_enabled",
"_gro_segment_size",
"_recv_buf_size",
"_closing",
"_closed_fut",
"_extra",
"_paused",
"_write_ready",
"_send_queue",
"_buffer_size",
"_protocol_paused",
)
def __init__(
self,
loop: asyncio.AbstractEventLoop,
sock: socket.socket,
protocol: asyncio.DatagramProtocol,
address: tuple[str, int] | None,
gro_enabled: bool,
gso_enabled: bool,
gro_segment_size: int,
) -> None:
super().__init__()
self._loop = loop
self._sock = sock
self._protocol = protocol
self._address = address
self._gro_enabled = gro_enabled
self._gso_enabled = gso_enabled
self._gro_segment_size = gro_segment_size
self._closing = False
self._closed_fut: asyncio.Future[None] = loop.create_future()
self._paused = False
self._write_ready = True
# Write buffer state
self._send_queue: deque[tuple[bytes, tuple[str, int] | None]] = (
collections.deque()
)
self._buffer_size = 0
self._protocol_paused = False
self._recv_buf_size = _DEFAULT_GRO_BUF if gro_enabled else gro_segment_size
self._extra = {
"peername": address,
"socket": sock,
"sockname": sock.getsockname(),
}
def get_extra_info(self, name: str, default: Any = None) -> Any:
return self._extra.get(name, default)
def is_closing(self) -> bool:
return self._closing
def close(self) -> None:
if self._closing:
return
self._closing = True
self._loop.remove_reader(self._sock.fileno())
# Drain the write queue gracefully in the background
if not self._send_queue:
self._loop.call_soon(self._call_connection_lost, None)
def abort(self) -> None:
self._closing = True
self._call_connection_lost(None)
def _call_connection_lost(self, exc: Exception | None) -> None:
try:
self._loop.remove_reader(self._sock.fileno())
self._loop.remove_writer(self._sock.fileno())
except Exception:
pass
try:
self._protocol.connection_lost(exc)
finally:
self._sock.close()
if not self._closed_fut.done():
self._closed_fut.set_result(None)
def sendto(self, data: bytes, addr: tuple[str, int] | None = None) -> None: # type: ignore[override]
if self._closing:
raise OSError("Transport is closing")
target = addr or self._address
if not self._write_ready:
self._queue_write(data, target)
return
try:
if target is not None:
self._sock.sendto(data, target)
else:
self._sock.send(data)
except BlockingIOError:
self._write_ready = False
self._loop.add_writer(self._sock.fileno(), self._on_write_ready)
self._queue_write(data, target)
except OSError as exc:
self._protocol.error_received(exc)
def sendto_many(self, datagrams: list[bytes]) -> None:
"""Send multiple datagrams, using GSO when available.
Falls back to individual ``sendto`` calls when GSO is not
supported or the socket write buffer is full."""
if self._closing:
raise OSError("Transport is closing")
if not self._write_ready:
target = self._address
for dgram in datagrams:
self._queue_write(dgram, target)
return
if self._gso_enabled:
self._send_linux_gso(datagrams)
else:
for dgram in datagrams:
self.sendto(dgram)
def _send_linux_gso(self, datagrams: list[bytes]) -> None:
for segment_size, group in _group_by_segment_size(datagrams):
if len(group) == 1:
# Single datagram — plain send (GSO needs >1 segment)
try:
self._sock.send(group[0])
except BlockingIOError:
self._write_ready = False
self._loop.add_writer(self._sock.fileno(), self._on_write_ready)
self._queue_write(group[0], self._address)
return
except OSError as exc:
self._protocol.error_received(exc)
continue
buf = b"".join(group)
try:
self._sock.sendmsg(
[buf],
[(socket.SOL_UDP, UDP_LINUX_SEGMENT, _UINT16.pack(segment_size))],
)
except BlockingIOError:
self._write_ready = False
self._loop.add_writer(self._sock.fileno(), self._on_write_ready)
# Queue individual datagrams as fallback
for dgram in group:
self._queue_write(dgram, self._address)
return
except OSError as exc:
self._protocol.error_received(exc)
def _queue_write(self, data: bytes, addr: tuple[str, int] | None) -> None:
self._send_queue.append((data, addr))
self._buffer_size += len(data)
self._maybe_pause_protocol()
def _maybe_pause_protocol(self) -> None:
if self._buffer_size >= _HIGH_WATERMARK and not self._protocol_paused:
self._protocol_paused = True
try:
self._protocol.pause_writing()
except AttributeError:
pass
def _maybe_resume_protocol(self) -> None:
if self._protocol_paused and self._buffer_size <= _LOW_WATERMARK:
self._protocol_paused = False
try:
self._protocol.resume_writing()
except AttributeError:
pass
def _on_write_ready(self) -> None:
while self._send_queue:
data, addr = self._send_queue[0]
try:
if addr is not None:
self._sock.sendto(data, addr)
else:
self._sock.send(data)
except BlockingIOError:
return
except OSError as exc:
self._protocol.error_received(exc)
self._send_queue.popleft()
self._buffer_size -= len(data)
self._maybe_resume_protocol()
self._write_ready = True
self._loop.remove_writer(self._sock.fileno())
if self._closing:
self._call_connection_lost(None)
def pause_reading(self) -> None:
if not self._paused:
self._paused = True
self._loop.remove_reader(self._sock.fileno())
def resume_reading(self) -> None:
if self._paused:
self._paused = False
self._loop.add_reader(self._sock.fileno(), self._on_readable)
def _start(self) -> None:
self._loop.call_soon(self._protocol.connection_made, self)
self._loop.add_reader(self._sock.fileno(), self._on_readable)
def _on_readable(self) -> None:
if self._closing:
return
self._recv_linux_gro()
def _recv_linux_gro(self) -> None:
ancbufsize = socket.CMSG_SPACE(_UINT16.size)
while True:
try:
data, ancdata, _flags, addr = self._sock.recvmsg(
self._recv_buf_size, ancbufsize
)
except BlockingIOError:
return
except OSError as exc:
self._protocol.error_received(exc)
return
if not data:
return
segment_size = self._gro_segment_size
for cmsg_level, cmsg_type, cmsg_data in ancdata:
if cmsg_level == socket.SOL_UDP and cmsg_type == UDP_LINUX_GRO:
(segment_size,) = _UINT16.unpack(cmsg_data[:2])
break
if len(data) <= segment_size:
self._protocol.datagram_received(data, addr)
else:
segments = _split_gro_buffer(data, segment_size)
self._protocol.datagrams_received(segments, addr) # type: ignore[attr-defined]
async def create_udp_endpoint(
loop: asyncio.AbstractEventLoop,
protocol_factory: Callable[[], asyncio.DatagramProtocol],
*,
local_addr: tuple[str, int] | None = None,
remote_addr: tuple[str, int] | None = None,
family: int = socket.AF_UNSPEC,
reuse_port: bool = False,
gro_segment_size: int = 1280,
sock: socket.socket | None = None,
) -> tuple[asyncio.DatagramTransport, asyncio.DatagramProtocol]:
if sock is not None:
# Caller provided a pre-connected socket — skip creation/bind/connect.
try:
connected_addr = sock.getpeername()
except OSError:
connected_addr = None
else:
# 1. Resolve Addresses
if family == socket.AF_UNSPEC:
target_addr = local_addr or remote_addr
if target_addr:
infos = await loop.getaddrinfo(
target_addr[0], target_addr[1], type=socket.SOCK_DGRAM
)
family = infos[0][0]
else:
family = socket.AF_INET
# 2. Create Socket
sock = socket.socket(family, socket.SOCK_DGRAM)
sock.setblocking(False)
if reuse_port:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
if local_addr:
sock.bind(local_addr)
connected_addr = None
if remote_addr:
await loop.sock_connect(sock, remote_addr)
connected_addr = remote_addr
# 3. Determine capabilities — the caller is responsible for
# enabling GRO via setsockopt before handing us the socket.
gro_enabled = _sock_has_gro(sock)
gso_enabled = _sock_has_gso(sock)
if not gro_enabled and not gso_enabled:
return await loop.create_datagram_endpoint(
lambda: protocol_factory(), sock=sock
)
# 4. Wire up optimized transport
protocol = protocol_factory()
transport = _OptimizedDatagramTransport(
loop=loop,
sock=sock,
protocol=protocol,
address=connected_addr,
gro_enabled=gro_enabled,
gso_enabled=gso_enabled,
gro_segment_size=gro_segment_size,
)
transport._start()
return transport, protocol
class DatagramReader:
"""API-compatible with ``asyncio.StreamReader`` (duck-typed) so that
``AsyncSocket`` can assign an instance to ``self._reader`` and the
existing ``recv()`` code works unchanged.
When GRO delivers multiple coalesced segments in a single syscall,
``feed_datagrams()`` stores them as a single ``list[bytes]`` entry.
``read()`` then returns that list directly so the caller can feed
all segments to the QUIC state-machine in one pass before probing —
avoiding the per-datagram recv→feed→probe round-trip overhead."""
def __init__(self) -> None:
self._buffer: deque[bytes | list[bytes]] = collections.deque()
self._waiter: asyncio.Future[None] | None = None
self._exception: BaseException | None = None
self._eof = False
def feed_datagram(self, data: bytes, addr: Any) -> None:
"""Feed a single (non-coalesced) datagram."""
self._buffer.append(data)
self._wake_waiter()
def feed_datagrams(self, data: list[bytes], addr: Any) -> None:
"""Feed a batch of coalesced datagrams as a single entry."""
self._buffer.append(data)
self._wake_waiter()
def set_exception(self, exc: BaseException) -> None:
self._exception = exc
self._wake_waiter()
def connection_lost(self, exc: BaseException | None) -> None:
self._eof = True
if exc is not None:
self._exception = exc
self._wake_waiter()
def _wake_waiter(self) -> None:
waiter = self._waiter
if waiter is not None and not waiter.done():
waiter.set_result(None)
async def read(self, n: int = -1) -> bytes | list[bytes]:
"""Return the next entry from the buffer.
* ``bytes`` — a single datagram (non-coalesced).
* ``list[bytes]`` — a batch of coalesced datagrams from one
GRO syscall.
* ``b""`` — EOF.
"""
if self._buffer:
return self._buffer.popleft()
if self._exception is not None:
raise self._exception
if self._eof:
return b""
self._waiter = asyncio.get_running_loop().create_future()
try:
await self._waiter
finally:
self._waiter = None
if self._buffer:
return self._buffer.popleft()
if self._exception is not None:
raise self._exception
return b""
class DatagramWriter:
"""API-compatible with ``asyncio.StreamWriter`` (duck-typed) so that
``AsyncSocket`` can assign an instance to ``self._writer`` and the
existing ``sendall()``, ``close()``, ``wait_for_close()`` code works
unchanged."""
def __init__(
self,
transport: asyncio.DatagramTransport,
) -> None:
self._transport = transport
self._address: tuple[str, int] | None = transport.get_extra_info("peername")
self._closed_event = asyncio.Event()
self._paused = False
self._drain_waiter: asyncio.Future[None] | None = None
@property
def transport(self) -> asyncio.DatagramTransport:
return self._transport
def write(self, data: bytes | bytearray | memoryview | list[bytes]) -> None:
if self._transport.is_closing():
return
if isinstance(data, list):
if hasattr(self._transport, "sendto_many"):
self._transport.sendto_many(data)
else:
# Plain asyncio transport — send individually
for dgram in data:
self._transport.sendto(dgram, self._address)
else:
self._transport.sendto(bytes(data), self._address)
async def drain(self) -> None:
if not self._paused:
return
self._drain_waiter = asyncio.get_running_loop().create_future()
try:
await self._drain_waiter
finally:
self._drain_waiter = None
def close(self) -> None:
self._transport.close()
async def wait_closed(self) -> None:
await self._closed_event.wait()
def get_extra_info(self, name: str, default: Any = None) -> Any:
return self._transport.get_extra_info(name, default)
def _pause_writing(self) -> None:
self._paused = True
def _resume_writing(self) -> None:
self._paused = False
waiter = self._drain_waiter
if waiter is not None and not waiter.done():
waiter.set_result(None)
class _DatagramBridgeProtocol(asyncio.DatagramProtocol):
"""Bridges ``asyncio.DatagramProtocol`` callbacks to
``DatagramReader`` / ``DatagramWriter``."""
def __init__(self, reader: DatagramReader) -> None:
self._reader = reader
self._writer: DatagramWriter | None = None
def connection_made(self, transport: asyncio.BaseTransport) -> None:
pass # transport is already wired via DatagramWriter
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
self._reader.feed_datagram(data, addr)
def datagrams_received(self, data: list[bytes], addr: tuple[str, int]) -> None:
self._reader.feed_datagrams(data, addr)
def error_received(self, exc: Exception) -> None:
self._reader.set_exception(exc)
def connection_lost(self, exc: BaseException | None) -> None:
self._reader.connection_lost(exc)
if self._writer is not None:
self._writer._closed_event.set()
def pause_writing(self) -> None:
if self._writer is not None:
self._writer._pause_writing()
def resume_writing(self) -> None:
if self._writer is not None:
self._writer._resume_writing()
async def open_dgram_connection(
remote_addr: tuple[str, int] | None = None,
*,
local_addr: tuple[str, int] | None = None,
family: int = socket.AF_UNSPEC,
sock: socket.socket | None = None,
gro_segment_size: int = 1280,
) -> tuple[DatagramReader, DatagramWriter]:
loop = asyncio.get_running_loop()
reader = DatagramReader()
protocol = _DatagramBridgeProtocol(reader)
transport, _ = await create_udp_endpoint(
loop,
lambda: protocol,
local_addr=local_addr,
remote_addr=remote_addr,
family=family,
gro_segment_size=gro_segment_size,
sock=sock,
)
writer = DatagramWriter(transport)
protocol._writer = writer
return reader, writer

View File

@@ -0,0 +1,142 @@
from __future__ import annotations
import enum
from asyncio import CancelledError, events, tasks
from types import TracebackType
__all__ = (
"Timeout",
"timeout",
)
class _State(enum.Enum):
CREATED = "created"
ENTERED = "active"
EXPIRING = "expiring"
EXPIRED = "expired"
EXITED = "finished"
class Timeout:
"""Asynchronous context manager for cancelling overdue coroutines.
Use `timeout()` or `timeout_at()` rather than instantiating this class directly.
"""
def __init__(self, when: float | None) -> None:
"""Schedule a timeout that will trigger at a given loop time.
- If `when` is `None`, the timeout will never trigger.
- If `when < loop.time()`, the timeout will trigger on the next
iteration of the event loop.
"""
self._state = _State.CREATED
self._timeout_handler: events.TimerHandle | events.Handle | None = None
self._task: tasks.Task | None = None # type: ignore[type-arg]
self._when = when
def when(self) -> float | None:
"""Return the current deadline."""
return self._when
def reschedule(self, when: float | None) -> None:
"""Reschedule the timeout."""
if self._state is not _State.ENTERED:
if self._state is _State.CREATED:
raise RuntimeError("Timeout has not been entered")
raise RuntimeError(
f"Cannot change state of {self._state.value} Timeout",
)
self._when = when
if self._timeout_handler is not None:
self._timeout_handler.cancel()
if when is None:
self._timeout_handler = None
else:
loop = events.get_running_loop()
if when <= loop.time():
self._timeout_handler = loop.call_soon(self._on_timeout)
else:
self._timeout_handler = loop.call_at(when, self._on_timeout)
def expired(self) -> bool:
"""Is timeout expired during execution?"""
return self._state in (_State.EXPIRING, _State.EXPIRED)
def __repr__(self) -> str:
info = [""]
if self._state is _State.ENTERED:
when = round(self._when, 3) if self._when is not None else None
info.append(f"when={when}")
info_str = " ".join(info)
return f"<Timeout [{self._state.value}]{info_str}>"
async def __aenter__(self) -> "Timeout":
if self._state is not _State.CREATED:
raise RuntimeError("Timeout has already been entered")
task = tasks.current_task()
if task is None:
raise RuntimeError("Timeout should be used inside a task")
self._state = _State.ENTERED
self._task = task
self.reschedule(self._when)
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
assert self._state in (_State.ENTERED, _State.EXPIRING)
assert self._task is not None
if self._timeout_handler is not None:
self._timeout_handler.cancel()
self._timeout_handler = None
if self._state is _State.EXPIRING:
self._state = _State.EXPIRED
if exc_type is CancelledError:
# Since there are no new cancel requests, we're
# handling this.
raise TimeoutError from exc_val
elif self._state is _State.ENTERED:
self._state = _State.EXITED
return None
def _on_timeout(self) -> None:
assert self._state is _State.ENTERED
assert self._task is not None
self._task.cancel()
self._state = _State.EXPIRING
# drop the reference early
self._timeout_handler = None
def timeout(delay: float | None) -> Timeout:
"""Timeout async context manager.
Useful in cases when you want to apply timeout logic around block
of code or in cases when asyncio.wait_for is not suitable. For example:
>>> async with asyncio.timeout(10): # 10 seconds timeout
... await long_running_task()
delay - value in seconds or None to disable timeout logic
long_running_task() is interrupted by raising asyncio.CancelledError,
the top-most affected timeout() context manager converts CancelledError
into TimeoutError.
"""
loop = events.get_running_loop()
return Timeout(loop.time() + delay if delay is not None else None)

View File

@@ -0,0 +1,60 @@
from __future__ import annotations
from .protocol import ExtensionFromHTTP
from .raw import RawExtensionFromHTTP
from .sse import ServerSideEventExtensionFromHTTP
try:
from .ws import WebSocketExtensionFromHTTP, WebSocketExtensionFromMultiplexedHTTP
except ImportError:
WebSocketExtensionFromHTTP = None # type: ignore[misc, assignment]
WebSocketExtensionFromMultiplexedHTTP = None # type: ignore[misc, assignment]
from typing import TypeVar
T = TypeVar("T")
def recursive_subclasses(cls: type[T]) -> list[type[T]]:
all_subclasses = []
for subclass in cls.__subclasses__():
all_subclasses.append(subclass)
all_subclasses.extend(recursive_subclasses(subclass))
return all_subclasses
def load_extension(
scheme: str | None, implementation: str | None = None
) -> type[ExtensionFromHTTP]:
if scheme is None:
return RawExtensionFromHTTP
scheme = scheme.lower()
if implementation:
implementation = implementation.lower()
for extension in recursive_subclasses(ExtensionFromHTTP):
if scheme in extension.supported_schemes():
if (
implementation is not None
and extension.implementation() != implementation
):
continue
return extension
raise ImportError(
f"Tried to load HTTP extension '{scheme}' but no available plugin support it."
)
__all__ = (
"ExtensionFromHTTP",
"RawExtensionFromHTTP",
"WebSocketExtensionFromHTTP",
"WebSocketExtensionFromMultiplexedHTTP",
"ServerSideEventExtensionFromHTTP",
"load_extension",
)

View File

@@ -0,0 +1,51 @@
from __future__ import annotations
from .protocol import AsyncExtensionFromHTTP
from .raw import AsyncRawExtensionFromHTTP
from .sse import AsyncServerSideEventExtensionFromHTTP
try:
from .ws import (
AsyncWebSocketExtensionFromHTTP,
AsyncWebSocketExtensionFromMultiplexedHTTP,
)
except ImportError:
AsyncWebSocketExtensionFromHTTP = None # type: ignore[misc, assignment]
AsyncWebSocketExtensionFromMultiplexedHTTP = None # type: ignore[misc, assignment]
from .. import recursive_subclasses
def load_extension(
scheme: str | None, implementation: str | None = None
) -> type[AsyncExtensionFromHTTP]:
if scheme is None:
return AsyncRawExtensionFromHTTP
scheme = scheme.lower()
if implementation:
implementation = implementation.lower()
for extension in recursive_subclasses(AsyncExtensionFromHTTP):
if scheme in extension.supported_schemes():
if (
implementation is not None
and extension.implementation() != implementation
):
continue
return extension
raise ImportError(
f"Tried to load HTTP extension '{scheme}' but no available plugin support it."
)
__all__ = (
"AsyncExtensionFromHTTP",
"AsyncRawExtensionFromHTTP",
"AsyncWebSocketExtensionFromHTTP",
"AsyncWebSocketExtensionFromMultiplexedHTTP",
"AsyncServerSideEventExtensionFromHTTP",
"load_extension",
)

View File

@@ -0,0 +1,188 @@
from __future__ import annotations
import typing
from abc import ABCMeta
from contextlib import asynccontextmanager
from socket import timeout as SocketTimeout
if typing.TYPE_CHECKING:
from ...._async.response import AsyncHTTPResponse
from ....backend import HttpVersion
from ....backend._async._base import AsyncDirectStreamAccess
from ....util._async.traffic_police import AsyncTrafficPolice
from ....exceptions import (
BaseSSLError,
ProtocolError,
ReadTimeoutError,
SSLError,
MustRedialError,
)
class AsyncExtensionFromHTTP(metaclass=ABCMeta):
"""Represent an extension that can be negotiated just after a "101 Switching Protocol" HTTP response.
This will considerably ease downstream integration."""
def __init__(self) -> None:
self._dsa: AsyncDirectStreamAccess | None = None
self._response: AsyncHTTPResponse | None = None
self._police_officer: AsyncTrafficPolice | None = None # type: ignore[type-arg]
@asynccontextmanager
async def _read_error_catcher(self) -> typing.AsyncGenerator[None, None]:
"""
Catch low-level python exceptions, instead re-raising urllib3
variants, so that low-level exceptions are not leaked in the
high-level api.
On unrecoverable issues, release the connection back to the pool.
"""
clean_exit = False
try:
try:
yield
except SocketTimeout as e:
clean_exit = True
pool = (
self._response._pool
if self._response and hasattr(self._response, "_pool")
else None
)
raise ReadTimeoutError(pool, None, "Read timed out.") from e # type: ignore[arg-type]
except BaseSSLError as e:
# FIXME: Is there a better way to differentiate between SSLErrors?
if "read operation timed out" not in str(e):
# SSL errors related to framing/MAC get wrapped and reraised here
raise SSLError(e) from e
clean_exit = True # ws algorithms based on timeouts can expect this without being harmful!
pool = (
self._response._pool
if self._response and hasattr(self._response, "_pool")
else None
)
raise ReadTimeoutError(pool, None, "Read timed out.") from e # type: ignore[arg-type]
except (OSError, MustRedialError) as e:
# This includes IncompleteRead.
raise ProtocolError(f"Connection broken: {e!r}", e) from e
# If no exception is thrown, we should avoid cleaning up
# unnecessarily.
clean_exit = True
finally:
# If we didn't terminate cleanly, we need to throw away our
# connection.
if not clean_exit:
# The response may not be closed but we're not going to use it
# anymore so close it now to ensure that the connection is
# released back to the pool.
if self._response:
await self.close()
@asynccontextmanager
async def _write_error_catcher(self) -> typing.AsyncGenerator[None, None]:
"""
Catch low-level python exceptions, instead re-raising urllib3
variants, so that low-level exceptions are not leaked in the
high-level api.
On unrecoverable issues, release the connection back to the pool.
"""
clean_exit = False
try:
try:
yield
except SocketTimeout as e:
pool = (
self._response._pool
if self._response and hasattr(self._response, "_pool")
else None
)
raise ReadTimeoutError(pool, None, "Read timed out.") from e # type: ignore[arg-type]
except BaseSSLError as e:
raise SSLError(e) from e
except OSError as e:
# This includes IncompleteRead.
raise ProtocolError(f"Connection broken: {e!r}", e) from e
# If no exception is thrown, we should avoid cleaning up
# unnecessarily.
clean_exit = True
finally:
# If we didn't terminate cleanly, we need to throw away our
# connection.
if not clean_exit:
# The response may not be closed but we're not going to use it
# anymore so close it now to ensure that the connection is
# released back to the pool.
if self._response:
await self.close()
@property
def urlopen_kwargs(self) -> dict[str, typing.Any]:
return {}
async def start(self, response: AsyncHTTPResponse) -> None:
"""The HTTP server gave us the go-to start negotiating another protocol."""
if response._fp is None or not hasattr(response._fp, "_dsa"):
raise OSError("The HTTP extension is closed or uninitialized")
self._dsa = response._fp._dsa
self._police_officer = response._police_officer
self._response = response
@property
def closed(self) -> bool:
return self._dsa is None
@staticmethod
def supported_svn() -> set[HttpVersion]:
"""Hint about supported parent SVN for this extension."""
raise NotImplementedError
@staticmethod
def implementation() -> str:
raise NotImplementedError
@staticmethod
def supported_schemes() -> set[str]:
"""Recognized schemes for the extension."""
raise NotImplementedError
@staticmethod
def scheme_to_http_scheme(scheme: str) -> str:
"""Convert the extension scheme to a known http scheme (either http or https)"""
raise NotImplementedError
def headers(self, http_version: HttpVersion) -> dict[str, str]:
"""Specific HTTP headers required (request) before the 101 status response."""
raise NotImplementedError
async def close(self) -> None:
"""End/Notify close for sub protocol."""
raise NotImplementedError
async def next_payload(self) -> str | bytes | None:
"""Unpack the next received message/payload from remote. This call does read from the socket.
If the method return None, it means that the remote closed the (extension) pipeline.
"""
raise NotImplementedError
async def send_payload(self, buf: str | bytes) -> None:
"""Dispatch a buffer to remote."""
raise NotImplementedError
async def on_payload(
self, callback: typing.Callable[[str | bytes | None], typing.Awaitable[None]]
) -> None:
"""Set up a callback that will be invoked automatically once a payload is received.
Meaning that you stop calling manually next_payload()."""
raise NotImplementedError

View File

@@ -0,0 +1,57 @@
from __future__ import annotations
from ....backend import HttpVersion
from .protocol import AsyncExtensionFromHTTP
class AsyncRawExtensionFromHTTP(AsyncExtensionFromHTTP):
"""Raw I/O from given HTTP stream after a 101 Switching Protocol Status."""
@staticmethod
def supported_svn() -> set[HttpVersion]:
return {HttpVersion.h11, HttpVersion.h2, HttpVersion.h3}
def headers(self, http_version: HttpVersion) -> dict[str, str]:
"""Specific HTTP headers required (request) before the 101 status response."""
return {}
async def close(self) -> None:
"""End/Notify close for sub protocol."""
if self._dsa is not None:
await self._dsa.close()
self._dsa = None
if self._response is not None:
await self._response.close()
self._response = None
self._police_officer = None
@staticmethod
def implementation() -> str:
return "raw"
@staticmethod
def supported_schemes() -> set[str]:
return set()
@staticmethod
def scheme_to_http_scheme(scheme: str) -> str:
return scheme
async def next_payload(self) -> bytes | None:
if self._police_officer is None or self._dsa is None:
raise OSError("The HTTP extension is closed or uninitialized")
async with self._police_officer.borrow(self._response):
async with self._read_error_catcher():
data, eot, _ = await self._dsa.recv_extended(None)
return data
async def send_payload(self, buf: str | bytes) -> None:
if self._police_officer is None or self._dsa is None:
raise OSError("The HTTP extension is closed or uninitialized")
if isinstance(buf, str):
buf = buf.encode()
async with self._police_officer.borrow(self._response):
async with self._write_error_catcher():
await self._dsa.sendall(buf)

View File

@@ -0,0 +1,132 @@
from __future__ import annotations
import typing
if typing.TYPE_CHECKING:
from ...._async.response import AsyncHTTPResponse
from ....backend import HttpVersion
from ..sse import ServerSentEvent
from .protocol import AsyncExtensionFromHTTP
class AsyncServerSideEventExtensionFromHTTP(AsyncExtensionFromHTTP):
def __init__(self) -> None:
super().__init__()
self._last_event_id: str | None = None
self._buffer: str = ""
self._stream: typing.AsyncGenerator[bytes, None] | None = None
@staticmethod
def supported_svn() -> set[HttpVersion]:
return {HttpVersion.h11, HttpVersion.h2, HttpVersion.h3}
@staticmethod
def implementation() -> str:
return "native"
@property
def urlopen_kwargs(self) -> dict[str, typing.Any]:
return {"preload_content": False}
async def close(self) -> None:
if self._stream is not None and self._response is not None:
await self._stream.aclose()
if (
self._response._fp is not None
and self._police_officer is not None
and hasattr(self._response._fp, "abort")
):
async with self._police_officer.borrow(self._response):
await self._response._fp.abort()
self._stream = None
self._response = None
self._police_officer = None
@property
def closed(self) -> bool:
return self._stream is None
async def start(self, response: AsyncHTTPResponse) -> None:
await super().start(response)
self._stream = response.stream(-1, decode_content=True)
def headers(self, http_version: HttpVersion) -> dict[str, str]:
return {"accept": "text/event-stream", "cache-control": "no-store"}
@typing.overload
async def next_payload(self, *, raw: typing.Literal[True] = True) -> str | None: ...
@typing.overload
async def next_payload(
self, *, raw: typing.Literal[False] = False
) -> ServerSentEvent | None: ...
async def next_payload(self, *, raw: bool = False) -> ServerSentEvent | str | None:
"""Unpack the next received message/payload from remote."""
if self._response is None or self._stream is None:
raise OSError("The HTTP extension is closed or uninitialized")
try:
raw_payload: str = (await self._stream.__anext__()).decode("utf-8")
except StopAsyncIteration:
await self._stream.aclose()
self._stream = None
return None
if self._buffer:
raw_payload = self._buffer + raw_payload
self._buffer = ""
kwargs: dict[str, typing.Any] = {}
eot = False
for line in raw_payload.splitlines():
if not line:
eot = True
break
key, _, value = line.partition(":")
if key not in {"event", "data", "retry", "id"}:
continue
if value.startswith(" "):
value = value[1:]
if key == "id":
if "\u0000" in value:
continue
if key == "retry":
try:
value = int(value) # type: ignore[assignment]
except (ValueError, TypeError):
continue
kwargs[key] = value
if eot is False:
self._buffer = raw_payload
return await self.next_payload(raw=raw) # type: ignore[call-overload,no-any-return]
if "id" not in kwargs and self._last_event_id is not None:
kwargs["id"] = self._last_event_id
event = ServerSentEvent(**kwargs)
if event.id:
self._last_event_id = event.id
if raw is True:
return raw_payload
return event
async def send_payload(self, buf: str | bytes) -> None:
"""Dispatch a buffer to remote."""
raise NotImplementedError("SSE is only one-way. Sending is forbidden.")
@staticmethod
def supported_schemes() -> set[str]:
return {"sse", "psse"}
@staticmethod
def scheme_to_http_scheme(scheme: str) -> str:
return {"sse": "https", "psse": "http"}[scheme]

View File

@@ -0,0 +1,238 @@
from __future__ import annotations
import typing
if typing.TYPE_CHECKING:
from ...._async.response import AsyncHTTPResponse
from wsproto import ConnectionType, WSConnection
from wsproto.events import (
AcceptConnection,
BytesMessage,
CloseConnection,
Ping,
Pong,
Request,
TextMessage,
)
from wsproto.extensions import PerMessageDeflate
from wsproto.utilities import ProtocolError as WebSocketProtocolError
from ....backend import HttpVersion
from ....exceptions import ProtocolError
from .protocol import AsyncExtensionFromHTTP
class AsyncWebSocketExtensionFromHTTP(AsyncExtensionFromHTTP):
def __init__(self) -> None:
super().__init__()
self._protocol = WSConnection(ConnectionType.CLIENT)
self._request_headers: dict[str, str] | None = None
self._remote_shutdown: bool = False
@staticmethod
def supported_svn() -> set[HttpVersion]:
return {HttpVersion.h11}
@staticmethod
def implementation() -> str:
return "wsproto"
async def start(self, response: AsyncHTTPResponse) -> None:
await super().start(response)
fake_http_response = b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n"
fake_http_response += b"Sec-Websocket-Accept: "
accept_token: str | None = response.headers.get("Sec-Websocket-Accept")
if accept_token is None:
raise ProtocolError(
"The WebSocket HTTP extension requires 'Sec-Websocket-Accept' header in the server response but was not present."
)
fake_http_response += accept_token.encode() + b"\r\n"
if "sec-websocket-extensions" in response.headers:
fake_http_response += (
b"Sec-Websocket-Extensions: "
+ response.headers.get("sec-websocket-extensions").encode() # type: ignore[union-attr]
+ b"\r\n"
)
fake_http_response += b"\r\n"
try:
self._protocol.receive_data(fake_http_response)
except WebSocketProtocolError as e:
raise ProtocolError from e # Defensive: should never occur!
event = next(self._protocol.events())
if not isinstance(event, AcceptConnection):
raise RuntimeError(
"The WebSocket state-machine did not pass the handshake phase when expected."
)
def headers(self, http_version: HttpVersion) -> dict[str, str]:
"""Specific HTTP headers required (request) before the 101 status response."""
if self._request_headers is not None:
return self._request_headers
try:
raw_data_to_socket = self._protocol.send(
Request(
host="example.com", target="/", extensions=(PerMessageDeflate(),)
)
)
except WebSocketProtocolError as e:
raise ProtocolError from e # Defensive: should never occur!
raw_headers = raw_data_to_socket.split(b"\r\n")[2:-2]
request_headers: dict[str, str] = {}
for raw_header in raw_headers:
k, v = raw_header.decode().split(": ")
request_headers[k.lower()] = v
if http_version != HttpVersion.h11:
del request_headers["upgrade"]
del request_headers["connection"]
request_headers[":protocol"] = "websocket"
request_headers[":method"] = "CONNECT"
self._request_headers = request_headers
return request_headers
async def close(self) -> None:
"""End/Notify close for sub protocol."""
if self._dsa is not None:
if self._police_officer is not None:
async with self._police_officer.borrow(self._response):
if self._remote_shutdown is False:
try:
data_to_send: bytes = self._protocol.send(
CloseConnection(0)
)
except WebSocketProtocolError:
pass
else:
async with self._write_error_catcher():
await self._dsa.sendall(data_to_send)
await self._dsa.close()
self._dsa = None
else:
self._dsa = None
if self._response is not None:
if self._police_officer is not None:
self._police_officer.forget(self._response)
else:
await self._response.close()
self._response = None
self._police_officer = None
async def next_payload(self) -> str | bytes | None:
"""Unpack the next received message/payload from remote."""
if self._dsa is None or self._response is None or self._police_officer is None:
raise OSError("The HTTP extension is closed or uninitialized")
async with self._police_officer.borrow(self._response):
for event in self._protocol.events():
if isinstance(event, TextMessage):
return event.data
elif isinstance(event, BytesMessage):
return event.data
elif isinstance(event, CloseConnection):
self._remote_shutdown = True
await self.close()
return None
elif isinstance(event, Ping):
try:
data_to_send: bytes = self._protocol.send(event.response())
except WebSocketProtocolError as e:
await self.close()
raise ProtocolError from e
async with self._write_error_catcher():
await self._dsa.sendall(data_to_send)
while True:
async with self._read_error_catcher():
data, eot, _ = await self._dsa.recv_extended(None)
try:
self._protocol.receive_data(data)
except WebSocketProtocolError as e:
raise ProtocolError from e
for event in self._protocol.events():
if isinstance(event, TextMessage):
return event.data
elif isinstance(event, BytesMessage):
return event.data
elif isinstance(event, CloseConnection):
self._remote_shutdown = True
await self.close()
return None
elif isinstance(event, Ping):
data_to_send = self._protocol.send(event.response())
async with self._write_error_catcher():
await self._dsa.sendall(data_to_send)
elif isinstance(event, Pong):
continue
async def send_payload(self, buf: str | bytes) -> None:
"""Dispatch a buffer to remote."""
if self._dsa is None or self._response is None or self._police_officer is None:
raise OSError("The HTTP extension is closed or uninitialized")
async with self._police_officer.borrow(self._response):
try:
if isinstance(buf, str):
data_to_send: bytes = self._protocol.send(TextMessage(buf))
else:
data_to_send = self._protocol.send(BytesMessage(buf))
except WebSocketProtocolError as e:
raise ProtocolError from e
async with self._write_error_catcher():
await self._dsa.sendall(data_to_send)
async def ping(self) -> None:
if self._dsa is None or self._response is None or self._police_officer is None:
raise OSError("The HTTP extension is closed or uninitialized")
async with self._police_officer.borrow(self._response):
try:
data_to_send: bytes = self._protocol.send(Ping())
except WebSocketProtocolError as e:
raise ProtocolError from e
async with self._write_error_catcher():
await self._dsa.sendall(data_to_send)
@staticmethod
def supported_schemes() -> set[str]:
return {"ws", "wss"}
@staticmethod
def scheme_to_http_scheme(scheme: str) -> str:
return {"ws": "http", "wss": "https"}[scheme]
class AsyncWebSocketExtensionFromMultiplexedHTTP(AsyncWebSocketExtensionFromHTTP):
"""
Plugin that support doing WebSocket over HTTP 2 and 3.
This implement RFC8441. Beware that this isn't actually supported by much server around internet.
"""
@staticmethod
def implementation() -> str:
return "rfc8441"
@staticmethod
def supported_svn() -> set[HttpVersion]:
return {HttpVersion.h11, HttpVersion.h2, HttpVersion.h3}

View File

@@ -0,0 +1,189 @@
from __future__ import annotations
import typing
from abc import ABCMeta
from contextlib import contextmanager
from socket import timeout as SocketTimeout
if typing.TYPE_CHECKING:
from ...backend import HttpVersion
from ...backend._base import DirectStreamAccess
from ...response import HTTPResponse
from ...util.traffic_police import TrafficPolice
from ...exceptions import (
BaseSSLError,
ProtocolError,
ReadTimeoutError,
SSLError,
MustRedialError,
)
class ExtensionFromHTTP(metaclass=ABCMeta):
"""Represent an extension that can be negotiated just after a "101 Switching Protocol" HTTP response.
This will considerably ease downstream integration."""
def __init__(self) -> None:
self._dsa: DirectStreamAccess | None = None
self._response: HTTPResponse | None = None
self._police_officer: TrafficPolice | None = None # type: ignore[type-arg]
@contextmanager
def _read_error_catcher(self) -> typing.Generator[None, None, None]:
"""
Catch low-level python exceptions, instead re-raising urllib3
variants, so that low-level exceptions are not leaked in the
high-level api.
On unrecoverable issues, release the connection back to the pool.
"""
clean_exit = False
try:
try:
yield
except SocketTimeout as e:
clean_exit = True
pool = (
self._response._pool
if self._response and hasattr(self._response, "_pool")
else None
)
raise ReadTimeoutError(pool, None, "Read timed out.") from e # type: ignore[arg-type]
except BaseSSLError as e:
# FIXME: Is there a better way to differentiate between SSLErrors?
if "read operation timed out" not in str(e):
# SSL errors related to framing/MAC get wrapped and reraised here
raise SSLError(e) from e
clean_exit = True # ws algorithms based on timeouts can expect this without being harmful!
pool = (
self._response._pool
if self._response and hasattr(self._response, "_pool")
else None
)
raise ReadTimeoutError(pool, None, "Read timed out.") from e # type: ignore[arg-type]
except (OSError, MustRedialError) as e:
# This includes IncompleteRead.
raise ProtocolError(f"Connection broken: {e!r}", e) from e
# If no exception is thrown, we should avoid cleaning up
# unnecessarily.
clean_exit = True
finally:
# If we didn't terminate cleanly, we need to throw away our
# connection.
if not clean_exit:
# The response may not be closed but we're not going to use it
# anymore so close it now to ensure that the connection is
# released back to the pool.
if self._response:
self.close()
@contextmanager
def _write_error_catcher(self) -> typing.Generator[None, None, None]:
"""
Catch low-level python exceptions, instead re-raising urllib3
variants, so that low-level exceptions are not leaked in the
high-level api.
On unrecoverable issues, release the connection back to the pool.
"""
clean_exit = False
try:
try:
yield
except SocketTimeout as e:
pool = (
self._response._pool
if self._response and hasattr(self._response, "_pool")
else None
)
raise ReadTimeoutError(pool, None, "Read timed out.") from e # type: ignore[arg-type]
except BaseSSLError as e:
raise SSLError(e) from e
except OSError as e:
# This includes IncompleteRead.
raise ProtocolError(f"Connection broken: {e!r}", e) from e
# If no exception is thrown, we should avoid cleaning up
# unnecessarily.
clean_exit = True
finally:
# If we didn't terminate cleanly, we need to throw away our
# connection.
if not clean_exit:
# The response may not be closed but we're not going to use it
# anymore so close it now to ensure that the connection is
# released back to the pool.
if self._response:
self.close()
@property
def urlopen_kwargs(self) -> dict[str, typing.Any]:
"""Return prerequisites. Must be passed as additional parameters to urlopen."""
return {}
def start(self, response: HTTPResponse) -> None:
"""The HTTP server gave us the go-to start negotiating another protocol."""
if response._fp is None or not hasattr(response._fp, "_dsa"):
raise RuntimeError(
"Attempt to start an HTTP extension without direct I/O access to the stream"
)
self._dsa = response._fp._dsa
self._police_officer = response._police_officer
self._response = response
@property
def closed(self) -> bool:
return self._dsa is None
@staticmethod
def supported_svn() -> set[HttpVersion]:
"""Hint about supported parent SVN for this extension."""
raise NotImplementedError
@staticmethod
def implementation() -> str:
raise NotImplementedError
@staticmethod
def supported_schemes() -> set[str]:
"""Recognized schemes for the extension."""
raise NotImplementedError
@staticmethod
def scheme_to_http_scheme(scheme: str) -> str:
"""Convert the extension scheme to a known http scheme (either http or https)"""
raise NotImplementedError
def headers(self, http_version: HttpVersion) -> dict[str, str]:
"""Specific HTTP headers required (request) before the 101 status response."""
raise NotImplementedError
def close(self) -> None:
"""End/Notify close for sub protocol."""
raise NotImplementedError
def next_payload(self) -> str | bytes | None:
"""Unpack the next received message/payload from remote. This call does read from the socket.
If the method return None, it means that the remote closed the (extension) pipeline.
"""
raise NotImplementedError
def send_payload(self, buf: str | bytes) -> None:
"""Dispatch a buffer to remote."""
raise NotImplementedError
def on_payload(self, callback: typing.Callable[[str | bytes | None], None]) -> None:
"""Set up a callback that will be invoked automatically once a payload is received.
Meaning that you stop calling manually next_payload()."""
raise NotImplementedError

View File

@@ -0,0 +1,58 @@
from __future__ import annotations
from ...backend import HttpVersion
from .protocol import ExtensionFromHTTP
class RawExtensionFromHTTP(ExtensionFromHTTP):
"""Raw I/O from given HTTP stream after a 101 Switching Protocol Status."""
@staticmethod
def supported_svn() -> set[HttpVersion]:
return {HttpVersion.h11, HttpVersion.h2, HttpVersion.h3}
def headers(self, http_version: HttpVersion) -> dict[str, str]:
"""Specific HTTP headers required (request) before the 101 status response."""
return {}
def close(self) -> None:
"""End/Notify close for sub protocol."""
if self._dsa is not None:
with self._write_error_catcher():
self._dsa.close()
self._dsa = None
if self._response is not None:
self._response.close()
self._response = None
self._police_officer = None
@staticmethod
def implementation() -> str:
return "raw"
@staticmethod
def supported_schemes() -> set[str]:
return set()
@staticmethod
def scheme_to_http_scheme(scheme: str) -> str:
return scheme
def next_payload(self) -> bytes | None:
if self._police_officer is None or self._dsa is None:
raise OSError("The HTTP extension is closed or uninitialized")
with self._police_officer.borrow(self._response):
with self._read_error_catcher():
data, eot, _ = self._dsa.recv_extended(None)
return data
def send_payload(self, buf: str | bytes) -> None:
if self._police_officer is None or self._dsa is None:
raise OSError("The HTTP extension is closed or uninitialized")
if isinstance(buf, str):
buf = buf.encode()
with self._police_officer.borrow(self._response):
with self._write_error_catcher():
self._dsa.sendall(buf)

View File

@@ -0,0 +1,185 @@
from __future__ import annotations
import json
import typing
from threading import RLock
if typing.TYPE_CHECKING:
from ...response import HTTPResponse
from ...backend import HttpVersion
from .protocol import ExtensionFromHTTP
class ServerSentEvent:
def __init__(
self,
event: str | None = None,
data: str | None = None,
id: str | None = None,
retry: int | None = None,
) -> None:
if not event:
event = "message"
if data is None:
data = ""
if id is None:
id = ""
self._event = event
self._data = data
self._id = id
self._retry = retry
@property
def event(self) -> str:
return self._event
@property
def data(self) -> str:
return self._data
@property
def id(self) -> str:
return self._id
@property
def retry(self) -> int | None:
return self._retry
def json(self) -> typing.Any:
return json.loads(self.data)
def __repr__(self) -> str:
pieces = [f"event={self.event!r}"]
if self.data != "":
pieces.append(f"data={self.data!r}")
if self.id != "":
pieces.append(f"id={self.id!r}")
if self.retry is not None:
pieces.append(f"retry={self.retry!r}")
return f"ServerSentEvent({', '.join(pieces)})"
class ServerSideEventExtensionFromHTTP(ExtensionFromHTTP):
def __init__(self) -> None:
super().__init__()
self._last_event_id: str | None = None
self._buffer: str = ""
self._lock = RLock()
self._stream: typing.Generator[bytes, None, None] | None = None
@staticmethod
def supported_svn() -> set[HttpVersion]:
return {HttpVersion.h11, HttpVersion.h2, HttpVersion.h3}
@staticmethod
def implementation() -> str:
return "native"
@property
def urlopen_kwargs(self) -> dict[str, typing.Any]:
return {"preload_content": False}
@property
def closed(self) -> bool:
return self._stream is None
def close(self) -> None:
if self._stream is not None and self._response is not None:
self._stream.close()
if (
self._response._fp is not None
and self._police_officer is not None
and hasattr(self._response._fp, "abort")
):
with self._police_officer.borrow(self._response):
self._response._fp.abort()
self._stream = None
self._response = None
self._police_officer = None
def start(self, response: HTTPResponse) -> None:
super().start(response)
self._stream = response.stream(-1, decode_content=True)
def headers(self, http_version: HttpVersion) -> dict[str, str]:
return {"accept": "text/event-stream", "cache-control": "no-store"}
@typing.overload
def next_payload(self, *, raw: typing.Literal[True] = True) -> str | None: ...
@typing.overload
def next_payload(
self, *, raw: typing.Literal[False] = False
) -> ServerSentEvent | None: ...
def next_payload(self, *, raw: bool = False) -> ServerSentEvent | str | None:
"""Unpack the next received message/payload from remote."""
if self._response is None or self._stream is None:
raise OSError("The HTTP extension is closed or uninitialized")
with self._lock:
try:
raw_payload: str = next(self._stream).decode("utf-8")
except StopIteration:
self._stream = None
return None
if self._buffer:
raw_payload = self._buffer + raw_payload
self._buffer = ""
kwargs: dict[str, typing.Any] = {}
eot = False
for line in raw_payload.splitlines():
if not line:
eot = True
break
key, _, value = line.partition(":")
if key not in {"event", "data", "retry", "id"}:
continue
if value.startswith(" "):
value = value[1:]
if key == "id":
if "\u0000" in value:
continue
if key == "retry":
try:
value = int(value) # type: ignore[assignment]
except (ValueError, TypeError):
continue
kwargs[key] = value
if eot is False:
self._buffer = raw_payload
return self.next_payload(raw=raw) # type: ignore[call-overload,no-any-return]
if "id" not in kwargs and self._last_event_id is not None:
kwargs["id"] = self._last_event_id
event = ServerSentEvent(**kwargs)
if event.id:
self._last_event_id = event.id
if raw is True:
return raw_payload
return event
def send_payload(self, buf: str | bytes) -> None:
"""Dispatch a buffer to remote."""
raise NotImplementedError("SSE is only one-way. Sending is forbidden.")
@staticmethod
def supported_schemes() -> set[str]:
return {"sse", "psse"}
@staticmethod
def scheme_to_http_scheme(scheme: str) -> str:
return {"sse": "https", "psse": "http"}[scheme]

View File

@@ -0,0 +1,247 @@
from __future__ import annotations
import typing
if typing.TYPE_CHECKING:
from ...response import HTTPResponse
from wsproto import ConnectionType, WSConnection
from wsproto.events import (
AcceptConnection,
BytesMessage,
CloseConnection,
Ping,
Pong,
Request,
TextMessage,
)
from wsproto.extensions import PerMessageDeflate
from wsproto.utilities import ProtocolError as WebSocketProtocolError
from ...backend import HttpVersion
from ...exceptions import ProtocolError
from .protocol import ExtensionFromHTTP
class WebSocketExtensionFromHTTP(ExtensionFromHTTP):
def __init__(self) -> None:
super().__init__()
self._protocol = WSConnection(ConnectionType.CLIENT)
self._request_headers: dict[str, str] | None = None
self._remote_shutdown: bool = False
@staticmethod
def supported_svn() -> set[HttpVersion]:
return {HttpVersion.h11}
@staticmethod
def implementation() -> str:
return "wsproto"
def start(self, response: HTTPResponse) -> None:
super().start(response)
fake_http_response = b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n"
fake_http_response += b"Sec-Websocket-Accept: "
accept_token: str | None = response.headers.get("Sec-Websocket-Accept")
if accept_token is None:
raise ProtocolError(
"The WebSocket HTTP extension requires 'Sec-Websocket-Accept' header in the server response but was not present."
)
fake_http_response += accept_token.encode() + b"\r\n"
if "sec-websocket-extensions" in response.headers:
fake_http_response += (
b"Sec-Websocket-Extensions: "
+ response.headers.get("sec-websocket-extensions").encode() # type: ignore[union-attr]
+ b"\r\n"
)
fake_http_response += b"\r\n"
try:
self._protocol.receive_data(fake_http_response)
except WebSocketProtocolError as e:
raise ProtocolError from e # Defensive: should never happen
event = next(self._protocol.events())
if not isinstance(event, AcceptConnection):
raise RuntimeError(
"The WebSocket state-machine did not pass the handshake phase when expected."
)
def headers(self, http_version: HttpVersion) -> dict[str, str]:
"""Specific HTTP headers required (request) before the 101 status response."""
if self._request_headers is not None:
return self._request_headers
try:
raw_data_to_socket = self._protocol.send(
Request(
host="example.com", target="/", extensions=(PerMessageDeflate(),)
)
)
except WebSocketProtocolError as e:
raise ProtocolError from e # Defensive: should never happen
raw_headers = raw_data_to_socket.split(b"\r\n")[2:-2]
request_headers: dict[str, str] = {}
for raw_header in raw_headers:
k, v = raw_header.decode().split(": ")
request_headers[k.lower()] = v
if http_version != HttpVersion.h11:
del request_headers["upgrade"]
del request_headers["connection"]
request_headers[":protocol"] = "websocket"
request_headers[":method"] = "CONNECT"
self._request_headers = request_headers
return request_headers
def close(self) -> None:
"""End/Notify close for sub protocol."""
if self._dsa is not None:
if self._police_officer is not None:
with self._police_officer.borrow(self._response):
if self._remote_shutdown is False:
try:
data_to_send: bytes = self._protocol.send(
CloseConnection(0)
)
except WebSocketProtocolError:
pass
else:
with self._write_error_catcher():
self._dsa.sendall(data_to_send)
self._dsa.close()
self._dsa = None
else:
self._dsa = None
if self._response is not None:
if self._police_officer is not None:
self._police_officer.forget(self._response)
else:
self._response.close()
self._response = None
self._police_officer = None
def next_payload(self) -> str | bytes | None:
"""Unpack the next received message/payload from remote."""
if self._dsa is None or self._response is None or self._police_officer is None:
raise OSError("The HTTP extension is closed or uninitialized")
with self._police_officer.borrow(self._response):
# we may have pending event to unpack!
for event in self._protocol.events():
if isinstance(event, TextMessage):
return event.data
elif isinstance(event, BytesMessage):
return event.data
elif isinstance(event, CloseConnection):
self._remote_shutdown = True
self.close()
return None
elif isinstance(event, Ping):
try:
data_to_send: bytes = self._protocol.send(event.response())
except WebSocketProtocolError as e:
self.close()
raise ProtocolError from e
with self._write_error_catcher():
self._dsa.sendall(data_to_send)
while True:
with self._read_error_catcher():
data, eot, _ = self._dsa.recv_extended(None)
try:
self._protocol.receive_data(data)
except WebSocketProtocolError as e:
self.close()
raise ProtocolError from e
for event in self._protocol.events():
if isinstance(event, TextMessage):
return event.data
elif isinstance(event, BytesMessage):
return event.data
elif isinstance(event, CloseConnection):
self._remote_shutdown = True
self.close()
return None
elif isinstance(event, Ping):
try:
data_to_send = self._protocol.send(event.response())
except WebSocketProtocolError as e:
self.close()
raise ProtocolError from e
with self._write_error_catcher():
self._dsa.sendall(data_to_send)
elif isinstance(event, Pong):
continue
def send_payload(self, buf: str | bytes) -> None:
"""Dispatch a buffer to remote."""
if self._dsa is None or self._response is None or self._police_officer is None:
raise OSError("The HTTP extension is closed or uninitialized")
with self._police_officer.borrow(self._response):
try:
if isinstance(buf, str):
data_to_send: bytes = self._protocol.send(TextMessage(buf))
else:
data_to_send = self._protocol.send(BytesMessage(buf))
except WebSocketProtocolError as e:
self.close()
raise ProtocolError from e
with self._write_error_catcher():
self._dsa.sendall(data_to_send)
def ping(self) -> None:
if self._dsa is None or self._response is None or self._police_officer is None:
raise OSError("The HTTP extension is closed or uninitialized")
with self._police_officer.borrow(self._response):
if self._remote_shutdown is False:
try:
data_to_send: bytes = self._protocol.send(Ping())
except WebSocketProtocolError as e:
self.close()
raise ProtocolError from e
with self._write_error_catcher():
self._dsa.sendall(data_to_send)
@staticmethod
def supported_schemes() -> set[str]:
return {"ws", "wss"}
@staticmethod
def scheme_to_http_scheme(scheme: str) -> str:
return {"ws": "http", "wss": "https"}[scheme]
class WebSocketExtensionFromMultiplexedHTTP(WebSocketExtensionFromHTTP):
"""
Plugin that support doing WebSocket over HTTP 2 and 3.
This implement RFC8441. Beware that this isn't actually supported by much server around internet.
"""
@staticmethod
def implementation() -> str:
return "rfc8441" # also known as rfc9220 (http3)
@staticmethod
def supported_svn() -> set[HttpVersion]:
return {HttpVersion.h11, HttpVersion.h2, HttpVersion.h3}