fix: 포트 충돌 회피 — note_bridge 8098, intent_service 8099
Jellyfin(8096), OrbStack(8097) 포트 충돌으로 변경. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
159
.venv/lib/python3.9/site-packages/urllib3_future/__init__.py
Normal file
159
.venv/lib/python3.9/site-packages/urllib3_future/__init__.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Python HTTP library with thread-safe connection pooling, file post support, user-friendly, and more
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# Set default logging handler to avoid "No handler found" warnings.
|
||||
import logging
|
||||
import typing
|
||||
import warnings
|
||||
from logging import NullHandler
|
||||
from os import environ
|
||||
|
||||
from . import exceptions
|
||||
from ._async.connectionpool import AsyncHTTPConnectionPool, AsyncHTTPSConnectionPool
|
||||
from ._async.connectionpool import connection_from_url as async_connection_from_url
|
||||
from ._async.poolmanager import AsyncPoolManager, AsyncProxyManager
|
||||
from ._async.poolmanager import proxy_from_url as async_proxy_from_url
|
||||
from ._async.response import AsyncHTTPResponse
|
||||
from ._collections import HTTPHeaderDict
|
||||
from ._typing import _TYPE_BODY, _TYPE_FIELDS
|
||||
from ._version import __version__
|
||||
from .backend import ConnectionInfo, HttpVersion, ResponsePromise
|
||||
from .connectionpool import HTTPConnectionPool, HTTPSConnectionPool, connection_from_url
|
||||
from .contrib.resolver import ResolverDescription
|
||||
from .contrib.resolver._async import AsyncResolverDescription
|
||||
from .filepost import encode_multipart_formdata
|
||||
from .poolmanager import PoolManager, ProxyManager, proxy_from_url
|
||||
from .response import BaseHTTPResponse, HTTPResponse
|
||||
from .util.request import make_headers
|
||||
from .util.retry import Retry
|
||||
from .util.timeout import Timeout
|
||||
|
||||
__author__ = "Andrey Petrov (andrey.petrov@shazow.net)"
|
||||
__license__ = "MIT"
|
||||
__version__ = __version__
|
||||
|
||||
__all__ = (
|
||||
"HTTPConnectionPool",
|
||||
"HTTPHeaderDict",
|
||||
"HTTPSConnectionPool",
|
||||
"PoolManager",
|
||||
"ProxyManager",
|
||||
"HTTPResponse",
|
||||
"Retry",
|
||||
"Timeout",
|
||||
"add_stderr_logger",
|
||||
"connection_from_url",
|
||||
"disable_warnings",
|
||||
"encode_multipart_formdata",
|
||||
"make_headers",
|
||||
"proxy_from_url",
|
||||
"request",
|
||||
"BaseHTTPResponse",
|
||||
"HttpVersion",
|
||||
"ConnectionInfo",
|
||||
"ResponsePromise",
|
||||
"ResolverDescription",
|
||||
"AsyncHTTPResponse",
|
||||
"AsyncResolverDescription",
|
||||
"AsyncHTTPConnectionPool",
|
||||
"AsyncHTTPSConnectionPool",
|
||||
"AsyncPoolManager",
|
||||
"AsyncProxyManager",
|
||||
"async_proxy_from_url",
|
||||
"async_connection_from_url",
|
||||
)
|
||||
|
||||
logging.getLogger(__name__).addHandler(NullHandler())
|
||||
|
||||
|
||||
def add_stderr_logger(
|
||||
level: int = logging.DEBUG,
|
||||
) -> logging.StreamHandler[typing.TextIO]:
|
||||
"""
|
||||
Helper for quickly adding a StreamHandler to the logger. Useful for
|
||||
debugging.
|
||||
|
||||
Returns the handler after adding it.
|
||||
"""
|
||||
# This method needs to be in this __init__.py to get the __name__ correct
|
||||
# even if urllib3 is vendored within another package.
|
||||
logger = logging.getLogger(__name__)
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(level)
|
||||
logger.debug("Added a stderr logging handler to logger: %s", __name__)
|
||||
return handler
|
||||
|
||||
|
||||
# ... Clean up.
|
||||
del NullHandler
|
||||
|
||||
|
||||
if (
|
||||
environ.get("SSHKEYLOGFILE", None) is not None
|
||||
or environ.get("QUICLOGDIR", None) is not None
|
||||
):
|
||||
warnings.warn( # Defensive: security warning only. not feature.
|
||||
"urllib3.future detected that development/debug environment variable are set. "
|
||||
"If you are unaware of it please audit your environment. "
|
||||
"Variables 'SSHKEYLOGFILE' and 'QUICLOGDIR' can only be set in a non-production environment.",
|
||||
exceptions.SecurityWarning,
|
||||
)
|
||||
|
||||
# All warning filters *must* be appended unless you're really certain that they
|
||||
# shouldn't be: otherwise, it's very hard for users to use most Python
|
||||
# mechanisms to silence them.
|
||||
# SecurityWarning's always go off by default.
|
||||
warnings.simplefilter("always", exceptions.SecurityWarning, append=True)
|
||||
# InsecurePlatformWarning's don't vary between requests, so we keep it default.
|
||||
warnings.simplefilter("default", exceptions.InsecurePlatformWarning, append=True)
|
||||
|
||||
|
||||
def disable_warnings(category: type[Warning] = exceptions.HTTPWarning) -> None:
|
||||
"""
|
||||
Helper for quickly disabling all urllib3 warnings.
|
||||
"""
|
||||
warnings.simplefilter("ignore", category)
|
||||
|
||||
|
||||
_DEFAULT_POOL = PoolManager()
|
||||
|
||||
|
||||
def request(
|
||||
method: str,
|
||||
url: str,
|
||||
*,
|
||||
body: _TYPE_BODY | None = None,
|
||||
fields: _TYPE_FIELDS | None = None,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
preload_content: bool | None = True,
|
||||
decode_content: bool | None = True,
|
||||
redirect: bool | None = True,
|
||||
retries: Retry | bool | int | None = None,
|
||||
timeout: Timeout | float | int | None = 3,
|
||||
json: typing.Any | None = None,
|
||||
) -> HTTPResponse:
|
||||
"""
|
||||
A convenience, top-level request method. It uses a module-global ``PoolManager`` instance.
|
||||
Therefore, its side effects could be shared across dependencies relying on it.
|
||||
To avoid side effects create a new ``PoolManager`` instance and use it instead.
|
||||
The method does not accept low-level ``**urlopen_kw``.
|
||||
"""
|
||||
|
||||
return _DEFAULT_POOL.request(
|
||||
method,
|
||||
url,
|
||||
body=body,
|
||||
fields=fields,
|
||||
headers=headers,
|
||||
preload_content=preload_content,
|
||||
decode_content=decode_content,
|
||||
redirect=redirect,
|
||||
retries=retries,
|
||||
timeout=timeout,
|
||||
json=json,
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,539 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import json as _json
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
from socket import timeout as SocketTimeout
|
||||
|
||||
from .._collections import HTTPHeaderDict
|
||||
from .._typing import _TYPE_BODY
|
||||
from ..backend._async import AsyncLowLevelResponse
|
||||
from ..exceptions import (
|
||||
BaseSSLError,
|
||||
HTTPError,
|
||||
IncompleteRead,
|
||||
ProtocolError,
|
||||
ReadTimeoutError,
|
||||
ResponseNotReady,
|
||||
SSLError,
|
||||
MustRedialError,
|
||||
)
|
||||
from ..response import ContentDecoder, HTTPResponse
|
||||
from ..util.response import is_fp_closed, BytesQueueBuffer
|
||||
from ..util.retry import Retry
|
||||
from .connection import AsyncHTTPConnection
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from email.message import Message
|
||||
|
||||
from .._async.connectionpool import AsyncHTTPConnectionPool
|
||||
from ..contrib.webextensions._async import AsyncExtensionFromHTTP
|
||||
from ..util._async.traffic_police import AsyncTrafficPolice
|
||||
|
||||
|
||||
class AsyncHTTPResponse(HTTPResponse):
|
||||
def __init__(
|
||||
self,
|
||||
body: _TYPE_BODY = "",
|
||||
headers: typing.Mapping[str, str] | typing.Mapping[bytes, bytes] | None = None,
|
||||
status: int = 0,
|
||||
version: int = 0,
|
||||
reason: str | None = None,
|
||||
preload_content: bool = True,
|
||||
decode_content: bool = True,
|
||||
original_response: AsyncLowLevelResponse | None = None,
|
||||
pool: AsyncHTTPConnectionPool | None = None,
|
||||
connection: AsyncHTTPConnection | None = None,
|
||||
msg: Message | None = None,
|
||||
retries: Retry | None = None,
|
||||
enforce_content_length: bool = True,
|
||||
request_method: str | None = None,
|
||||
request_url: str | None = None,
|
||||
auto_close: bool = True,
|
||||
police_officer: AsyncTrafficPolice[AsyncHTTPConnection] | None = None,
|
||||
) -> None:
|
||||
if isinstance(headers, HTTPHeaderDict):
|
||||
self.headers = headers
|
||||
else:
|
||||
self.headers = HTTPHeaderDict(headers) # type: ignore[arg-type]
|
||||
try:
|
||||
self.status = int(status)
|
||||
except ValueError:
|
||||
self.status = 0 # merely for tests, was supported due to broken httplib.
|
||||
self.version = version
|
||||
self.reason = reason
|
||||
self.decode_content = decode_content
|
||||
self._has_decoded_content = False
|
||||
self._request_url: str | None = request_url
|
||||
self._retries: Retry | None = None
|
||||
|
||||
self._extension: AsyncExtensionFromHTTP | None = None # type: ignore[assignment]
|
||||
|
||||
self.retries = retries
|
||||
|
||||
self.chunked = False
|
||||
|
||||
if "transfer-encoding" in self.headers:
|
||||
tr_enc = self.headers.get("transfer-encoding", "").lower()
|
||||
# Don't incur the penalty of creating a list and then discarding it
|
||||
encodings = (enc.strip() for enc in tr_enc.split(","))
|
||||
|
||||
if "chunked" in encodings:
|
||||
self.chunked = True
|
||||
|
||||
self._decoder: ContentDecoder | None = None
|
||||
|
||||
self.enforce_content_length = enforce_content_length
|
||||
self.auto_close = auto_close
|
||||
|
||||
self._body = None
|
||||
self._fp: AsyncLowLevelResponse | typing.IO[typing.Any] | None = None # type: ignore[assignment]
|
||||
self._original_response = original_response # type: ignore[assignment]
|
||||
self._fp_bytes_read = 0
|
||||
|
||||
if msg is not None:
|
||||
warnings.warn(
|
||||
"Passing msg=.. is deprecated and no-op in urllib3.future and is scheduled to be removed in a future major.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
self.msg = msg
|
||||
|
||||
if body and isinstance(body, (str, bytes)):
|
||||
self._body = body
|
||||
|
||||
self._pool: AsyncHTTPConnectionPool = pool # type: ignore[assignment]
|
||||
self._connection: AsyncHTTPConnection = connection # type: ignore[assignment]
|
||||
|
||||
if hasattr(body, "read"):
|
||||
self._fp = body # type: ignore[assignment]
|
||||
|
||||
# Are we using the chunked-style of transfer encoding?
|
||||
self.chunk_left: int | None = None
|
||||
|
||||
# Determine length of response
|
||||
self._request_method: str | None = request_method
|
||||
self.length_remaining: int | None = self._init_length(self._request_method)
|
||||
|
||||
# Used to return the correct amount of bytes for partial read()s
|
||||
self._decoded_buffer = BytesQueueBuffer()
|
||||
|
||||
self._police_officer: AsyncTrafficPolice[AsyncHTTPConnection] | None = (
|
||||
police_officer # type: ignore[assignment]
|
||||
)
|
||||
|
||||
self._preloaded_content: bool = preload_content
|
||||
|
||||
if self._police_officer is not None:
|
||||
self._police_officer.memorize(self, self._connection)
|
||||
# we can utilize a ConnectionPool without level-0 PoolManager!
|
||||
if self._police_officer.parent is not None:
|
||||
self._police_officer.parent.memorize(self, self._pool)
|
||||
|
||||
async def readinto(self, b: bytearray) -> int: # type: ignore[override]
|
||||
temp = await self.read(len(b))
|
||||
if len(temp) == 0:
|
||||
return 0
|
||||
else:
|
||||
b[: len(temp)] = temp
|
||||
return len(temp)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _error_catcher(self) -> typing.AsyncGenerator[None, None]: # type: ignore[override]
|
||||
"""
|
||||
Catch low-level python exceptions, instead re-raising urllib3
|
||||
variants, so that low-level exceptions are not leaked in the
|
||||
high-level api.
|
||||
|
||||
On exit, release the connection back to the pool.
|
||||
"""
|
||||
clean_exit = False
|
||||
|
||||
try:
|
||||
try:
|
||||
yield
|
||||
|
||||
except SocketTimeout as e:
|
||||
# FIXME: Ideally we'd like to include the url in the ReadTimeoutError but
|
||||
# there is yet no clean way to get at it from this context.
|
||||
raise ReadTimeoutError(self._pool, None, "Read timed out.") from e # type: ignore[arg-type]
|
||||
|
||||
except BaseSSLError as e:
|
||||
# FIXME: Is there a better way to differentiate between SSLErrors?
|
||||
if "read operation timed out" not in str(e):
|
||||
# SSL errors related to framing/MAC get wrapped and reraised here
|
||||
raise SSLError(e) from e
|
||||
|
||||
raise ReadTimeoutError(self._pool, None, "Read timed out.") from e # type: ignore[arg-type]
|
||||
|
||||
except (OSError, MustRedialError) as e:
|
||||
# This includes IncompleteRead.
|
||||
raise ProtocolError(f"Connection broken: {e!r}", e) from e
|
||||
|
||||
# If no exception is thrown, we should avoid cleaning up
|
||||
# unnecessarily.
|
||||
clean_exit = True
|
||||
finally:
|
||||
# If we didn't terminate cleanly, we need to throw away our
|
||||
# connection.
|
||||
if not clean_exit:
|
||||
# The response may not be closed but we're not going to use it
|
||||
# anymore so close it now to ensure that the connection is
|
||||
# released back to the pool.
|
||||
if self._original_response:
|
||||
self._original_response.close()
|
||||
|
||||
# Closing the response may not actually be sufficient to close
|
||||
# everything, so if we have a hold of the connection close that
|
||||
# too.
|
||||
if self._connection:
|
||||
await self._connection.close()
|
||||
|
||||
# If we hold the original response but it's closed now, we should
|
||||
# return the connection back to the pool.
|
||||
if self._original_response and self._original_response.isclosed():
|
||||
self.release_conn()
|
||||
|
||||
async def drain_conn(self) -> None: # type: ignore[override]
|
||||
"""
|
||||
Read and discard any remaining HTTP response data in the response connection.
|
||||
|
||||
Unread data in the HTTPResponse connection blocks the connection from being released back to the pool.
|
||||
"""
|
||||
try:
|
||||
await self.read(
|
||||
# Do not spend resources decoding the content unless
|
||||
# decoding has already been initiated.
|
||||
decode_content=self._has_decoded_content,
|
||||
)
|
||||
except (HTTPError, OSError, BaseSSLError):
|
||||
pass
|
||||
|
||||
@property
|
||||
def trailers(self) -> HTTPHeaderDict | None:
|
||||
"""
|
||||
Retrieve post-response (trailing headers) if any.
|
||||
This WILL return None if no HTTP Trailer Headers have been received.
|
||||
"""
|
||||
if self._fp is None:
|
||||
return None
|
||||
|
||||
if hasattr(self._fp, "trailers"):
|
||||
return self._fp.trailers
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def extension(self) -> AsyncExtensionFromHTTP | None: # type: ignore[override]
|
||||
return self._extension
|
||||
|
||||
async def start_extension(self, item: AsyncExtensionFromHTTP) -> None: # type: ignore[override]
|
||||
if self._extension is not None:
|
||||
raise OSError("extension already plugged in")
|
||||
|
||||
if not hasattr(self._fp, "_dsa"):
|
||||
raise ResponseNotReady()
|
||||
|
||||
await item.start(self)
|
||||
|
||||
self._extension = item
|
||||
|
||||
async def json(self) -> typing.Any:
|
||||
"""
|
||||
Parses the body of the HTTP response as JSON.
|
||||
|
||||
To use a custom JSON decoder pass the result of :attr:`HTTPResponse.data` to the decoder.
|
||||
|
||||
This method can raise either `UnicodeDecodeError` or `json.JSONDecodeError`.
|
||||
|
||||
Read more :ref:`here <json>`.
|
||||
"""
|
||||
data = (await self.data).decode("utf-8")
|
||||
return _json.loads(data)
|
||||
|
||||
@property
|
||||
async def data(self) -> bytes: # type: ignore[override]
|
||||
# For backwards-compat with earlier urllib3 0.4 and earlier.
|
||||
if self._body:
|
||||
return self._body # type: ignore[return-value]
|
||||
|
||||
if self._fp:
|
||||
return await self.read(cache_content=True)
|
||||
|
||||
return None # type: ignore[return-value]
|
||||
|
||||
async def _fp_read(self, amt: int | None = None) -> bytes: # type: ignore[override]
|
||||
"""
|
||||
Read a response with the thought that reading the number of bytes
|
||||
larger than can fit in a 32-bit int at a time via SSL in some
|
||||
known cases leads to an overflow error that has to be prevented
|
||||
if `amt` or `self.length_remaining` indicate that a problem may
|
||||
happen.
|
||||
|
||||
The known cases:
|
||||
* 3.8 <= CPython < 3.9.7 because of a bug
|
||||
https://github.com/urllib3/urllib3/issues/2513#issuecomment-1152559900.
|
||||
* urllib3 injected with pyOpenSSL-backed SSL-support.
|
||||
* CPython < 3.10 only when `amt` does not fit 32-bit int.
|
||||
"""
|
||||
assert self._fp
|
||||
c_int_max = 2**31 - 1
|
||||
if (
|
||||
(amt and amt > c_int_max)
|
||||
or (self.length_remaining and self.length_remaining > c_int_max)
|
||||
) and sys.version_info < (3, 10):
|
||||
buffer = io.BytesIO()
|
||||
# Besides `max_chunk_amt` being a maximum chunk size, it
|
||||
# affects memory overhead of reading a response by this
|
||||
# method in CPython.
|
||||
# `c_int_max` equal to 2 GiB - 1 byte is the actual maximum
|
||||
# chunk size that does not lead to an overflow error, but
|
||||
# 256 MiB is a compromise.
|
||||
max_chunk_amt = 2**28
|
||||
while amt is None or amt != 0:
|
||||
if amt is not None:
|
||||
chunk_amt = min(amt, max_chunk_amt)
|
||||
amt -= chunk_amt
|
||||
else:
|
||||
chunk_amt = max_chunk_amt
|
||||
try:
|
||||
if isinstance(self._fp, AsyncLowLevelResponse):
|
||||
data = await self._fp.read(chunk_amt)
|
||||
else:
|
||||
data = self._fp.read(chunk_amt) # type: ignore[attr-defined]
|
||||
except ValueError: # Defensive: overly protective
|
||||
break # Defensive: can also be an indicator that read ended, should not happen.
|
||||
if not data:
|
||||
break
|
||||
buffer.write(data)
|
||||
del data # to reduce peak memory usage by `max_chunk_amt`.
|
||||
return buffer.getvalue()
|
||||
else:
|
||||
# StringIO doesn't like amt=None
|
||||
if isinstance(self._fp, AsyncLowLevelResponse):
|
||||
return await self._fp.read(amt)
|
||||
return self._fp.read(amt) if amt is not None else self._fp.read() # type: ignore[no-any-return]
|
||||
|
||||
async def _raw_read( # type: ignore[override]
|
||||
self,
|
||||
amt: int | None = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Reads `amt` of bytes from the socket.
|
||||
"""
|
||||
if self._fp is None:
|
||||
return None # type: ignore[return-value]
|
||||
|
||||
fp_closed = getattr(self._fp, "closed", False)
|
||||
|
||||
async with self._error_catcher():
|
||||
data = (await self._fp_read(amt)) if not fp_closed else b""
|
||||
|
||||
# Mocking library often use io.BytesIO
|
||||
# which does not auto-close when reading data
|
||||
# with amt=None.
|
||||
is_foreign_fp_unclosed = (
|
||||
amt is None and getattr(self._fp, "closed", False) is False
|
||||
)
|
||||
|
||||
if (amt is not None and amt != 0 and not data) or is_foreign_fp_unclosed:
|
||||
if is_foreign_fp_unclosed:
|
||||
self._fp_bytes_read += len(data)
|
||||
if self.length_remaining is not None:
|
||||
self.length_remaining -= len(data)
|
||||
|
||||
# Platform-specific: Buggy versions of Python.
|
||||
# Close the connection when no data is returned
|
||||
#
|
||||
# This is redundant to what httplib/http.client _should_
|
||||
# already do. However, versions of python released before
|
||||
# December 15, 2012 (http://bugs.python.org/issue16298) do
|
||||
# not properly close the connection in all cases. There is
|
||||
# no harm in redundantly calling close.
|
||||
self._fp.close()
|
||||
if (
|
||||
self.enforce_content_length
|
||||
and self.length_remaining is not None
|
||||
and self.length_remaining != 0
|
||||
):
|
||||
# This is an edge case that httplib failed to cover due
|
||||
# to concerns of backward compatibility. We're
|
||||
# addressing it here to make sure IncompleteRead is
|
||||
# raised during streaming, so all calls with incorrect
|
||||
# Content-Length are caught.
|
||||
raise IncompleteRead(self._fp_bytes_read, self.length_remaining)
|
||||
|
||||
if data and not is_foreign_fp_unclosed:
|
||||
self._fp_bytes_read += len(data)
|
||||
if self.length_remaining is not None:
|
||||
self.length_remaining -= len(data)
|
||||
|
||||
return data
|
||||
|
||||
async def read1( # type: ignore[override]
|
||||
self,
|
||||
amt: int | None = None,
|
||||
decode_content: bool | None = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Similar to ``http.client.HTTPResponse.read1`` and documented
|
||||
in :meth:`io.BufferedReader.read1`, but with an additional parameter:
|
||||
``decode_content``.
|
||||
|
||||
:param amt:
|
||||
How much of the content to read.
|
||||
|
||||
:param decode_content:
|
||||
If True, will attempt to decode the body based on the
|
||||
'content-encoding' header.
|
||||
"""
|
||||
|
||||
data = await self.read(
|
||||
amt=amt or -1,
|
||||
decode_content=decode_content,
|
||||
)
|
||||
|
||||
if amt is not None and len(data) > amt:
|
||||
self._decoded_buffer.put(data)
|
||||
return self._decoded_buffer.get(amt)
|
||||
|
||||
return data
|
||||
|
||||
async def read( # type: ignore[override]
|
||||
self,
|
||||
amt: int | None = None,
|
||||
decode_content: bool | None = None,
|
||||
cache_content: bool = False,
|
||||
) -> bytes:
|
||||
try:
|
||||
self._init_decoder()
|
||||
if decode_content is None:
|
||||
decode_content = self.decode_content
|
||||
|
||||
if amt is not None:
|
||||
cache_content = False
|
||||
|
||||
if amt < 0 and len(self._decoded_buffer):
|
||||
return self._decoded_buffer.get(len(self._decoded_buffer))
|
||||
|
||||
if 0 < amt <= len(self._decoded_buffer):
|
||||
return self._decoded_buffer.get(amt)
|
||||
|
||||
if self._police_officer is not None:
|
||||
async with self._police_officer.borrow(self):
|
||||
data = await self._raw_read(amt)
|
||||
else:
|
||||
data = await self._raw_read(amt)
|
||||
|
||||
if amt and amt < 0:
|
||||
amt = len(data)
|
||||
|
||||
flush_decoder = False
|
||||
if amt is None:
|
||||
flush_decoder = True
|
||||
elif amt != 0 and not data:
|
||||
flush_decoder = True
|
||||
|
||||
if not data and len(self._decoded_buffer) == 0:
|
||||
return data
|
||||
|
||||
if amt is None:
|
||||
data = self._decode(data, decode_content, flush_decoder)
|
||||
if cache_content:
|
||||
self._body = data
|
||||
else:
|
||||
# do not waste memory on buffer when not decoding
|
||||
if not decode_content:
|
||||
if self._has_decoded_content:
|
||||
raise RuntimeError(
|
||||
"Calling read(decode_content=False) is not supported after "
|
||||
"read(decode_content=True) was called."
|
||||
)
|
||||
return data
|
||||
|
||||
decoded_data = self._decode(data, decode_content, flush_decoder)
|
||||
self._decoded_buffer.put(decoded_data)
|
||||
|
||||
while len(self._decoded_buffer) < amt and data:
|
||||
# TODO make sure to initially read enough data to get past the headers
|
||||
# For example, the GZ file header takes 10 bytes, we don't want to read
|
||||
# it one byte at a time
|
||||
if self._police_officer is not None:
|
||||
async with self._police_officer.borrow(self):
|
||||
data = await self._raw_read(amt)
|
||||
else:
|
||||
data = await self._raw_read(amt)
|
||||
|
||||
decoded_data = self._decode(data, decode_content, flush_decoder)
|
||||
self._decoded_buffer.put(decoded_data)
|
||||
data = self._decoded_buffer.get(amt)
|
||||
|
||||
return data
|
||||
finally:
|
||||
if (
|
||||
self._fp
|
||||
and hasattr(self._fp, "_eot")
|
||||
and self._fp._eot
|
||||
and self._police_officer is not None
|
||||
):
|
||||
# an HTTP extension could be live, we don't want to accidentally kill it!
|
||||
if (
|
||||
not hasattr(self._fp, "_dsa")
|
||||
or self._fp._dsa is None
|
||||
or self._fp._dsa.closed is True
|
||||
):
|
||||
self._police_officer.forget(self)
|
||||
self._police_officer = None
|
||||
|
||||
async def stream( # type: ignore[override]
|
||||
self, amt: int | None = 2**16, decode_content: bool | None = None
|
||||
) -> typing.AsyncGenerator[bytes, None]:
|
||||
if self._fp is None:
|
||||
return
|
||||
while not is_fp_closed(self._fp) or len(self._decoded_buffer) > 0:
|
||||
data = await self.read(amt=amt, decode_content=decode_content)
|
||||
|
||||
if data:
|
||||
yield data
|
||||
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
if self.extension is not None and self.extension.closed is False:
|
||||
await self.extension.close()
|
||||
|
||||
if not self.closed and self._fp:
|
||||
self._fp.close()
|
||||
|
||||
if self._connection:
|
||||
await self._connection.close()
|
||||
|
||||
if not self.auto_close:
|
||||
io.IOBase.close(self)
|
||||
|
||||
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
buffer: list[bytes] = []
|
||||
async for chunk in self.stream(-1, decode_content=True):
|
||||
if b"\n" in chunk:
|
||||
chunks = chunk.split(b"\n")
|
||||
yield b"".join(buffer) + chunks[0] + b"\n"
|
||||
for x in chunks[1:-1]:
|
||||
yield x + b"\n"
|
||||
if chunks[-1]:
|
||||
buffer = [chunks[-1]]
|
||||
else:
|
||||
buffer = []
|
||||
else:
|
||||
buffer.append(chunk)
|
||||
if buffer:
|
||||
yield b"".join(buffer)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if not self.closed:
|
||||
if not self.closed and self._fp:
|
||||
self._fp.close()
|
||||
|
||||
if not self.auto_close:
|
||||
io.IOBase.close(self)
|
||||
440
.venv/lib/python3.9/site-packages/urllib3_future/_collections.py
Normal file
440
.venv/lib/python3.9/site-packages/urllib3_future/_collections.py
Normal file
@@ -0,0 +1,440 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from collections import OrderedDict
|
||||
from enum import Enum, auto
|
||||
from functools import lru_cache
|
||||
from threading import RLock
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
# We can only import Protocol if TYPE_CHECKING because it's a development
|
||||
# dependency, and is not available at runtime.
|
||||
from typing_extensions import Protocol
|
||||
|
||||
class HasGettableStringKeys(Protocol):
|
||||
def keys(self) -> typing.Iterator[str]: ...
|
||||
|
||||
def __getitem__(self, key: str) -> str: ...
|
||||
|
||||
|
||||
__all__ = ["RecentlyUsedContainer", "HTTPHeaderDict"]
|
||||
|
||||
|
||||
# Key type
|
||||
_KT = typing.TypeVar("_KT")
|
||||
# Value type
|
||||
_VT = typing.TypeVar("_VT")
|
||||
# Default type
|
||||
_DT = typing.TypeVar("_DT")
|
||||
|
||||
ValidHTTPHeaderSource = typing.Union[
|
||||
"HTTPHeaderDict",
|
||||
typing.Mapping[str, str],
|
||||
typing.Iterable[typing.Tuple[str, str]],
|
||||
"HasGettableStringKeys",
|
||||
]
|
||||
|
||||
|
||||
class _Sentinel(Enum):
|
||||
not_passed = auto()
|
||||
|
||||
|
||||
@lru_cache(maxsize=64)
|
||||
def _lower_wrapper(string: str) -> str:
|
||||
"""Reasoning: We are often calling lower on repetitive identical header key. This was unnecessary exhausting!"""
|
||||
return string.lower()
|
||||
|
||||
|
||||
def ensure_can_construct_http_header_dict(
|
||||
potential: object,
|
||||
) -> ValidHTTPHeaderSource | None:
|
||||
if isinstance(potential, HTTPHeaderDict):
|
||||
return potential
|
||||
elif isinstance(potential, typing.Mapping):
|
||||
# Full runtime checking of the contents of a Mapping is expensive, so for the
|
||||
# purposes of typechecking, we assume that any Mapping is the right shape.
|
||||
return typing.cast(typing.Mapping[str, str], potential)
|
||||
elif isinstance(potential, typing.Iterable):
|
||||
# Similarly to Mapping, full runtime checking of the contents of an Iterable is
|
||||
# expensive, so for the purposes of typechecking, we assume that any Iterable
|
||||
# is the right shape.
|
||||
return typing.cast(typing.Iterable[typing.Tuple[str, str]], potential)
|
||||
elif hasattr(potential, "keys") and hasattr(potential, "__getitem__"):
|
||||
return typing.cast("HasGettableStringKeys", potential)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class RecentlyUsedContainer(typing.Generic[_KT, _VT], typing.MutableMapping[_KT, _VT]):
|
||||
"""
|
||||
Provides a thread-safe dict-like container which maintains up to
|
||||
``maxsize`` keys while throwing away the least-recently-used keys beyond
|
||||
``maxsize``. Caution: RecentlyUsedContainer is deprecated and scheduled for
|
||||
removal in a next major of urllib3.future. It has been replaced by a more
|
||||
suitable implementation in ``urllib3.util.traffic_police``.
|
||||
|
||||
:param maxsize:
|
||||
Maximum number of recent elements to retain.
|
||||
|
||||
:param dispose_func:
|
||||
Every time an item is evicted from the container,
|
||||
``dispose_func(value)`` is called. Callback which will get called
|
||||
"""
|
||||
|
||||
_container: typing.OrderedDict[_KT, _VT]
|
||||
_maxsize: int
|
||||
dispose_func: typing.Callable[[_VT], None] | None
|
||||
lock: RLock
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxsize: int = 10,
|
||||
dispose_func: typing.Callable[[_VT], None] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._maxsize = maxsize
|
||||
self.dispose_func = dispose_func
|
||||
self._container = OrderedDict()
|
||||
self.lock = RLock()
|
||||
|
||||
def __getitem__(self, key: _KT) -> _VT:
|
||||
# Re-insert the item, moving it to the end of the eviction line.
|
||||
with self.lock:
|
||||
item = self._container.pop(key)
|
||||
self._container[key] = item
|
||||
return item
|
||||
|
||||
def __setitem__(self, key: _KT, value: _VT) -> None:
|
||||
evicted_item = None
|
||||
with self.lock:
|
||||
# Possibly evict the existing value of 'key'
|
||||
try:
|
||||
# If the key exists, we'll overwrite it, which won't change the
|
||||
# size of the pool. Because accessing a key should move it to
|
||||
# the end of the eviction line, we pop it out first.
|
||||
evicted_item = key, self._container.pop(key)
|
||||
self._container[key] = value
|
||||
except KeyError:
|
||||
# When the key does not exist, we insert the value first so that
|
||||
# evicting works in all cases, including when self._maxsize is 0
|
||||
self._container[key] = value
|
||||
if len(self._container) > self._maxsize:
|
||||
# If we didn't evict an existing value, and we've hit our maximum
|
||||
# size, then we have to evict the least recently used item from
|
||||
# the beginning of the container.
|
||||
evicted_item = self._container.popitem(last=False)
|
||||
|
||||
# After releasing the lock on the pool, dispose of any evicted value.
|
||||
if evicted_item is not None and self.dispose_func:
|
||||
_, evicted_value = evicted_item
|
||||
self.dispose_func(evicted_value)
|
||||
|
||||
def __delitem__(self, key: _KT) -> None:
|
||||
with self.lock:
|
||||
value = self._container.pop(key)
|
||||
|
||||
if self.dispose_func:
|
||||
self.dispose_func(value)
|
||||
|
||||
def __len__(self) -> int:
|
||||
with self.lock:
|
||||
return len(self._container)
|
||||
|
||||
def __iter__(self) -> typing.NoReturn:
|
||||
raise NotImplementedError(
|
||||
"Iteration over this class is unlikely to be threadsafe."
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
with self.lock:
|
||||
# Copy pointers to all values, then wipe the mapping
|
||||
values = list(self._container.values())
|
||||
self._container.clear()
|
||||
|
||||
if self.dispose_func:
|
||||
for value in values:
|
||||
self.dispose_func(value)
|
||||
|
||||
def keys(self) -> set[_KT]: # type: ignore[override]
|
||||
with self.lock:
|
||||
return set(self._container.keys())
|
||||
|
||||
|
||||
class HTTPHeaderDictItemView(typing.Set[typing.Tuple[str, str]]):
|
||||
"""
|
||||
HTTPHeaderDict is unusual for a Mapping[str, str] in that it has two modes of
|
||||
address.
|
||||
|
||||
If we directly try to get an item with a particular name, we will get a string
|
||||
back that is the concatenated version of all the values:
|
||||
|
||||
>>> d['X-Header-Name']
|
||||
'Value1, Value2, Value3'
|
||||
|
||||
However, if we iterate over an HTTPHeaderDict's items, we will optionally combine
|
||||
these values based on whether combine=True was called when building up the dictionary
|
||||
|
||||
>>> d = HTTPHeaderDict({"A": "1", "B": "foo"})
|
||||
>>> d.add("A", "2", combine=True)
|
||||
>>> d.add("B", "bar")
|
||||
>>> list(d.items())
|
||||
[
|
||||
('A', '1, 2'),
|
||||
('B', 'foo'),
|
||||
('B', 'bar'),
|
||||
]
|
||||
|
||||
This class conforms to the interface required by the MutableMapping ABC while
|
||||
also giving us the nonstandard iteration behavior we want; items with duplicate
|
||||
keys, ordered by time of first insertion.
|
||||
"""
|
||||
|
||||
_headers: HTTPHeaderDict
|
||||
|
||||
def __init__(self, headers: HTTPHeaderDict) -> None:
|
||||
self._headers = headers
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(list(self._headers.iteritems()))
|
||||
|
||||
def __iter__(self) -> typing.Iterator[tuple[str, str]]:
|
||||
return self._headers.iteritems()
|
||||
|
||||
def __contains__(self, item: object) -> bool:
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
passed_key, passed_val = item
|
||||
if isinstance(passed_key, str) and isinstance(passed_val, str):
|
||||
return self._headers._has_value_for_header(passed_key, passed_val)
|
||||
return False
|
||||
|
||||
|
||||
class HTTPHeaderDict(typing.MutableMapping[str, str]):
|
||||
"""
|
||||
:param headers:
|
||||
An iterable of field-value pairs. Must not contain multiple field names
|
||||
when compared case-insensitively.
|
||||
|
||||
:param kwargs:
|
||||
Additional field-value pairs to pass in to ``dict.update``.
|
||||
|
||||
A ``dict`` like container for storing HTTP Headers.
|
||||
|
||||
Field names are stored and compared case-insensitively in compliance with
|
||||
RFC 7230. Iteration provides the first case-sensitive key seen for each
|
||||
case-insensitive pair.
|
||||
|
||||
Using ``__setitem__`` syntax overwrites fields that compare equal
|
||||
case-insensitively in order to maintain ``dict``'s api. For fields that
|
||||
compare equal, instead create a new ``HTTPHeaderDict`` and use ``.add``
|
||||
in a loop.
|
||||
|
||||
If multiple fields that are equal case-insensitively are passed to the
|
||||
constructor or ``.update``, the behavior is undefined and some will be
|
||||
lost.
|
||||
|
||||
>>> headers = HTTPHeaderDict()
|
||||
>>> headers.add('Set-Cookie', 'foo=bar')
|
||||
>>> headers.add('set-cookie', 'baz=quxx')
|
||||
>>> headers['content-length'] = '7'
|
||||
>>> headers['SET-cookie']
|
||||
'foo=bar, baz=quxx'
|
||||
>>> headers['Content-Length']
|
||||
'7'
|
||||
"""
|
||||
|
||||
_container: typing.MutableMapping[str, list[str]]
|
||||
|
||||
def __init__(self, headers: ValidHTTPHeaderSource | None = None, **kwargs: str):
|
||||
super().__init__()
|
||||
self._container = {} # 'dict' is insert-ordered in Python 3.7+
|
||||
if headers is not None:
|
||||
if isinstance(headers, HTTPHeaderDict):
|
||||
self._copy_from(headers)
|
||||
else:
|
||||
self.extend(headers)
|
||||
if kwargs:
|
||||
self.extend(kwargs)
|
||||
|
||||
def __setitem__(self, key: str, val: str) -> None:
|
||||
# avoid a bytes/str comparison by decoding before httplib
|
||||
self._container[_lower_wrapper(key)] = [key, val]
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
if isinstance(key, bytes):
|
||||
key = key.decode("latin-1")
|
||||
val = self._container[_lower_wrapper(key)]
|
||||
return ", ".join(val[1:])
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
if isinstance(key, bytes):
|
||||
key = key.decode("latin-1")
|
||||
del self._container[_lower_wrapper(key)]
|
||||
|
||||
def __contains__(self, key: object) -> bool:
|
||||
if isinstance(key, bytes):
|
||||
key = key.decode("latin-1")
|
||||
if isinstance(key, str):
|
||||
return _lower_wrapper(key) in self._container
|
||||
return False
|
||||
|
||||
def setdefault(self, key: str, default: str = "") -> str:
|
||||
return super().setdefault(key, default)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
maybe_constructable = ensure_can_construct_http_header_dict(other)
|
||||
if maybe_constructable is None:
|
||||
return False
|
||||
else:
|
||||
other_as_http_header_dict = type(self)(maybe_constructable)
|
||||
|
||||
return {_lower_wrapper(k): v for k, v in self.itermerged()} == {
|
||||
_lower_wrapper(k): v for k, v in other_as_http_header_dict.itermerged()
|
||||
}
|
||||
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._container)
|
||||
|
||||
def __iter__(self) -> typing.Iterator[str]:
|
||||
# Only provide the originally cased names
|
||||
for vals in self._container.values():
|
||||
yield vals[0]
|
||||
|
||||
def discard(self, key: str) -> None:
|
||||
try:
|
||||
del self[key]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def add(self, key: str, val: str, *, combine: bool = False) -> None:
|
||||
"""Adds a (name, value) pair, doesn't overwrite the value if it already
|
||||
exists.
|
||||
|
||||
If this is called with combine=True, instead of adding a new header value
|
||||
as a distinct item during iteration, this will instead append the value to
|
||||
any existing header value with a comma. If no existing header value exists
|
||||
for the key, then the value will simply be added, ignoring the combine parameter.
|
||||
|
||||
>>> headers = HTTPHeaderDict(foo='bar')
|
||||
>>> headers.add('Foo', 'baz')
|
||||
>>> headers['foo']
|
||||
'bar, baz'
|
||||
>>> list(headers.items())
|
||||
[('foo', 'bar'), ('foo', 'baz')]
|
||||
>>> headers.add('foo', 'quz', combine=True)
|
||||
>>> list(headers.items())
|
||||
[('foo', 'bar, baz, quz')]
|
||||
"""
|
||||
key_lower = _lower_wrapper(key)
|
||||
new_vals = [key, val]
|
||||
# Keep the common case aka no item present as fast as possible
|
||||
vals = self._container.setdefault(key_lower, new_vals)
|
||||
if new_vals is not vals:
|
||||
# if there are values here, then there is at least the initial
|
||||
# key/value pair
|
||||
if combine:
|
||||
vals[-1] = vals[-1] + ", " + val
|
||||
else:
|
||||
vals.append(val)
|
||||
|
||||
def extend(self, *args: ValidHTTPHeaderSource, **kwargs: str) -> None:
|
||||
"""Generic import function for any type of header-like object.
|
||||
Adapted version of MutableMapping.update in order to insert items
|
||||
with self.add instead of self.__setitem__
|
||||
"""
|
||||
if len(args) > 1:
|
||||
raise TypeError(
|
||||
f"extend() takes at most 1 positional arguments ({len(args)} given)"
|
||||
)
|
||||
other = args[0] if len(args) >= 1 else ()
|
||||
|
||||
if isinstance(other, HTTPHeaderDict):
|
||||
for key, val in other.iteritems():
|
||||
self.add(key, val)
|
||||
elif isinstance(other, typing.Mapping):
|
||||
for key, val in other.items():
|
||||
self.add(key, val)
|
||||
elif isinstance(other, typing.Iterable):
|
||||
for key, value in other:
|
||||
self.add(key, value)
|
||||
elif hasattr(other, "keys") and hasattr(other, "__getitem__"):
|
||||
# THIS IS NOT A TYPESAFE BRANCH
|
||||
# In this branch, the object has a `keys` attr but is not a Mapping or any of
|
||||
# the other types indicated in the method signature. We do some stuff with
|
||||
# it as though it partially implements the Mapping interface, but we're not
|
||||
# doing that stuff safely AT ALL.
|
||||
for key in other.keys():
|
||||
self.add(key, other[key])
|
||||
|
||||
for key, value in kwargs.items():
|
||||
self.add(key, value)
|
||||
|
||||
@typing.overload
|
||||
def getlist(self, key: str) -> list[str]: ...
|
||||
|
||||
@typing.overload
|
||||
def getlist(self, key: str, default: _DT) -> list[str] | _DT: ...
|
||||
|
||||
def getlist(
|
||||
self, key: str, default: _Sentinel | _DT = _Sentinel.not_passed
|
||||
) -> list[str] | _DT:
|
||||
"""Returns a list of all the values for the named field. Returns an
|
||||
empty list if the key doesn't exist."""
|
||||
if isinstance(key, bytes):
|
||||
key = key.decode("latin-1")
|
||||
try:
|
||||
vals = self._container[_lower_wrapper(key)]
|
||||
except KeyError:
|
||||
if default is _Sentinel.not_passed:
|
||||
# _DT is unbound; empty list is instance of List[str]
|
||||
return []
|
||||
# _DT is bound; default is instance of _DT
|
||||
return default
|
||||
else:
|
||||
# _DT may or may not be bound; vals[1:] is instance of List[str], which
|
||||
# meets our external interface requirement of `Union[List[str], _DT]`.
|
||||
return vals[1:]
|
||||
|
||||
# Backwards compatibility for httplib
|
||||
getheaders = getlist
|
||||
getallmatchingheaders = getlist
|
||||
iget = getlist
|
||||
|
||||
# Backwards compatibility for http.cookiejar
|
||||
get_all = getlist
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{type(self).__name__}({dict(self.itermerged())})"
|
||||
|
||||
def _copy_from(self, other: HTTPHeaderDict) -> None:
|
||||
for key in other:
|
||||
val = other.getlist(key)
|
||||
self._container[_lower_wrapper(key)] = [key, *val]
|
||||
|
||||
def copy(self) -> HTTPHeaderDict:
|
||||
clone = type(self)()
|
||||
clone._copy_from(self)
|
||||
return clone
|
||||
|
||||
def iteritems(self) -> typing.Iterator[tuple[str, str]]:
|
||||
"""Iterate over all header lines, including duplicate ones."""
|
||||
for key in self:
|
||||
vals = self._container[_lower_wrapper(key)]
|
||||
for val in vals[1:]:
|
||||
yield vals[0], val
|
||||
|
||||
def itermerged(self) -> typing.Iterator[tuple[str, str]]:
|
||||
"""Iterate over all headers, merging duplicate ones together."""
|
||||
for key in self:
|
||||
val = self._container[_lower_wrapper(key)]
|
||||
yield val[0], ", ".join(val[1:])
|
||||
|
||||
def items(self) -> HTTPHeaderDictItemView: # type: ignore[override]
|
||||
return HTTPHeaderDictItemView(self)
|
||||
|
||||
def _has_value_for_header(self, header_name: str, potential_value: str) -> bool:
|
||||
if header_name in self:
|
||||
return potential_value in self._container[_lower_wrapper(header_name)][1:]
|
||||
return False
|
||||
255
.venv/lib/python3.9/site-packages/urllib3_future/_constant.py
Normal file
255
.venv/lib/python3.9/site-packages/urllib3_future/_constant.py
Normal file
@@ -0,0 +1,255 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
from enum import IntEnum
|
||||
|
||||
|
||||
class HTTPStatus(IntEnum):
|
||||
"""HTTP status codes and reason phrases
|
||||
|
||||
Status codes from the following RFCs are all observed:
|
||||
|
||||
* RFC 7231: Hypertext Transfer Protocol (HTTP/1.1), obsoletes 2616
|
||||
* RFC 6585: Additional HTTP Status Codes
|
||||
* RFC 3229: Delta encoding in HTTP
|
||||
* RFC 4918: HTTP Extensions for WebDAV, obsoletes 2518
|
||||
* RFC 5842: Binding Extensions to WebDAV
|
||||
* RFC 7238: Permanent Redirect
|
||||
* RFC 2295: Transparent Content Negotiation in HTTP
|
||||
* RFC 2774: An HTTP Extension Framework
|
||||
* RFC 7725: An HTTP Status Code to Report Legal Obstacles
|
||||
* RFC 7540: Hypertext Transfer Protocol Version 2 (HTTP/2)
|
||||
* RFC 2324: Hyper Text Coffee Pot Control Protocol (HTCPCP/1.0)
|
||||
* RFC 8297: An HTTP Status Code for Indicating Hints
|
||||
* RFC 8470: Using Early Data in HTTP
|
||||
"""
|
||||
|
||||
phrase: str
|
||||
description: str
|
||||
standard: bool
|
||||
|
||||
def __new__(
|
||||
cls, value: int, phrase: str, description: str = "", is_standard: bool = True
|
||||
) -> HTTPStatus:
|
||||
obj = int.__new__(cls, value)
|
||||
obj._value_ = value
|
||||
|
||||
obj.phrase = phrase
|
||||
obj.description = description
|
||||
obj.standard = is_standard
|
||||
return obj
|
||||
|
||||
# informational
|
||||
CONTINUE = 100, "Continue", "Request received, please continue"
|
||||
SWITCHING_PROTOCOLS = (
|
||||
101,
|
||||
"Switching Protocols",
|
||||
"Switching to new protocol; obey Upgrade header",
|
||||
)
|
||||
PROCESSING = 102, "Processing"
|
||||
EARLY_HINTS = 103, "Early Hints"
|
||||
|
||||
# success
|
||||
OK = 200, "OK", "Request fulfilled, document follows"
|
||||
CREATED = 201, "Created", "Document created, URL follows"
|
||||
ACCEPTED = (202, "Accepted", "Request accepted, processing continues off-line")
|
||||
NON_AUTHORITATIVE_INFORMATION = (
|
||||
203,
|
||||
"Non-Authoritative Information",
|
||||
"Request fulfilled from cache",
|
||||
)
|
||||
NO_CONTENT = 204, "No Content", "Request fulfilled, nothing follows"
|
||||
RESET_CONTENT = 205, "Reset Content", "Clear input form for further input"
|
||||
PARTIAL_CONTENT = 206, "Partial Content", "Partial content follows"
|
||||
MULTI_STATUS = 207, "Multi-Status"
|
||||
ALREADY_REPORTED = 208, "Already Reported"
|
||||
IM_USED = 226, "IM Used"
|
||||
|
||||
# redirection
|
||||
MULTIPLE_CHOICES = (
|
||||
300,
|
||||
"Multiple Choices",
|
||||
"Object has several resources -- see URI list",
|
||||
)
|
||||
MOVED_PERMANENTLY = (
|
||||
301,
|
||||
"Moved Permanently",
|
||||
"Object moved permanently -- see URI list",
|
||||
)
|
||||
FOUND = 302, "Found", "Object moved temporarily -- see URI list"
|
||||
SEE_OTHER = 303, "See Other", "Object moved -- see Method and URL list"
|
||||
NOT_MODIFIED = (304, "Not Modified", "Document has not changed since given time")
|
||||
USE_PROXY = (
|
||||
305,
|
||||
"Use Proxy",
|
||||
"You must use proxy specified in Location to access this resource",
|
||||
)
|
||||
TEMPORARY_REDIRECT = (
|
||||
307,
|
||||
"Temporary Redirect",
|
||||
"Object moved temporarily -- see URI list",
|
||||
)
|
||||
PERMANENT_REDIRECT = (
|
||||
308,
|
||||
"Permanent Redirect",
|
||||
"Object moved permanently -- see URI list",
|
||||
)
|
||||
|
||||
# client error
|
||||
BAD_REQUEST = (400, "Bad Request", "Bad request syntax or unsupported method")
|
||||
UNAUTHORIZED = (401, "Unauthorized", "No permission -- see authorization schemes")
|
||||
PAYMENT_REQUIRED = (402, "Payment Required", "No payment -- see charging schemes")
|
||||
FORBIDDEN = (403, "Forbidden", "Request forbidden -- authorization will not help")
|
||||
NOT_FOUND = (404, "Not Found", "Nothing matches the given URI")
|
||||
METHOD_NOT_ALLOWED = (
|
||||
405,
|
||||
"Method Not Allowed",
|
||||
"Specified method is invalid for this resource",
|
||||
)
|
||||
NOT_ACCEPTABLE = (406, "Not Acceptable", "URI not available in preferred format")
|
||||
PROXY_AUTHENTICATION_REQUIRED = (
|
||||
407,
|
||||
"Proxy Authentication Required",
|
||||
"You must authenticate with this proxy before proceeding",
|
||||
)
|
||||
REQUEST_TIMEOUT = (408, "Request Timeout", "Request timed out; try again later")
|
||||
CONFLICT = 409, "Conflict", "Request conflict"
|
||||
GONE = (410, "Gone", "URI no longer exists and has been permanently removed")
|
||||
LENGTH_REQUIRED = (411, "Length Required", "Client must specify Content-Length")
|
||||
PRECONDITION_FAILED = (
|
||||
412,
|
||||
"Precondition Failed",
|
||||
"Precondition in headers is false",
|
||||
)
|
||||
REQUEST_ENTITY_TOO_LARGE = (413, "Request Entity Too Large", "Entity is too large")
|
||||
REQUEST_URI_TOO_LONG = (414, "Request-URI Too Long", "URI is too long")
|
||||
UNSUPPORTED_MEDIA_TYPE = (
|
||||
415,
|
||||
"Unsupported Media Type",
|
||||
"Entity body in unsupported format",
|
||||
)
|
||||
REQUESTED_RANGE_NOT_SATISFIABLE = (
|
||||
416,
|
||||
"Requested Range Not Satisfiable",
|
||||
"Cannot satisfy request range",
|
||||
)
|
||||
EXPECTATION_FAILED = (
|
||||
417,
|
||||
"Expectation Failed",
|
||||
"Expect condition could not be satisfied",
|
||||
)
|
||||
IM_A_TEAPOT = (
|
||||
418,
|
||||
"I'm a Teapot",
|
||||
"Server refuses to brew coffee because it is a teapot.",
|
||||
)
|
||||
MISDIRECTED_REQUEST = (
|
||||
421,
|
||||
"Misdirected Request",
|
||||
"Server is not able to produce a response",
|
||||
)
|
||||
UNPROCESSABLE_ENTITY = 422, "Unprocessable Entity"
|
||||
LOCKED = 423, "Locked"
|
||||
FAILED_DEPENDENCY = 424, "Failed Dependency"
|
||||
TOO_EARLY = 425, "Too Early"
|
||||
UPGRADE_REQUIRED = 426, "Upgrade Required"
|
||||
PRECONDITION_REQUIRED = (
|
||||
428,
|
||||
"Precondition Required",
|
||||
"The origin server requires the request to be conditional",
|
||||
)
|
||||
TOO_MANY_REQUESTS = (
|
||||
429,
|
||||
"Too Many Requests",
|
||||
"The user has sent too many requests in "
|
||||
'a given amount of time ("rate limiting")',
|
||||
)
|
||||
REQUEST_HEADER_FIELDS_TOO_LARGE = (
|
||||
431,
|
||||
"Request Header Fields Too Large",
|
||||
"The server is unwilling to process the request because its header "
|
||||
"fields are too large",
|
||||
)
|
||||
UNAVAILABLE_FOR_LEGAL_REASONS = (
|
||||
451,
|
||||
"Unavailable For Legal Reasons",
|
||||
"The server is denying access to the "
|
||||
"resource as a consequence of a legal demand",
|
||||
)
|
||||
|
||||
# server errors
|
||||
INTERNAL_SERVER_ERROR = (
|
||||
500,
|
||||
"Internal Server Error",
|
||||
"Server got itself in trouble",
|
||||
)
|
||||
NOT_IMPLEMENTED = (501, "Not Implemented", "Server does not support this operation")
|
||||
BAD_GATEWAY = (502, "Bad Gateway", "Invalid responses from another server/proxy")
|
||||
SERVICE_UNAVAILABLE = (
|
||||
503,
|
||||
"Service Unavailable",
|
||||
"The server cannot process the request due to a high load",
|
||||
)
|
||||
GATEWAY_TIMEOUT = (
|
||||
504,
|
||||
"Gateway Timeout",
|
||||
"The gateway server did not receive a timely response",
|
||||
)
|
||||
HTTP_VERSION_NOT_SUPPORTED = (
|
||||
505,
|
||||
"HTTP Version Not Supported",
|
||||
"Cannot fulfill request",
|
||||
)
|
||||
VARIANT_ALSO_NEGOTIATES = 506, "Variant Also Negotiates"
|
||||
INSUFFICIENT_STORAGE = 507, "Insufficient Storage"
|
||||
LOOP_DETECTED = 508, "Loop Detected"
|
||||
NOT_EXTENDED = 510, "Not Extended"
|
||||
NETWORK_AUTHENTICATION_REQUIRED = (
|
||||
511,
|
||||
"Network Authentication Required",
|
||||
"The client needs to authenticate to gain network access",
|
||||
)
|
||||
|
||||
|
||||
# another hack to maintain backwards compatibility
|
||||
# Mapping status codes to official W3C names
|
||||
responses: typing.Mapping[int, str] = {
|
||||
v: v.phrase for v in HTTPStatus.__members__.values()
|
||||
}
|
||||
|
||||
# Default value for `blocksize` - a new parameter introduced to
|
||||
# http.client.HTTPConnection & http.client.HTTPSConnection in Python 3.7
|
||||
# The maximum TCP packet size is 65535 octets. But OSes can have a recv buffer greater than 65k, so
|
||||
# passing the highest value does improve responsiveness.
|
||||
# udp usually is set to 212992
|
||||
# and tcp usually is set to 131072 (16384 * 8) or (65535 * 2)
|
||||
try:
|
||||
# dynamically retrieve the kernel rcvbuf size.
|
||||
stream = socket.socket(type=socket.SOCK_STREAM)
|
||||
dgram = socket.socket(type=socket.SOCK_DGRAM)
|
||||
|
||||
DEFAULT_BLOCKSIZE: int = stream.getsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF)
|
||||
UDP_DEFAULT_BLOCKSIZE: int = dgram.getsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF)
|
||||
|
||||
stream.close()
|
||||
dgram.close()
|
||||
except OSError:
|
||||
DEFAULT_BLOCKSIZE = 131072
|
||||
UDP_DEFAULT_BLOCKSIZE = 212992
|
||||
|
||||
TCP_DEFAULT_BLOCKSIZE: int = DEFAULT_BLOCKSIZE
|
||||
|
||||
UDP_LINUX_GRO: int = 104
|
||||
UDP_LINUX_SEGMENT: int = 103
|
||||
|
||||
# Mozilla TLS recommendations for ciphers
|
||||
# General-purpose servers with a variety of clients, recommended for almost all systems.
|
||||
MOZ_INTERMEDIATE_CIPHERS: str = "ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384:DHE-RSA-CHACHA20-POLY1305"
|
||||
|
||||
DEFAULT_BACKGROUND_WATCH_WINDOW: float = 5.0
|
||||
MINIMAL_BACKGROUND_WATCH_WINDOW: float = 0.05
|
||||
|
||||
DEFAULT_KEEPALIVE_DELAY: float = 3600.0
|
||||
DEFAULT_KEEPALIVE_IDLE_WINDOW: float = 60.0
|
||||
MINIMAL_KEEPALIVE_IDLE_WINDOW: float = 1.0
|
||||
@@ -0,0 +1,642 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json as _json
|
||||
import typing
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from ._async.response import AsyncHTTPResponse
|
||||
from ._collections import HTTPHeaderDict
|
||||
from ._typing import _TYPE_ASYNC_BODY, _TYPE_BODY, _TYPE_ENCODE_URL_FIELDS, _TYPE_FIELDS
|
||||
from .filepost import encode_multipart_formdata
|
||||
from .response import HTTPResponse
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from typing_extensions import Literal
|
||||
|
||||
from .backend import ResponsePromise
|
||||
|
||||
__all__ = ["RequestMethods", "AsyncRequestMethods"]
|
||||
|
||||
|
||||
class RequestMethods:
|
||||
"""
|
||||
Convenience mixin for classes who implement a :meth:`urlopen` method, such
|
||||
as :class:`urllib3.HTTPConnectionPool` and
|
||||
:class:`urllib3.PoolManager`.
|
||||
|
||||
Provides behavior for making common types of HTTP request methods and
|
||||
decides which type of request field encoding to use.
|
||||
|
||||
Specifically,
|
||||
|
||||
:meth:`.request_encode_url` is for sending requests whose fields are
|
||||
encoded in the URL (such as GET, HEAD, DELETE).
|
||||
|
||||
:meth:`.request_encode_body` is for sending requests whose fields are
|
||||
encoded in the *body* of the request using multipart or www-form-urlencoded
|
||||
(such as for POST, PUT, PATCH).
|
||||
|
||||
:meth:`.request` is for making any kind of request, it will look up the
|
||||
appropriate encoding format and use one of the above two methods to make
|
||||
the request.
|
||||
|
||||
Initializer parameters:
|
||||
|
||||
:param headers:
|
||||
Headers to include with all requests, unless other headers are given
|
||||
explicitly.
|
||||
"""
|
||||
|
||||
_encode_url_methods = {"DELETE", "GET", "HEAD", "OPTIONS"}
|
||||
|
||||
def __init__(self, headers: typing.Mapping[str, str] | None = None) -> None:
|
||||
self.headers = headers or {}
|
||||
|
||||
@typing.overload
|
||||
def urlopen(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
body: _TYPE_BODY | None = None,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
encode_multipart: bool = True,
|
||||
multipart_boundary: str | None = None,
|
||||
*,
|
||||
multiplexed: Literal[False] = ...,
|
||||
**kw: typing.Any,
|
||||
) -> HTTPResponse: ...
|
||||
|
||||
@typing.overload
|
||||
def urlopen(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
body: _TYPE_BODY | None = None,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
encode_multipart: bool = True,
|
||||
multipart_boundary: str | None = None,
|
||||
*,
|
||||
multiplexed: Literal[True],
|
||||
**kw: typing.Any,
|
||||
) -> ResponsePromise: ...
|
||||
|
||||
def urlopen(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
body: _TYPE_BODY | None = None,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
encode_multipart: bool = True,
|
||||
multipart_boundary: str | None = None,
|
||||
**kw: typing.Any,
|
||||
) -> HTTPResponse | ResponsePromise:
|
||||
raise NotImplementedError(
|
||||
"Classes extending RequestMethods must implement "
|
||||
"their own ``urlopen`` method."
|
||||
)
|
||||
|
||||
@typing.overload
|
||||
def request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
body: _TYPE_BODY | None = ...,
|
||||
fields: _TYPE_FIELDS | None = ...,
|
||||
headers: typing.Mapping[str, str] | None = ...,
|
||||
json: typing.Any | None = ...,
|
||||
*,
|
||||
multiplexed: Literal[False] = ...,
|
||||
**urlopen_kw: typing.Any,
|
||||
) -> HTTPResponse: ...
|
||||
|
||||
@typing.overload
|
||||
def request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
body: _TYPE_BODY | None = ...,
|
||||
fields: _TYPE_FIELDS | None = ...,
|
||||
headers: typing.Mapping[str, str] | None = ...,
|
||||
json: typing.Any | None = ...,
|
||||
*,
|
||||
multiplexed: Literal[True],
|
||||
**urlopen_kw: typing.Any,
|
||||
) -> ResponsePromise: ...
|
||||
|
||||
def request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
body: _TYPE_BODY | None = None,
|
||||
fields: _TYPE_FIELDS | None = None,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
json: typing.Any | None = None,
|
||||
**urlopen_kw: typing.Any,
|
||||
) -> HTTPResponse | ResponsePromise:
|
||||
"""
|
||||
Make a request using :meth:`urlopen` with the appropriate encoding of
|
||||
``fields`` based on the ``method`` used.
|
||||
|
||||
This is a convenience method that requires the least amount of manual
|
||||
effort. It can be used in most situations, while still having the
|
||||
option to drop down to more specific methods when necessary, such as
|
||||
:meth:`request_encode_url`, :meth:`request_encode_body`,
|
||||
or even the lowest level :meth:`urlopen`.
|
||||
"""
|
||||
method = method.upper()
|
||||
|
||||
if json is not None and body is not None:
|
||||
raise TypeError(
|
||||
"request got values for both 'body' and 'json' parameters which are mutually exclusive"
|
||||
)
|
||||
|
||||
if json is not None:
|
||||
if headers is None:
|
||||
headers = self.headers.copy() # type: ignore
|
||||
if "content-type" not in map(str.lower, headers.keys()):
|
||||
headers["Content-Type"] = "application/json" # type: ignore
|
||||
|
||||
body = _json.dumps(json, separators=(",", ":"), ensure_ascii=False).encode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
if body is not None:
|
||||
urlopen_kw["body"] = body
|
||||
|
||||
if method in self._encode_url_methods:
|
||||
return self.request_encode_url(
|
||||
method,
|
||||
url,
|
||||
fields=fields, # type: ignore[arg-type]
|
||||
headers=headers,
|
||||
**urlopen_kw,
|
||||
)
|
||||
else:
|
||||
return self.request_encode_body( # type: ignore[no-any-return]
|
||||
method,
|
||||
url,
|
||||
fields=fields,
|
||||
headers=headers,
|
||||
**urlopen_kw,
|
||||
)
|
||||
|
||||
@typing.overload
|
||||
def request_encode_url(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
fields: _TYPE_ENCODE_URL_FIELDS | None = ...,
|
||||
headers: typing.Mapping[str, str] | None = ...,
|
||||
*,
|
||||
multiplexed: Literal[False] = ...,
|
||||
**urlopen_kw: typing.Any,
|
||||
) -> HTTPResponse: ...
|
||||
|
||||
@typing.overload
|
||||
def request_encode_url(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
fields: _TYPE_ENCODE_URL_FIELDS | None = ...,
|
||||
headers: typing.Mapping[str, str] | None = ...,
|
||||
*,
|
||||
multiplexed: Literal[True],
|
||||
**urlopen_kw: typing.Any,
|
||||
) -> ResponsePromise: ...
|
||||
|
||||
def request_encode_url(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
fields: _TYPE_ENCODE_URL_FIELDS | None = None,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
**urlopen_kw: typing.Any,
|
||||
) -> HTTPResponse | ResponsePromise:
|
||||
"""
|
||||
Make a request using :meth:`urlopen` with the ``fields`` encoded in
|
||||
the url. This is useful for request methods like GET, HEAD, DELETE, etc.
|
||||
"""
|
||||
if headers is None:
|
||||
headers = self.headers
|
||||
|
||||
extra_kw: dict[str, typing.Any] = {"headers": headers}
|
||||
extra_kw.update(urlopen_kw)
|
||||
|
||||
if fields:
|
||||
url += "?" + urlencode(fields)
|
||||
|
||||
return self.urlopen(method, url, **extra_kw) # type: ignore[no-any-return]
|
||||
|
||||
@typing.overload
|
||||
def request_encode_body(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
fields: _TYPE_FIELDS | None = ...,
|
||||
headers: typing.Mapping[str, str] | None = ...,
|
||||
encode_multipart: bool = ...,
|
||||
multipart_boundary: str | None = ...,
|
||||
*,
|
||||
multiplexed: Literal[False] = ...,
|
||||
**urlopen_kw: typing.Any,
|
||||
) -> HTTPResponse: ...
|
||||
|
||||
@typing.overload
|
||||
def request_encode_body(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
fields: _TYPE_FIELDS | None = ...,
|
||||
headers: typing.Mapping[str, str] | None = ...,
|
||||
encode_multipart: bool = ...,
|
||||
multipart_boundary: str | None = ...,
|
||||
*,
|
||||
multiplexed: Literal[True],
|
||||
**urlopen_kw: typing.Any,
|
||||
) -> ResponsePromise: ...
|
||||
|
||||
def request_encode_body(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
fields: _TYPE_FIELDS | None = None,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
encode_multipart: bool = True,
|
||||
multipart_boundary: str | None = None,
|
||||
**urlopen_kw: typing.Any,
|
||||
) -> HTTPResponse | ResponsePromise:
|
||||
"""
|
||||
Make a request using :meth:`urlopen` with the ``fields`` encoded in
|
||||
the body. This is useful for request methods like POST, PUT, PATCH, etc.
|
||||
|
||||
When ``encode_multipart=True`` (default), then
|
||||
:func:`urllib3.encode_multipart_formdata` is used to encode
|
||||
the payload with the appropriate content type. Otherwise
|
||||
:func:`urllib.parse.urlencode` is used with the
|
||||
'application/x-www-form-urlencoded' content type.
|
||||
|
||||
Multipart encoding must be used when posting files, and it's reasonably
|
||||
safe to use it in other times too. However, it may break request
|
||||
signing, such as with OAuth.
|
||||
|
||||
Supports an optional ``fields`` parameter of key/value strings AND
|
||||
key/filetuple. A filetuple is a (filename, data, MIME type) tuple where
|
||||
the MIME type is optional. For example::
|
||||
|
||||
fields = {
|
||||
'foo': 'bar',
|
||||
'fakefile': ('foofile.txt', 'contents of foofile'),
|
||||
'realfile': ('barfile.txt', open('realfile').read()),
|
||||
'typedfile': ('bazfile.bin', open('bazfile').read(),
|
||||
'image/jpeg'),
|
||||
'nonamefile': 'contents of nonamefile field',
|
||||
}
|
||||
|
||||
When uploading a file, providing a filename (the first parameter of the
|
||||
tuple) is optional but recommended to best mimic behavior of browsers.
|
||||
|
||||
Note that if ``headers`` are supplied, the 'Content-Type' header will
|
||||
be overwritten because it depends on the dynamic random boundary string
|
||||
which is used to compose the body of the request. The random boundary
|
||||
string can be explicitly set with the ``multipart_boundary`` parameter.
|
||||
"""
|
||||
if headers is None:
|
||||
headers = self.headers
|
||||
|
||||
extra_kw: dict[str, typing.Any] = {"headers": HTTPHeaderDict(headers)}
|
||||
body: bytes | str
|
||||
|
||||
if fields:
|
||||
if "body" in urlopen_kw:
|
||||
raise TypeError(
|
||||
"request got values for both 'fields' and 'body', can only specify one."
|
||||
)
|
||||
|
||||
if encode_multipart:
|
||||
body, content_type = encode_multipart_formdata(
|
||||
fields, boundary=multipart_boundary
|
||||
)
|
||||
else:
|
||||
body, content_type = (
|
||||
urlencode(fields), # type: ignore[arg-type]
|
||||
"application/x-www-form-urlencoded",
|
||||
)
|
||||
|
||||
extra_kw["body"] = body
|
||||
extra_kw["headers"].setdefault("Content-Type", content_type)
|
||||
|
||||
extra_kw.update(urlopen_kw)
|
||||
|
||||
return self.urlopen(method, url, **extra_kw) # type: ignore[no-any-return]
|
||||
|
||||
|
||||
class AsyncRequestMethods:
|
||||
"""
|
||||
Convenience mixin for classes who implement a :meth:`urlopen` method, such
|
||||
as :class:`urllib3.AsyncHTTPConnectionPool` and
|
||||
:class:`urllib3.AsyncPoolManager`.
|
||||
|
||||
Provides behavior for making common types of HTTP request methods and
|
||||
decides which type of request field encoding to use.
|
||||
|
||||
Specifically,
|
||||
|
||||
:meth:`.request_encode_url` is for sending requests whose fields are
|
||||
encoded in the URL (such as GET, HEAD, DELETE).
|
||||
|
||||
:meth:`.request_encode_body` is for sending requests whose fields are
|
||||
encoded in the *body* of the request using multipart or www-form-urlencoded
|
||||
(such as for POST, PUT, PATCH).
|
||||
|
||||
:meth:`.request` is for making any kind of request, it will look up the
|
||||
appropriate encoding format and use one of the above two methods to make
|
||||
the request.
|
||||
|
||||
Initializer parameters:
|
||||
|
||||
:param headers:
|
||||
Headers to include with all requests, unless other headers are given
|
||||
explicitly.
|
||||
"""
|
||||
|
||||
_encode_url_methods = {"DELETE", "GET", "HEAD", "OPTIONS"}
|
||||
|
||||
def __init__(self, headers: typing.Mapping[str, str] | None = None) -> None:
|
||||
self.headers = headers or {}
|
||||
|
||||
@typing.overload
|
||||
async def urlopen(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
body: _TYPE_BODY | _TYPE_ASYNC_BODY | None = None,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
encode_multipart: bool = True,
|
||||
multipart_boundary: str | None = None,
|
||||
*,
|
||||
multiplexed: Literal[False] = ...,
|
||||
**kw: typing.Any,
|
||||
) -> AsyncHTTPResponse: ...
|
||||
|
||||
@typing.overload
|
||||
async def urlopen(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
body: _TYPE_BODY | _TYPE_ASYNC_BODY | None = None,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
encode_multipart: bool = True,
|
||||
multipart_boundary: str | None = None,
|
||||
*,
|
||||
multiplexed: Literal[True],
|
||||
**kw: typing.Any,
|
||||
) -> ResponsePromise: ...
|
||||
|
||||
async def urlopen(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
body: _TYPE_BODY | _TYPE_ASYNC_BODY | None = None,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
encode_multipart: bool = True,
|
||||
multipart_boundary: str | None = None,
|
||||
**kw: typing.Any,
|
||||
) -> AsyncHTTPResponse | ResponsePromise:
|
||||
raise NotImplementedError(
|
||||
"Classes extending RequestMethods must implement "
|
||||
"their own ``urlopen`` method."
|
||||
)
|
||||
|
||||
@typing.overload
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
body: _TYPE_BODY | _TYPE_ASYNC_BODY | None = ...,
|
||||
fields: _TYPE_FIELDS | None = ...,
|
||||
headers: typing.Mapping[str, str] | None = ...,
|
||||
json: typing.Any | None = ...,
|
||||
*,
|
||||
multiplexed: Literal[False] = ...,
|
||||
**urlopen_kw: typing.Any,
|
||||
) -> AsyncHTTPResponse: ...
|
||||
|
||||
@typing.overload
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
body: _TYPE_BODY | _TYPE_ASYNC_BODY | None = ...,
|
||||
fields: _TYPE_FIELDS | None = ...,
|
||||
headers: typing.Mapping[str, str] | None = ...,
|
||||
json: typing.Any | None = ...,
|
||||
*,
|
||||
multiplexed: Literal[True],
|
||||
**urlopen_kw: typing.Any,
|
||||
) -> ResponsePromise: ...
|
||||
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
body: _TYPE_BODY | _TYPE_ASYNC_BODY | None = None,
|
||||
fields: _TYPE_FIELDS | None = None,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
json: typing.Any | None = None,
|
||||
**urlopen_kw: typing.Any,
|
||||
) -> AsyncHTTPResponse | ResponsePromise:
|
||||
"""
|
||||
Make a request using :meth:`urlopen` with the appropriate encoding of
|
||||
``fields`` based on the ``method`` used.
|
||||
|
||||
This is a convenience method that requires the least amount of manual
|
||||
effort. It can be used in most situations, while still having the
|
||||
option to drop down to more specific methods when necessary, such as
|
||||
:meth:`request_encode_url`, :meth:`request_encode_body`,
|
||||
or even the lowest level :meth:`urlopen`.
|
||||
"""
|
||||
method = method.upper()
|
||||
|
||||
if json is not None and body is not None:
|
||||
raise TypeError(
|
||||
"request got values for both 'body' and 'json' parameters which are mutually exclusive"
|
||||
)
|
||||
|
||||
if json is not None:
|
||||
if headers is None:
|
||||
headers = self.headers.copy() # type: ignore
|
||||
if "content-type" not in map(str.lower, headers.keys()):
|
||||
headers["Content-Type"] = "application/json" # type: ignore
|
||||
|
||||
body = _json.dumps(json, separators=(",", ":"), ensure_ascii=False).encode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
if body is not None:
|
||||
urlopen_kw["body"] = body
|
||||
|
||||
if method in self._encode_url_methods:
|
||||
return await self.request_encode_url(
|
||||
method,
|
||||
url,
|
||||
fields=fields, # type: ignore[arg-type]
|
||||
headers=headers,
|
||||
**urlopen_kw,
|
||||
)
|
||||
else:
|
||||
return await self.request_encode_body( # type: ignore[no-any-return]
|
||||
method,
|
||||
url,
|
||||
fields=fields,
|
||||
headers=headers,
|
||||
**urlopen_kw,
|
||||
)
|
||||
|
||||
@typing.overload
|
||||
async def request_encode_url(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
fields: _TYPE_ENCODE_URL_FIELDS | None = ...,
|
||||
headers: typing.Mapping[str, str] | None = ...,
|
||||
*,
|
||||
multiplexed: Literal[False] = ...,
|
||||
**urlopen_kw: typing.Any,
|
||||
) -> AsyncHTTPResponse: ...
|
||||
|
||||
@typing.overload
|
||||
async def request_encode_url(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
fields: _TYPE_ENCODE_URL_FIELDS | None = ...,
|
||||
headers: typing.Mapping[str, str] | None = ...,
|
||||
*,
|
||||
multiplexed: Literal[True],
|
||||
**urlopen_kw: typing.Any,
|
||||
) -> ResponsePromise: ...
|
||||
|
||||
async def request_encode_url(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
fields: _TYPE_ENCODE_URL_FIELDS | None = None,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
**urlopen_kw: typing.Any,
|
||||
) -> AsyncHTTPResponse | ResponsePromise:
|
||||
"""
|
||||
Make a request using :meth:`urlopen` with the ``fields`` encoded in
|
||||
the url. This is useful for request methods like GET, HEAD, DELETE, etc.
|
||||
"""
|
||||
if headers is None:
|
||||
headers = self.headers
|
||||
|
||||
extra_kw: dict[str, typing.Any] = {"headers": headers}
|
||||
extra_kw.update(urlopen_kw)
|
||||
|
||||
if fields:
|
||||
url += "?" + urlencode(fields)
|
||||
|
||||
return await self.urlopen(method, url, **extra_kw) # type: ignore[no-any-return]
|
||||
|
||||
@typing.overload
|
||||
async def request_encode_body(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
fields: _TYPE_FIELDS | None = ...,
|
||||
headers: typing.Mapping[str, str] | None = ...,
|
||||
encode_multipart: bool = ...,
|
||||
multipart_boundary: str | None = ...,
|
||||
*,
|
||||
multiplexed: Literal[False] = ...,
|
||||
**urlopen_kw: typing.Any,
|
||||
) -> AsyncHTTPResponse: ...
|
||||
|
||||
@typing.overload
|
||||
async def request_encode_body(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
fields: _TYPE_FIELDS | None = ...,
|
||||
headers: typing.Mapping[str, str] | None = ...,
|
||||
encode_multipart: bool = ...,
|
||||
multipart_boundary: str | None = ...,
|
||||
*,
|
||||
multiplexed: Literal[True],
|
||||
**urlopen_kw: typing.Any,
|
||||
) -> ResponsePromise: ...
|
||||
|
||||
async def request_encode_body(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
fields: _TYPE_FIELDS | None = None,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
encode_multipart: bool = True,
|
||||
multipart_boundary: str | None = None,
|
||||
**urlopen_kw: typing.Any,
|
||||
) -> AsyncHTTPResponse | ResponsePromise:
|
||||
"""
|
||||
Make a request using :meth:`urlopen` with the ``fields`` encoded in
|
||||
the body. This is useful for request methods like POST, PUT, PATCH, etc.
|
||||
|
||||
When ``encode_multipart=True`` (default), then
|
||||
:func:`urllib3.encode_multipart_formdata` is used to encode
|
||||
the payload with the appropriate content type. Otherwise
|
||||
:func:`urllib.parse.urlencode` is used with the
|
||||
'application/x-www-form-urlencoded' content type.
|
||||
|
||||
Multipart encoding must be used when posting files, and it's reasonably
|
||||
safe to use it in other times too. However, it may break request
|
||||
signing, such as with OAuth.
|
||||
|
||||
Supports an optional ``fields`` parameter of key/value strings AND
|
||||
key/filetuple. A filetuple is a (filename, data, MIME type) tuple where
|
||||
the MIME type is optional. For example::
|
||||
|
||||
fields = {
|
||||
'foo': 'bar',
|
||||
'fakefile': ('foofile.txt', 'contents of foofile'),
|
||||
'realfile': ('barfile.txt', open('realfile').read()),
|
||||
'typedfile': ('bazfile.bin', open('bazfile').read(),
|
||||
'image/jpeg'),
|
||||
'nonamefile': 'contents of nonamefile field',
|
||||
}
|
||||
|
||||
When uploading a file, providing a filename (the first parameter of the
|
||||
tuple) is optional but recommended to best mimic behavior of browsers.
|
||||
|
||||
Note that if ``headers`` are supplied, the 'Content-Type' header will
|
||||
be overwritten because it depends on the dynamic random boundary string
|
||||
which is used to compose the body of the request. The random boundary
|
||||
string can be explicitly set with the ``multipart_boundary`` parameter.
|
||||
"""
|
||||
if headers is None:
|
||||
headers = self.headers
|
||||
|
||||
extra_kw: dict[str, typing.Any] = {"headers": HTTPHeaderDict(headers)}
|
||||
body: bytes | str
|
||||
|
||||
if fields:
|
||||
if "body" in urlopen_kw:
|
||||
raise TypeError(
|
||||
"request got values for both 'fields' and 'body', can only specify one."
|
||||
)
|
||||
|
||||
if encode_multipart:
|
||||
body, content_type = encode_multipart_formdata(
|
||||
fields, boundary=multipart_boundary
|
||||
)
|
||||
else:
|
||||
body, content_type = (
|
||||
urlencode(fields), # type: ignore[arg-type]
|
||||
"application/x-www-form-urlencoded",
|
||||
)
|
||||
|
||||
extra_kw["body"] = body
|
||||
extra_kw["headers"].setdefault("Content-Type", content_type)
|
||||
|
||||
extra_kw.update(urlopen_kw)
|
||||
|
||||
return await self.urlopen(method, url, **extra_kw) # type: ignore[no-any-return]
|
||||
94
.venv/lib/python3.9/site-packages/urllib3_future/_typing.py
Normal file
94
.venv/lib/python3.9/site-packages/urllib3_future/_typing.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from enum import Enum
|
||||
|
||||
from .backend import LowLevelResponse
|
||||
from .backend._async import AsyncLowLevelResponse
|
||||
from .fields import RequestField
|
||||
from .util.request import _TYPE_FAILEDTELL
|
||||
from .util.timeout import _TYPE_DEFAULT, Timeout
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import ssl
|
||||
|
||||
from typing_extensions import Literal, TypedDict
|
||||
|
||||
class _TYPE_PEER_CERT_RET_DICT(TypedDict, total=False):
|
||||
subjectAltName: tuple[tuple[str, str], ...]
|
||||
subject: tuple[tuple[tuple[str, str], ...], ...]
|
||||
serialNumber: str
|
||||
|
||||
|
||||
_TYPE_BODY: typing.TypeAlias = typing.Union[
|
||||
bytes,
|
||||
typing.IO[typing.Any],
|
||||
typing.Iterable[bytes],
|
||||
typing.Iterable[str],
|
||||
str,
|
||||
LowLevelResponse,
|
||||
AsyncLowLevelResponse,
|
||||
]
|
||||
|
||||
_TYPE_ASYNC_BODY: typing.TypeAlias = typing.Union[
|
||||
typing.AsyncIterable[bytes],
|
||||
typing.AsyncIterable[str],
|
||||
]
|
||||
|
||||
_TYPE_FIELD_VALUE: typing.TypeAlias = typing.Union[str, bytes]
|
||||
_TYPE_FIELD_VALUE_TUPLE: typing.TypeAlias = typing.Union[
|
||||
_TYPE_FIELD_VALUE,
|
||||
typing.Tuple[str, _TYPE_FIELD_VALUE],
|
||||
typing.Tuple[str, _TYPE_FIELD_VALUE, str],
|
||||
]
|
||||
|
||||
_TYPE_FIELDS_SEQUENCE: typing.TypeAlias = typing.Sequence[
|
||||
typing.Union[typing.Tuple[str, _TYPE_FIELD_VALUE_TUPLE], RequestField]
|
||||
]
|
||||
_TYPE_FIELDS: typing.TypeAlias = typing.Union[
|
||||
_TYPE_FIELDS_SEQUENCE,
|
||||
typing.Mapping[str, _TYPE_FIELD_VALUE_TUPLE],
|
||||
]
|
||||
_TYPE_ENCODE_URL_FIELDS: typing.TypeAlias = typing.Union[
|
||||
typing.Sequence[typing.Tuple[str, typing.Union[str, bytes]]],
|
||||
typing.Mapping[str, typing.Union[str, bytes]],
|
||||
]
|
||||
_TYPE_SOCKET_OPTIONS: typing.TypeAlias = typing.Sequence[
|
||||
typing.Union[
|
||||
typing.Tuple[int, int, int],
|
||||
typing.Tuple[int, int, int, str],
|
||||
]
|
||||
]
|
||||
_TYPE_REDUCE_RESULT: typing.TypeAlias = typing.Tuple[
|
||||
typing.Callable[..., object], typing.Tuple[object, ...]
|
||||
]
|
||||
|
||||
|
||||
_TYPE_TIMEOUT: typing.TypeAlias = typing.Union[float, _TYPE_DEFAULT, Timeout, None]
|
||||
_TYPE_TIMEOUT_INTERNAL: typing.TypeAlias = typing.Union[float, _TYPE_DEFAULT, None]
|
||||
_TYPE_PEER_CERT_RET: typing.TypeAlias = typing.Union[
|
||||
"_TYPE_PEER_CERT_RET_DICT", bytes, None
|
||||
]
|
||||
|
||||
_TYPE_BODY_POSITION: typing.TypeAlias = typing.Union[int, _TYPE_FAILEDTELL]
|
||||
|
||||
try:
|
||||
from typing import TypedDict
|
||||
|
||||
class _TYPE_SOCKS_OPTIONS(TypedDict):
|
||||
socks_version: int | Enum
|
||||
proxy_host: str | None
|
||||
proxy_port: str | None
|
||||
username: str | None
|
||||
password: str | None
|
||||
rdns: bool
|
||||
|
||||
except ImportError: # Python 3.7
|
||||
_TYPE_SOCKS_OPTIONS = typing.Dict[str, typing.Any] # type: ignore[misc, assignment]
|
||||
|
||||
|
||||
class ProxyConfig(typing.NamedTuple):
|
||||
ssl_context: ssl.SSLContext | None
|
||||
use_forwarding_for_https: bool
|
||||
assert_hostname: None | str | Literal[False]
|
||||
assert_fingerprint: str | None
|
||||
@@ -0,0 +1,4 @@
|
||||
# This file is protected via CODEOWNERS
|
||||
from __future__ import annotations
|
||||
|
||||
__version__ = "2.17.902"
|
||||
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._base import (
|
||||
BaseBackend,
|
||||
ConnectionInfo,
|
||||
HttpVersion,
|
||||
LowLevelResponse,
|
||||
QuicPreemptiveCacheType,
|
||||
ResponsePromise,
|
||||
)
|
||||
from .hface import HfaceBackend
|
||||
|
||||
__all__ = (
|
||||
"BaseBackend",
|
||||
"HfaceBackend",
|
||||
"HttpVersion",
|
||||
"QuicPreemptiveCacheType",
|
||||
"LowLevelResponse",
|
||||
"ConnectionInfo",
|
||||
"ResponsePromise",
|
||||
)
|
||||
@@ -0,0 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._base import AsyncBaseBackend, AsyncLowLevelResponse
|
||||
from .hface import AsyncHfaceBackend
|
||||
|
||||
__all__ = (
|
||||
"AsyncBaseBackend",
|
||||
"AsyncLowLevelResponse",
|
||||
"AsyncHfaceBackend",
|
||||
)
|
||||
@@ -0,0 +1,330 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from ..._collections import HTTPHeaderDict
|
||||
from ...contrib.ssa import AsyncSocket, SSLAsyncSocket
|
||||
from .._base import BaseBackend, ResponsePromise
|
||||
from ...util.response import BytesQueueBuffer
|
||||
|
||||
|
||||
class AsyncDirectStreamAccess:
|
||||
def __init__(
|
||||
self,
|
||||
stream_id: int,
|
||||
read: typing.Callable[
|
||||
[int | None, int | None, bool, bool],
|
||||
typing.Awaitable[tuple[bytes, bool, HTTPHeaderDict | None]],
|
||||
]
|
||||
| None = None,
|
||||
write: typing.Callable[[bytes, int, bool], typing.Awaitable[None]]
|
||||
| None = None,
|
||||
) -> None:
|
||||
self._stream_id = stream_id
|
||||
self._read = read
|
||||
self._write = write
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._read is None and self._write is None
|
||||
|
||||
async def readinto(self, b: bytearray) -> int:
|
||||
if self._read is None:
|
||||
raise OSError("read operation on a closed stream")
|
||||
|
||||
temp = await self.recv(len(b))
|
||||
|
||||
if len(temp) == 0:
|
||||
return 0
|
||||
else:
|
||||
b[: len(temp)] = temp
|
||||
return len(temp)
|
||||
|
||||
def readable(self) -> bool:
|
||||
return self._read is not None
|
||||
|
||||
def writable(self) -> bool:
|
||||
return self._write is not None
|
||||
|
||||
def seekable(self) -> bool:
|
||||
return False
|
||||
|
||||
def fileno(self) -> int:
|
||||
return -1
|
||||
|
||||
def name(self) -> int:
|
||||
return -1
|
||||
|
||||
async def recv(self, __bufsize: int, __flags: int = 0) -> bytes:
|
||||
data, _, _ = await self.recv_extended(__bufsize)
|
||||
return data
|
||||
|
||||
async def recv_extended(
|
||||
self, __bufsize: int | None
|
||||
) -> tuple[bytes, bool, HTTPHeaderDict | None]:
|
||||
if self._read is None:
|
||||
raise OSError("stream closed error")
|
||||
|
||||
data, eot, trailers = await self._read(
|
||||
__bufsize,
|
||||
self._stream_id,
|
||||
__bufsize is not None,
|
||||
False,
|
||||
)
|
||||
|
||||
if eot:
|
||||
self._read = None
|
||||
|
||||
return data, eot, trailers
|
||||
|
||||
async def sendall(self, __data: bytes, __flags: int = 0) -> None:
|
||||
if self._write is None:
|
||||
raise OSError("stream write not permitted")
|
||||
|
||||
await self._write(__data, self._stream_id, False)
|
||||
|
||||
async def write(self, __data: bytes) -> int:
|
||||
if self._write is None:
|
||||
raise OSError("stream write not permitted")
|
||||
|
||||
await self._write(__data, self._stream_id, False)
|
||||
|
||||
return len(__data)
|
||||
|
||||
async def sendall_extended(
|
||||
self, __data: bytes, __close_stream: bool = False
|
||||
) -> None:
|
||||
if self._write is None:
|
||||
raise OSError("stream write not permitted")
|
||||
|
||||
await self._write(__data, self._stream_id, __close_stream)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._write is not None:
|
||||
await self._write(b"", self._stream_id, True)
|
||||
self._write = None
|
||||
if self._read is not None:
|
||||
await self._read(None, self._stream_id, False, True)
|
||||
self._read = None
|
||||
|
||||
|
||||
class AsyncLowLevelResponse:
|
||||
"""Implemented for backward compatibility purposes. It is there to impose http.client like
|
||||
basic response object. So that we don't have to change urllib3 tested behaviors."""
|
||||
|
||||
__internal_read_st: (
|
||||
typing.Callable[
|
||||
[int | None, int | None],
|
||||
typing.Awaitable[tuple[bytes, bool, HTTPHeaderDict | None]],
|
||||
]
|
||||
| None
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
method: str,
|
||||
status: int,
|
||||
version: int,
|
||||
reason: str,
|
||||
headers: HTTPHeaderDict,
|
||||
body: typing.Callable[
|
||||
[int | None, int | None],
|
||||
typing.Awaitable[tuple[bytes, bool, HTTPHeaderDict | None]],
|
||||
]
|
||||
| None,
|
||||
*,
|
||||
authority: str | None = None,
|
||||
port: int | None = None,
|
||||
stream_id: int | None = None,
|
||||
# this obj should not be always available[...]
|
||||
dsa: AsyncDirectStreamAccess | None = None,
|
||||
stream_abort: typing.Callable[[int], typing.Awaitable[None]] | None = None,
|
||||
) -> None:
|
||||
self.status = status
|
||||
self.version = version
|
||||
self.reason = reason
|
||||
self.msg = headers
|
||||
self._method = method
|
||||
|
||||
self.__internal_read_st = body
|
||||
|
||||
has_body = self.__internal_read_st is not None
|
||||
|
||||
self.closed = has_body is False
|
||||
self._eot = self.closed
|
||||
|
||||
# is kept to determine if we can upgrade conn
|
||||
self.authority = authority
|
||||
self.port = port
|
||||
|
||||
# http.client compat layer
|
||||
self.debuglevel: int = 0 # no-op flag, kept for strict backward compatibility!
|
||||
self.chunked: bool = ( # is "chunked" being used? http1 only!
|
||||
self.version == 11 and "chunked" == self.msg.get("transfer-encoding")
|
||||
)
|
||||
self.chunk_left: int | None = None # bytes left to read in current chunk
|
||||
self.length: int | None = None # number of bytes left in response
|
||||
self.will_close: bool = (
|
||||
False # no-op flag, kept for strict backward compatibility!
|
||||
)
|
||||
|
||||
if not self.chunked:
|
||||
content_length = self.msg.get("content-length")
|
||||
self.length = int(content_length) if content_length else None
|
||||
|
||||
#: not part of http.client but useful to track (raw) download speeds!
|
||||
self.data_in_count = 0
|
||||
|
||||
self._stream_id = stream_id
|
||||
|
||||
self.__buffer_excess: BytesQueueBuffer = BytesQueueBuffer()
|
||||
self.__promise: ResponsePromise | None = None
|
||||
self._dsa = dsa
|
||||
self._stream_abort = stream_abort
|
||||
|
||||
self.trailers: HTTPHeaderDict | None = None
|
||||
|
||||
@property
|
||||
def fp(self) -> typing.NoReturn:
|
||||
raise RuntimeError(
|
||||
"urllib3-future no longer expose a filepointer-like in responses. It was a remnant from the http.client era. "
|
||||
"We no longer support it."
|
||||
)
|
||||
|
||||
@property
|
||||
def from_promise(self) -> ResponsePromise | None:
|
||||
return self.__promise
|
||||
|
||||
@from_promise.setter
|
||||
def from_promise(self, value: ResponsePromise) -> None:
|
||||
if value.stream_id != self._stream_id:
|
||||
raise ValueError(
|
||||
"Trying to assign a ResponsePromise to an unrelated LowLevelResponse"
|
||||
)
|
||||
self.__promise = value
|
||||
|
||||
@property
|
||||
def method(self) -> str:
|
||||
"""Original HTTP verb used in the request."""
|
||||
return self._method
|
||||
|
||||
def isclosed(self) -> bool:
|
||||
"""Here we do not create a fp sock like http.client Response."""
|
||||
return self.closed
|
||||
|
||||
async def read(self, __size: int | None = None) -> bytes:
|
||||
if self.closed is True or self.__internal_read_st is None:
|
||||
# overly protective, just in case.
|
||||
raise ValueError(
|
||||
"I/O operation on closed file."
|
||||
) # Defensive: Should not be reachable in normal condition
|
||||
|
||||
if __size == 0:
|
||||
return b"" # Defensive: This is unreachable, this case is already covered higher in the stack.
|
||||
|
||||
buf_capacity = len(self.__buffer_excess)
|
||||
data_ready_to_go = (
|
||||
__size is not None and buf_capacity > 0 and buf_capacity >= __size
|
||||
)
|
||||
|
||||
if self._eot is False and not data_ready_to_go:
|
||||
data, self._eot, self.trailers = await self.__internal_read_st(
|
||||
__size, self._stream_id
|
||||
)
|
||||
|
||||
self.__buffer_excess.put(data)
|
||||
buf_capacity = len(self.__buffer_excess)
|
||||
|
||||
data = self.__buffer_excess.get(
|
||||
__size if __size is not None and __size > 0 else buf_capacity
|
||||
)
|
||||
|
||||
size_in = len(data)
|
||||
|
||||
buf_capacity -= size_in
|
||||
|
||||
if self._eot and buf_capacity == 0:
|
||||
self._stream_abort = None
|
||||
self.closed = True
|
||||
self._sock = None
|
||||
|
||||
if self.chunked:
|
||||
self.chunk_left = buf_capacity if buf_capacity else None
|
||||
elif self.length is not None:
|
||||
self.length -= size_in
|
||||
|
||||
self.data_in_count += size_in
|
||||
|
||||
return data
|
||||
|
||||
async def abort(self) -> None:
|
||||
if self._stream_abort is not None:
|
||||
if self._eot is False:
|
||||
if self._stream_id is not None:
|
||||
await self._stream_abort(self._stream_id)
|
||||
self._eot = True
|
||||
self._stream_abort = None
|
||||
self.closed = True
|
||||
self._dsa = None
|
||||
|
||||
def close(self) -> None:
|
||||
self.__internal_read_st = None
|
||||
self.closed = True
|
||||
self._dsa = None
|
||||
|
||||
|
||||
class AsyncBaseBackend(BaseBackend):
|
||||
sock: AsyncSocket | SSLAsyncSocket | None # type: ignore[assignment]
|
||||
|
||||
async def _upgrade(self) -> None: # type: ignore[override]
|
||||
"""Upgrade conn from svn ver to max supported."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def _tunnel(self) -> None: # type: ignore[override]
|
||||
"""Emit proper CONNECT request to the http (server) intermediary."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def _new_conn(self) -> AsyncSocket | None: # type: ignore[override]
|
||||
"""Run protocol initialization from there. Return None to ensure that the child
|
||||
class correctly create the socket / connection."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def _post_conn(self) -> None: # type: ignore[override]
|
||||
"""Should be called after _new_conn proceed as expected.
|
||||
Expect protocol handshake to be done here."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def endheaders( # type: ignore[override]
|
||||
self,
|
||||
message_body: bytes | None = None,
|
||||
*,
|
||||
encode_chunked: bool = False,
|
||||
expect_body_afterward: bool = False,
|
||||
) -> ResponsePromise | None:
|
||||
"""This method conclude the request context construction."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def getresponse( # type: ignore[override]
|
||||
self, *, promise: ResponsePromise | None = None
|
||||
) -> AsyncLowLevelResponse:
|
||||
"""Fetch the HTTP response. You SHOULD not retrieve the body in that method, it SHOULD be done
|
||||
in the LowLevelResponse, so it enable stream capabilities and remain efficient.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
"""End the connection, do some reinit, closing of fd, etc..."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def send( # type: ignore[override]
|
||||
self,
|
||||
data: (bytes | typing.IO[typing.Any] | typing.Iterable[bytes] | str),
|
||||
*,
|
||||
eot: bool = False,
|
||||
) -> ResponsePromise | None:
|
||||
"""The send() method SHOULD be invoked after calling endheaders() if and only if the request
|
||||
context specify explicitly that a body is going to be sent."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def ping(self) -> None: # type: ignore[override]
|
||||
raise NotImplementedError
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,690 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import socket
|
||||
import time
|
||||
import typing
|
||||
import warnings
|
||||
from base64 import b64encode
|
||||
from datetime import datetime, timedelta
|
||||
from secrets import token_bytes
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ssl import SSLSocket, SSLContext, TLSVersion
|
||||
from .._typing import _TYPE_SOCKET_OPTIONS
|
||||
from ._async import AsyncLowLevelResponse
|
||||
|
||||
from .._collections import HTTPHeaderDict
|
||||
from .._constant import DEFAULT_BLOCKSIZE, DEFAULT_KEEPALIVE_DELAY
|
||||
from ..util.response import BytesQueueBuffer
|
||||
|
||||
|
||||
class HttpVersion(str, enum.Enum):
|
||||
"""Describe possible SVN protocols that can be supported."""
|
||||
|
||||
h11 = "HTTP/1.1"
|
||||
# we know that it is rather "HTTP/2" than "HTTP/2.0"
|
||||
# it is this way to remain somewhat compatible with http.client
|
||||
# http_svn (int). 9 -> 11 -> 20 -> 30
|
||||
h2 = "HTTP/2.0"
|
||||
h3 = "HTTP/3.0"
|
||||
|
||||
|
||||
class ConnectionInfo:
|
||||
def __init__(self) -> None:
|
||||
#: Time taken to establish the connection
|
||||
self.established_latency: timedelta | None = None
|
||||
|
||||
#: HTTP protocol used with the remote peer (not the proxy)
|
||||
self.http_version: HttpVersion | None = None
|
||||
|
||||
#: The SSL certificate presented by the remote peer (not the proxy)
|
||||
self.certificate_der: bytes | None = None
|
||||
self.certificate_dict: (
|
||||
dict[str, int | tuple[tuple[str, str], ...] | tuple[str, ...] | str] | None
|
||||
) = None
|
||||
|
||||
#: The SSL issuer certificate for the remote peer certificate (not the proxy)
|
||||
self.issuer_certificate_der: bytes | None = None
|
||||
self.issuer_certificate_dict: (
|
||||
dict[str, int | tuple[tuple[str, str], ...] | tuple[str, ...] | str] | None
|
||||
) = None
|
||||
|
||||
#: The IP address used to reach the remote peer (not the proxy), that was yield by your resolver.
|
||||
self.destination_address: tuple[str, int] | None = None
|
||||
|
||||
#: The TLS cipher used to secure the exchanges (not the proxy)
|
||||
self.cipher: str | None = None
|
||||
#: The TLS revision used (not the proxy)
|
||||
self.tls_version: TLSVersion | None = None
|
||||
#: The time taken to reach a complete TLS liaison between the remote peer and us. (not the proxy)
|
||||
self.tls_handshake_latency: timedelta | None = None
|
||||
#: Time taken to resolve a domain name into a reachable IP address.
|
||||
self.resolution_latency: timedelta | None = None
|
||||
|
||||
#: Time taken to encode and send the whole request through the socket.
|
||||
self.request_sent_latency: timedelta | None = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(
|
||||
{
|
||||
"established_latency": self.established_latency,
|
||||
"certificate_der": self.certificate_der,
|
||||
"certificate_dict": self.certificate_dict,
|
||||
"issuer_certificate_der": self.issuer_certificate_der,
|
||||
"issuer_certificate_dict": self.issuer_certificate_dict,
|
||||
"destination_address": self.destination_address,
|
||||
"cipher": self.cipher,
|
||||
"tls_version": self.tls_version,
|
||||
"tls_handshake_latency": self.tls_handshake_latency,
|
||||
"http_version": self.http_version,
|
||||
"resolution_latency": self.resolution_latency,
|
||||
"request_sent_latency": self.request_sent_latency,
|
||||
}
|
||||
)
|
||||
|
||||
def is_encrypted(self) -> bool:
|
||||
return self.certificate_der is not None
|
||||
|
||||
|
||||
class DirectStreamAccess:
|
||||
def __init__(
|
||||
self,
|
||||
stream_id: int,
|
||||
read: typing.Callable[
|
||||
[int | None, int | None, bool, bool],
|
||||
tuple[bytes, bool, HTTPHeaderDict | None],
|
||||
]
|
||||
| None = None,
|
||||
write: typing.Callable[[bytes, int, bool], None] | None = None,
|
||||
) -> None:
|
||||
self._stream_id = stream_id
|
||||
|
||||
if read is not None:
|
||||
self._read: (
|
||||
typing.Callable[
|
||||
[int | None, bool], tuple[bytes, bool, HTTPHeaderDict | None]
|
||||
]
|
||||
| None
|
||||
) = lambda amt, fo: read(amt, self._stream_id, amt is not None, fo)
|
||||
else:
|
||||
self._read = None
|
||||
|
||||
if write is not None:
|
||||
self._write: typing.Callable[[bytes, bool], None] | None = (
|
||||
lambda buf, eot: write(buf, self._stream_id, eot)
|
||||
)
|
||||
else:
|
||||
self._write = None
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._read is None and self._write is None
|
||||
|
||||
def readinto(self, b: bytearray) -> int:
|
||||
if self._read is None:
|
||||
raise OSError("read operation on a closed stream")
|
||||
|
||||
temp = self.recv(len(b))
|
||||
|
||||
if len(temp) == 0:
|
||||
return 0
|
||||
else:
|
||||
b[: len(temp)] = temp
|
||||
return len(temp)
|
||||
|
||||
def readable(self) -> bool:
|
||||
return self._read is not None
|
||||
|
||||
def writable(self) -> bool:
|
||||
return self._write is not None
|
||||
|
||||
def seekable(self) -> bool:
|
||||
return False
|
||||
|
||||
def fileno(self) -> int:
|
||||
return -1
|
||||
|
||||
def name(self) -> int:
|
||||
return -1
|
||||
|
||||
def recv(self, __bufsize: int, __flags: int = 0) -> bytes:
|
||||
data, _, _ = self.recv_extended(__bufsize)
|
||||
return data
|
||||
|
||||
def recv_extended(
|
||||
self, __bufsize: int | None
|
||||
) -> tuple[bytes, bool, HTTPHeaderDict | None]:
|
||||
if self._read is None:
|
||||
raise OSError("stream closed error")
|
||||
|
||||
data, eot, trailers = self._read(__bufsize, False)
|
||||
|
||||
if eot:
|
||||
self._read = None
|
||||
|
||||
return data, eot, trailers
|
||||
|
||||
def sendall(self, __data: bytes, __flags: int = 0) -> None:
|
||||
if self._write is None:
|
||||
raise OSError("stream write not permitted")
|
||||
|
||||
self._write(__data, False)
|
||||
|
||||
def write(self, __data: bytes) -> int:
|
||||
if self._write is None:
|
||||
raise OSError("stream write not permitted")
|
||||
|
||||
self._write(__data, False)
|
||||
|
||||
return len(__data)
|
||||
|
||||
def sendall_extended(self, __data: bytes, __close_stream: bool = False) -> None:
|
||||
if self._write is None:
|
||||
raise OSError("stream write not permitted")
|
||||
|
||||
self._write(__data, __close_stream)
|
||||
|
||||
def close(self) -> None:
|
||||
if self._write is not None:
|
||||
self._write(b"", True)
|
||||
self._write = None
|
||||
if self._read is not None:
|
||||
self._read(None, True)
|
||||
self._read = None
|
||||
|
||||
|
||||
class LowLevelResponse:
|
||||
"""Implemented for backward compatibility purposes. It is there to impose http.client like
|
||||
basic response object. So that we don't have to change urllib3 tested behaviors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
method: str,
|
||||
status: int,
|
||||
version: int,
|
||||
reason: str,
|
||||
headers: HTTPHeaderDict,
|
||||
body: typing.Callable[
|
||||
[int | None, int | None], tuple[bytes, bool, HTTPHeaderDict | None]
|
||||
]
|
||||
| None,
|
||||
*,
|
||||
authority: str | None = None,
|
||||
port: int | None = None,
|
||||
stream_id: int | None = None,
|
||||
sock: socket.socket | None = None,
|
||||
# this obj should not be always available[...]
|
||||
dsa: DirectStreamAccess | None = None,
|
||||
stream_abort: typing.Callable[[int], None] | None = None,
|
||||
):
|
||||
self.status = status
|
||||
self.version = version
|
||||
self.reason = reason
|
||||
self.msg = headers
|
||||
self._method = method
|
||||
|
||||
self.__internal_read_st = body
|
||||
|
||||
has_body = self.__internal_read_st is not None
|
||||
|
||||
self.closed = has_body is False
|
||||
self._eot = self.closed
|
||||
|
||||
# is kept to determine if we can upgrade conn
|
||||
self.authority = authority
|
||||
self.port = port
|
||||
|
||||
# http.client additional compat layer
|
||||
# although rarely used, some 3rd party library may
|
||||
# peek at those for whatever reason. most of the time they
|
||||
# are wrong to do so.
|
||||
self.debuglevel: int = 0 # no-op flag, kept for strict backward compatibility!
|
||||
self.chunked: bool = ( # is "chunked" being used? http1 only!
|
||||
self.version == 11 and "chunked" == self.msg.get("transfer-encoding")
|
||||
)
|
||||
self.chunk_left: int | None = None # bytes left to read in current chunk
|
||||
self.length: int | None = None # number of bytes left in response
|
||||
self.will_close: bool = (
|
||||
False # no-op flag, kept for strict backward compatibility!
|
||||
)
|
||||
|
||||
if not self.chunked:
|
||||
content_length = self.msg.get("content-length")
|
||||
self.length = int(content_length) if content_length else None
|
||||
|
||||
#: not part of http.client but useful to track (raw) download speeds!
|
||||
self.data_in_count = 0
|
||||
|
||||
# tricky part...
|
||||
# sometime 3rd party library tend to access hazardous materials...
|
||||
# they want a direct socket access.
|
||||
self._sock = sock
|
||||
self._fp: socket.SocketIO | None = None
|
||||
self._dsa = dsa
|
||||
self._stream_abort = stream_abort
|
||||
|
||||
self._stream_id = stream_id
|
||||
|
||||
self.__buffer_excess: BytesQueueBuffer = BytesQueueBuffer()
|
||||
self.__promise: ResponsePromise | None = None
|
||||
|
||||
self.trailers: HTTPHeaderDict | None = None
|
||||
|
||||
@property
|
||||
def fp(self) -> socket.SocketIO | DirectStreamAccess | None:
|
||||
warnings.warn(
|
||||
(
|
||||
"This is a rather awkward situation. A program (probably) tried to access the socket object "
|
||||
"directly, thus bypassing our state-machine protocol (amongst other things). "
|
||||
"This is currently unsupported and dangerous. Errors will occurs if you negotiated HTTP/2 or later versions. "
|
||||
"We tried to be rather strict on the backward compatibility between urllib3 and urllib3-future, "
|
||||
"but this is rather complicated to support (e.g. direct socket access). "
|
||||
"You are probably better off using our higher level read() function. "
|
||||
"Please open an issue at https://github.com/jawah/urllib3.future/issues to gain support or "
|
||||
"insights on it."
|
||||
),
|
||||
DeprecationWarning,
|
||||
2,
|
||||
)
|
||||
|
||||
if self._sock is None:
|
||||
if self.status == 101 or (
|
||||
self._method == "CONNECT" and 200 <= self.status < 300
|
||||
):
|
||||
return self._dsa
|
||||
|
||||
# well, there's nothing we can do more :'(
|
||||
raise AttributeError
|
||||
|
||||
if self._fp is None:
|
||||
self._fp = self._sock.makefile("rb") # type: ignore[assignment]
|
||||
|
||||
return self._fp
|
||||
|
||||
@property
|
||||
def from_promise(self) -> ResponsePromise | None:
|
||||
return self.__promise
|
||||
|
||||
@from_promise.setter
|
||||
def from_promise(self, value: ResponsePromise) -> None:
|
||||
if value.stream_id != self._stream_id:
|
||||
raise ValueError(
|
||||
"Trying to assign a ResponsePromise to an unrelated LowLevelResponse"
|
||||
)
|
||||
self.__promise = value
|
||||
|
||||
@property
|
||||
def method(self) -> str:
|
||||
"""Original HTTP verb used in the request."""
|
||||
return self._method
|
||||
|
||||
def isclosed(self) -> bool:
|
||||
"""Here we do not create a fp sock like http.client Response."""
|
||||
return self.closed
|
||||
|
||||
def read(self, __size: int | None = None) -> bytes:
|
||||
if self.closed is True or self.__internal_read_st is None:
|
||||
# overly protective, just in case.
|
||||
raise ValueError(
|
||||
"I/O operation on closed file."
|
||||
) # Defensive: Should not be reachable in normal condition
|
||||
|
||||
if __size == 0:
|
||||
return b"" # Defensive: This is unreachable, this case is already covered higher in the stack.
|
||||
|
||||
buf_capacity = len(self.__buffer_excess)
|
||||
data_ready_to_go = (
|
||||
__size is not None and buf_capacity > 0 and buf_capacity >= __size
|
||||
)
|
||||
|
||||
if self._eot is False and not data_ready_to_go:
|
||||
data, self._eot, self.trailers = self.__internal_read_st(
|
||||
__size, self._stream_id
|
||||
)
|
||||
|
||||
self.__buffer_excess.put(data)
|
||||
buf_capacity = len(self.__buffer_excess)
|
||||
|
||||
data = self.__buffer_excess.get(
|
||||
__size if __size is not None and __size > 0 else buf_capacity
|
||||
)
|
||||
|
||||
size_in = len(data)
|
||||
|
||||
buf_capacity -= size_in
|
||||
|
||||
if self._eot and buf_capacity == 0:
|
||||
self._stream_abort = None
|
||||
self.closed = True
|
||||
self._sock = None
|
||||
|
||||
if self.chunked:
|
||||
self.chunk_left = buf_capacity if buf_capacity else None
|
||||
elif self.length is not None:
|
||||
self.length -= size_in
|
||||
|
||||
self.data_in_count += size_in
|
||||
|
||||
return data
|
||||
|
||||
def abort(self) -> None:
|
||||
if self._stream_abort is not None:
|
||||
if self._eot is False:
|
||||
if self._stream_id is not None:
|
||||
self._stream_abort(self._stream_id)
|
||||
self._eot = True
|
||||
self._stream_abort = None
|
||||
self.closed = True
|
||||
self._dsa = None
|
||||
|
||||
def close(self) -> None:
|
||||
self.__internal_read_st = None
|
||||
self.closed = True
|
||||
self._sock = None
|
||||
self._dsa = None
|
||||
|
||||
|
||||
class ResponsePromise:
|
||||
def __init__(
|
||||
self,
|
||||
conn: BaseBackend,
|
||||
stream_id: int,
|
||||
request_headers: list[tuple[bytes, bytes]],
|
||||
**parameters: typing.Any,
|
||||
) -> None:
|
||||
self._uid: str = b64encode(token_bytes(16)).decode("ascii")
|
||||
self._conn: BaseBackend = conn
|
||||
self._stream_id: int = stream_id
|
||||
self._response: LowLevelResponse | AsyncLowLevelResponse | None = None
|
||||
self._request_headers = request_headers
|
||||
self._parameters: typing.MutableMapping[str, typing.Any] = parameters
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, ResponsePromise):
|
||||
return False
|
||||
return self.uid == other.uid
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<ResponsePromise '{self.uid}' {self._conn._http_vsn_str} Stream[{self.stream_id}]>"
|
||||
|
||||
@property
|
||||
def uid(self) -> str:
|
||||
return self._uid
|
||||
|
||||
@property
|
||||
def request_headers(self) -> list[tuple[bytes, bytes]]:
|
||||
return self._request_headers
|
||||
|
||||
@property
|
||||
def stream_id(self) -> int:
|
||||
return self._stream_id
|
||||
|
||||
@property
|
||||
def is_ready(self) -> bool:
|
||||
return self._response is not None
|
||||
|
||||
@property
|
||||
def response(self) -> LowLevelResponse | AsyncLowLevelResponse:
|
||||
if not self._response:
|
||||
raise OSError
|
||||
return self._response
|
||||
|
||||
@response.setter
|
||||
def response(self, value: LowLevelResponse | AsyncLowLevelResponse) -> None:
|
||||
self._response = value
|
||||
|
||||
def set_parameter(self, key: str, value: typing.Any) -> None:
|
||||
self._parameters[key] = value
|
||||
|
||||
def get_parameter(self, key: str) -> typing.Any | None:
|
||||
return self._parameters[key] if key in self._parameters else None
|
||||
|
||||
def update_parameters(self, data: dict[str, typing.Any]) -> None:
|
||||
self._parameters.update(data)
|
||||
|
||||
|
||||
_HostPortType: typing.TypeAlias = typing.Tuple[str, int]
|
||||
QuicPreemptiveCacheType: typing.TypeAlias = typing.MutableMapping[
|
||||
_HostPortType, typing.Optional[_HostPortType]
|
||||
]
|
||||
|
||||
|
||||
class BaseBackend:
|
||||
"""
|
||||
The goal here is to detach ourselves from the http.client package.
|
||||
At first, we'll strictly follow the methods in http.client.HTTPConnection. So that
|
||||
we would be able to implement other backend without disrupting the actual code base.
|
||||
Extend that base class in order to ship another backend with urllib3.
|
||||
"""
|
||||
|
||||
supported_svn: typing.ClassVar[list[HttpVersion] | None] = None
|
||||
scheme: typing.ClassVar[str]
|
||||
|
||||
default_socket_kind: socket.SocketKind = socket.SOCK_STREAM
|
||||
#: Disable Nagle's algorithm by default.
|
||||
default_socket_options: typing.ClassVar[_TYPE_SOCKET_OPTIONS] = [
|
||||
(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1, "tcp")
|
||||
]
|
||||
|
||||
#: Whether this connection verifies the host's certificate.
|
||||
is_verified: bool = False
|
||||
|
||||
#: Whether this proxy connection verified the proxy host's certificate.
|
||||
# If no proxy is currently connected to the value will be ``None``.
|
||||
proxy_is_verified: bool | None = None
|
||||
|
||||
response_class = LowLevelResponse
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int | None = None,
|
||||
timeout: int | float | None = -1,
|
||||
source_address: tuple[str, int] | None = None,
|
||||
blocksize: int = DEFAULT_BLOCKSIZE,
|
||||
*,
|
||||
socket_options: _TYPE_SOCKET_OPTIONS | None = default_socket_options,
|
||||
disabled_svn: set[HttpVersion] | None = None,
|
||||
preemptive_quic_cache: QuicPreemptiveCacheType | None = None,
|
||||
keepalive_delay: float | int | None = DEFAULT_KEEPALIVE_DELAY,
|
||||
):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout = timeout
|
||||
self.source_address = source_address
|
||||
self.blocksize = blocksize
|
||||
self.socket_kind = BaseBackend.default_socket_kind
|
||||
self.socket_options = socket_options
|
||||
self.sock: socket.socket | SSLSocket | None = None
|
||||
|
||||
self._response: LowLevelResponse | AsyncLowLevelResponse | None = None
|
||||
# Set it as default
|
||||
self._svn: HttpVersion | None = HttpVersion.h11
|
||||
|
||||
self._tunnel_host: str | None = None
|
||||
self._tunnel_port: int | None = None
|
||||
self._tunnel_scheme: str | None = None
|
||||
self._tunnel_headers: typing.Mapping[str, str] = dict()
|
||||
|
||||
self._disabled_svn = disabled_svn if disabled_svn is not None else set()
|
||||
self._preemptive_quic_cache = preemptive_quic_cache
|
||||
|
||||
if self._disabled_svn:
|
||||
if len(self._disabled_svn) == len(list(HttpVersion)):
|
||||
raise RuntimeError(
|
||||
"You disabled every supported protocols. The HTTP connection object is left with no outcomes."
|
||||
)
|
||||
|
||||
# valuable intel
|
||||
self.conn_info: ConnectionInfo | None = None
|
||||
|
||||
self._promises: dict[str, ResponsePromise] = {}
|
||||
self._promises_per_stream: dict[int, ResponsePromise] = {}
|
||||
self._pending_responses: dict[
|
||||
int, LowLevelResponse | AsyncLowLevelResponse
|
||||
] = {}
|
||||
|
||||
self._start_last_request: datetime | None = None
|
||||
|
||||
self._cached_http_vsn: int | None = None
|
||||
|
||||
self._keepalive_delay: float | None = (
|
||||
keepalive_delay # just forwarded for qh3 idle_timeout conf.
|
||||
)
|
||||
self._connected_at: float | None = None
|
||||
self._last_used_at: float = time.monotonic()
|
||||
|
||||
self._recv_size_ema: float = 0.0
|
||||
|
||||
def __contains__(self, item: ResponsePromise) -> bool:
|
||||
return item.uid in self._promises
|
||||
|
||||
@property
|
||||
def _fast_recv_mode(self) -> bool:
|
||||
if len(self._promises) <= 1 or self._svn is HttpVersion.h3:
|
||||
return True
|
||||
return self._recv_size_ema >= 1450
|
||||
|
||||
@property
|
||||
def last_used_at(self) -> float:
|
||||
return self._last_used_at
|
||||
|
||||
@property
|
||||
def connected_at(self) -> float | None:
|
||||
return self._connected_at
|
||||
|
||||
@property
|
||||
def disabled_svn(self) -> set[HttpVersion]:
|
||||
return self._disabled_svn
|
||||
|
||||
@property
|
||||
def _http_vsn_str(self) -> str:
|
||||
"""Reimplemented for backward compatibility purposes."""
|
||||
assert self._svn is not None
|
||||
return self._svn.value
|
||||
|
||||
@property
|
||||
def _http_vsn(self) -> int:
|
||||
"""Reimplemented for backward compatibility purposes."""
|
||||
assert self._svn is not None
|
||||
if self._cached_http_vsn is None:
|
||||
self._cached_http_vsn = int(self._svn.value.split("/")[-1].replace(".", ""))
|
||||
return self._cached_http_vsn
|
||||
|
||||
@property
|
||||
def is_saturated(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def is_idle(self) -> bool:
|
||||
return not self._promises and not self._pending_responses
|
||||
|
||||
@property
|
||||
def max_stream_count(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def is_multiplexed(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def max_frame_size(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def _upgrade(self) -> None:
|
||||
"""Upgrade conn from svn ver to max supported."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _tunnel(self) -> None:
|
||||
"""Emit proper CONNECT request to the http (server) intermediary."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _new_conn(self) -> socket.socket | None:
|
||||
"""Run protocol initialization from there. Return None to ensure that the child
|
||||
class correctly create the socket / connection."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _post_conn(self) -> None:
|
||||
"""Should be called after _new_conn proceed as expected.
|
||||
Expect protocol handshake to be done here."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _custom_tls(
|
||||
self,
|
||||
ssl_context: SSLContext | None = None,
|
||||
ca_certs: str | None = None,
|
||||
ca_cert_dir: str | None = None,
|
||||
ca_cert_data: None | str | bytes = None,
|
||||
ssl_minimum_version: int | None = None,
|
||||
ssl_maximum_version: int | None = None,
|
||||
cert_file: str | None = None,
|
||||
key_file: str | None = None,
|
||||
key_password: str | None = None,
|
||||
) -> bool:
|
||||
"""This method serve as bypassing any default tls setup.
|
||||
It is most useful when the encryption does not lie on the TCP layer. This method
|
||||
WILL raise NotImplementedError if the connection is not concerned."""
|
||||
raise NotImplementedError
|
||||
|
||||
def set_tunnel(
|
||||
self,
|
||||
host: str,
|
||||
port: int | None = None,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
scheme: str = "http",
|
||||
) -> None:
|
||||
"""Prepare the connection to set up a tunnel. Does NOT actually do the socket and http connect.
|
||||
Here host:port represent the target (final) server and not the intermediary."""
|
||||
raise NotImplementedError
|
||||
|
||||
def putrequest(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
skip_host: bool = False,
|
||||
skip_accept_encoding: bool = False,
|
||||
) -> None:
|
||||
"""It is the first method called, setting up the request initial context."""
|
||||
raise NotImplementedError
|
||||
|
||||
def putheader(self, header: str, *values: str) -> None:
|
||||
"""For a single header name, assign one or multiple value. This method is called right after putrequest()
|
||||
for each entries."""
|
||||
raise NotImplementedError
|
||||
|
||||
def endheaders(
|
||||
self,
|
||||
message_body: bytes | None = None,
|
||||
*,
|
||||
encode_chunked: bool = False,
|
||||
expect_body_afterward: bool = False,
|
||||
) -> ResponsePromise | None:
|
||||
"""This method conclude the request context construction."""
|
||||
raise NotImplementedError
|
||||
|
||||
def getresponse(
|
||||
self, *, promise: ResponsePromise | None = None
|
||||
) -> LowLevelResponse:
|
||||
"""Fetch the HTTP response. You SHOULD not retrieve the body in that method, it SHOULD be done
|
||||
in the LowLevelResponse, so it enable stream capabilities and remain efficient.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self) -> None:
|
||||
"""End the connection, do some reinit, closing of fd, etc..."""
|
||||
raise NotImplementedError
|
||||
|
||||
def send(
|
||||
self,
|
||||
data: bytes | bytearray,
|
||||
*,
|
||||
eot: bool = False,
|
||||
) -> ResponsePromise | None:
|
||||
"""The send() method SHOULD be invoked after calling endheaders() if and only if the request
|
||||
context specify explicitly that a body is going to be sent."""
|
||||
raise NotImplementedError
|
||||
|
||||
def ping(self) -> None:
|
||||
"""Send a PING to the remote peer."""
|
||||
raise NotImplementedError
|
||||
1955
.venv/lib/python3.9/site-packages/urllib3_future/backend/hface.py
Normal file
1955
.venv/lib/python3.9/site-packages/urllib3_future/backend/hface.py
Normal file
File diff suppressed because it is too large
Load Diff
1130
.venv/lib/python3.9/site-packages/urllib3_future/connection.py
Normal file
1130
.venv/lib/python3.9/site-packages/urllib3_future/connection.py
Normal file
File diff suppressed because it is too large
Load Diff
2409
.venv/lib/python3.9/site-packages/urllib3_future/connectionpool.py
Normal file
2409
.venv/lib/python3.9/site-packages/urllib3_future/connectionpool.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,237 @@
|
||||
"""
|
||||
This module contains provisional support for SOCKS proxies from within
|
||||
urllib3. This module supports SOCKS4, SOCKS4A (an extension of SOCKS4), and
|
||||
SOCKS5. To enable its functionality, either install PySocks or install this
|
||||
module with the ``socks`` extra.
|
||||
|
||||
The SOCKS implementation supports the full range of urllib3 features. It also
|
||||
supports the following SOCKS features:
|
||||
|
||||
- SOCKS4A (``proxy_url='socks4a://...``)
|
||||
- SOCKS4 (``proxy_url='socks4://...``)
|
||||
- SOCKS5 with remote DNS (``proxy_url='socks5h://...``)
|
||||
- SOCKS5 with local DNS (``proxy_url='socks5://...``)
|
||||
- Usernames and passwords for the SOCKS proxy
|
||||
|
||||
.. note::
|
||||
It is recommended to use ``socks5h://`` or ``socks4a://`` schemes in
|
||||
your ``proxy_url`` to ensure that DNS resolution is done from the remote
|
||||
server instead of client-side when connecting to a domain name.
|
||||
|
||||
SOCKS4 supports IPv4 and domain names with the SOCKS4A extension. SOCKS5
|
||||
supports IPv4, IPv6, and domain names.
|
||||
|
||||
When connecting to a SOCKS4 proxy the ``username`` portion of the ``proxy_url``
|
||||
will be sent as the ``userid`` section of the SOCKS request:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
proxy_url="socks4a://<userid>@proxy-host"
|
||||
|
||||
When connecting to a SOCKS5 proxy the ``username`` and ``password`` portion
|
||||
of the ``proxy_url`` will be sent as the username/password to authenticate
|
||||
with the proxy:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
proxy_url="socks5h://<username>:<password>@proxy-host"
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
try:
|
||||
import socks
|
||||
except ImportError:
|
||||
import warnings
|
||||
|
||||
from ..exceptions import DependencyWarning
|
||||
|
||||
warnings.warn(
|
||||
(
|
||||
"SOCKS support in urllib3 requires the installation of optional "
|
||||
"dependencies: specifically, PySocks. For more information, see "
|
||||
"https://urllib3.readthedocs.io/en/latest/contrib.html#socks-proxies"
|
||||
),
|
||||
DependencyWarning,
|
||||
)
|
||||
raise
|
||||
|
||||
import typing
|
||||
from socket import timeout as SocketTimeout
|
||||
|
||||
from .._typing import _TYPE_SOCKS_OPTIONS
|
||||
from ..backend import HttpVersion
|
||||
from ..connection import HTTPConnection, HTTPSConnection
|
||||
from ..connectionpool import HTTPConnectionPool, HTTPSConnectionPool
|
||||
from ..exceptions import ConnectTimeoutError, NewConnectionError
|
||||
from ..poolmanager import PoolManager
|
||||
from ..util.url import parse_url
|
||||
|
||||
try:
|
||||
import ssl
|
||||
except ImportError:
|
||||
ssl = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class SOCKSConnection(HTTPConnection):
|
||||
"""
|
||||
A plain-text HTTP connection that connects via a SOCKS proxy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_socks_options: _TYPE_SOCKS_OPTIONS,
|
||||
*args: typing.Any,
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
self._socks_options = _socks_options
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _new_conn(self) -> socks.socksocket:
|
||||
"""
|
||||
Establish a new connection via the SOCKS proxy.
|
||||
"""
|
||||
extra_kw: dict[str, typing.Any] = {}
|
||||
if self.source_address:
|
||||
extra_kw["source_address"] = self.source_address
|
||||
|
||||
if self.socket_options:
|
||||
only_tcp_options = []
|
||||
|
||||
for opt in self.socket_options:
|
||||
if len(opt) == 3:
|
||||
only_tcp_options.append(opt)
|
||||
elif len(opt) == 4:
|
||||
protocol: str = opt[3].lower()
|
||||
if protocol == "udp":
|
||||
continue
|
||||
only_tcp_options.append(opt[:3])
|
||||
|
||||
extra_kw["socket_options"] = only_tcp_options
|
||||
|
||||
try:
|
||||
conn = socks.create_connection(
|
||||
(self.host, self.port),
|
||||
proxy_type=self._socks_options["socks_version"], # type: ignore[arg-type]
|
||||
proxy_addr=self._socks_options["proxy_host"],
|
||||
proxy_port=self._socks_options["proxy_port"], # type: ignore[arg-type]
|
||||
proxy_username=self._socks_options["username"],
|
||||
proxy_password=self._socks_options["password"],
|
||||
proxy_rdns=self._socks_options["rdns"],
|
||||
timeout=self.timeout, # type: ignore[arg-type]
|
||||
**extra_kw,
|
||||
)
|
||||
|
||||
except SocketTimeout as e:
|
||||
raise ConnectTimeoutError(
|
||||
self,
|
||||
f"Connection to {self.host} timed out. (connect timeout={self.timeout})",
|
||||
) from e
|
||||
|
||||
except socks.ProxyError as e:
|
||||
# This is fragile as hell, but it seems to be the only way to raise
|
||||
# useful errors here.
|
||||
if e.socket_err:
|
||||
error = e.socket_err
|
||||
if isinstance(error, SocketTimeout):
|
||||
raise ConnectTimeoutError(
|
||||
self,
|
||||
f"Connection to {self.host} timed out. (connect timeout={self.timeout})",
|
||||
) from e
|
||||
else:
|
||||
# Adding `from e` messes with coverage somehow, so it's omitted.
|
||||
# See #2386.
|
||||
raise NewConnectionError(
|
||||
self, f"Failed to establish a new connection: {error}"
|
||||
)
|
||||
else:
|
||||
raise NewConnectionError(
|
||||
self, f"Failed to establish a new connection: {e}"
|
||||
) from e
|
||||
|
||||
except OSError as e: # Defensive: PySocks should catch all these.
|
||||
raise NewConnectionError(
|
||||
self, f"Failed to establish a new connection: {e}"
|
||||
) from e
|
||||
|
||||
return conn
|
||||
|
||||
|
||||
# We don't need to duplicate the Verified/Unverified distinction from
|
||||
# urllib3/connection.py here because the HTTPSConnection will already have been
|
||||
# correctly set to either the Verified or Unverified form by that module. This
|
||||
# means the SOCKSHTTPSConnection will automatically be the correct type.
|
||||
class SOCKSHTTPSConnection(SOCKSConnection, HTTPSConnection):
|
||||
pass
|
||||
|
||||
|
||||
class SOCKSHTTPConnectionPool(HTTPConnectionPool):
|
||||
ConnectionCls = SOCKSConnection
|
||||
|
||||
|
||||
class SOCKSHTTPSConnectionPool(HTTPSConnectionPool):
|
||||
ConnectionCls = SOCKSHTTPSConnection
|
||||
|
||||
|
||||
class SOCKSProxyManager(PoolManager):
|
||||
"""
|
||||
A version of the urllib3 ProxyManager that routes connections via the
|
||||
defined SOCKS proxy.
|
||||
"""
|
||||
|
||||
pool_classes_by_scheme = {
|
||||
"http": SOCKSHTTPConnectionPool,
|
||||
"https": SOCKSHTTPSConnectionPool,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proxy_url: str,
|
||||
username: str | None = None,
|
||||
password: str | None = None,
|
||||
num_pools: int = 10,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
**connection_pool_kw: typing.Any,
|
||||
):
|
||||
parsed = parse_url(proxy_url)
|
||||
|
||||
if username is None and password is None and parsed.auth is not None:
|
||||
split = parsed.auth.split(":")
|
||||
if len(split) == 2:
|
||||
username, password = split
|
||||
if parsed.scheme == "socks5":
|
||||
socks_version = socks.PROXY_TYPE_SOCKS5
|
||||
rdns = False
|
||||
elif parsed.scheme == "socks5h":
|
||||
socks_version = socks.PROXY_TYPE_SOCKS5
|
||||
rdns = True
|
||||
elif parsed.scheme == "socks4":
|
||||
socks_version = socks.PROXY_TYPE_SOCKS4
|
||||
rdns = False
|
||||
elif parsed.scheme == "socks4a":
|
||||
socks_version = socks.PROXY_TYPE_SOCKS4
|
||||
rdns = True
|
||||
else:
|
||||
raise ValueError(f"Unable to determine SOCKS version from {proxy_url}")
|
||||
|
||||
self.proxy_url = proxy_url
|
||||
|
||||
socks_options = {
|
||||
"socks_version": socks_version,
|
||||
"proxy_host": parsed.host,
|
||||
"proxy_port": parsed.port,
|
||||
"username": username,
|
||||
"password": password,
|
||||
"rdns": rdns,
|
||||
}
|
||||
connection_pool_kw["_socks_options"] = socks_options
|
||||
|
||||
if "disabled_svn" not in connection_pool_kw:
|
||||
connection_pool_kw["disabled_svn"] = set()
|
||||
|
||||
connection_pool_kw["disabled_svn"].add(HttpVersion.h3)
|
||||
|
||||
super().__init__(num_pools, headers, **connection_pool_kw)
|
||||
|
||||
self.pool_classes_by_scheme = SOCKSProxyManager.pool_classes_by_scheme
|
||||
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
This is hazmat. It can blow up anytime.
|
||||
Use it with precautions!
|
||||
|
||||
Reasoning behind this:
|
||||
|
||||
1) python-socks requires another dependency, namely asyncio-timeout, that is one too much for us.
|
||||
2) it does not support our AsyncSocket wrapper (it has his own internally)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
import typing
|
||||
import warnings
|
||||
|
||||
from python_socks import _abc as abc
|
||||
|
||||
# look the other way if unpleasant. No choice for now.
|
||||
# will start discussions once we have a solid traffic.
|
||||
from python_socks._connectors.abc import AsyncConnector
|
||||
from python_socks._connectors.socks4_async import Socks4AsyncConnector
|
||||
from python_socks._connectors.socks5_async import Socks5AsyncConnector
|
||||
from python_socks._errors import ProxyError, ProxyTimeoutError
|
||||
from python_socks._helpers import parse_proxy_url
|
||||
from python_socks._protocols.errors import ReplyError
|
||||
from python_socks._types import ProxyType
|
||||
|
||||
from .ssa import AsyncSocket
|
||||
from .ssa._timeout import timeout as timeout_
|
||||
|
||||
|
||||
class Resolver(abc.AsyncResolver):
|
||||
def __init__(self, loop: asyncio.AbstractEventLoop):
|
||||
self._loop = loop
|
||||
|
||||
async def resolve(
|
||||
self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_UNSPEC
|
||||
) -> tuple[socket.AddressFamily, str]:
|
||||
infos = await self._loop.getaddrinfo(
|
||||
host=host,
|
||||
port=port,
|
||||
family=family,
|
||||
type=socket.SOCK_STREAM,
|
||||
)
|
||||
|
||||
if not infos: # Defensive:
|
||||
raise OSError(f"Can`t resolve address {host}:{port} [{family}]")
|
||||
|
||||
infos = sorted(infos, key=lambda info: info[0])
|
||||
|
||||
family, _, _, _, address = infos[0]
|
||||
return family, address[0]
|
||||
|
||||
|
||||
def create_connector(
|
||||
proxy_type: ProxyType,
|
||||
username: str | None,
|
||||
password: str | None,
|
||||
rdns: bool,
|
||||
resolver: abc.AsyncResolver,
|
||||
) -> AsyncConnector:
|
||||
if proxy_type == ProxyType.SOCKS4:
|
||||
return Socks4AsyncConnector(
|
||||
user_id=username,
|
||||
rdns=rdns,
|
||||
resolver=resolver,
|
||||
)
|
||||
|
||||
if proxy_type == ProxyType.SOCKS5:
|
||||
return Socks5AsyncConnector(
|
||||
username=username,
|
||||
password=password,
|
||||
rdns=rdns,
|
||||
resolver=resolver,
|
||||
)
|
||||
|
||||
raise ValueError(f"Invalid proxy type: {proxy_type}")
|
||||
|
||||
|
||||
class AsyncioProxy:
|
||||
def __init__(
|
||||
self,
|
||||
proxy_type: ProxyType,
|
||||
host: str,
|
||||
port: int,
|
||||
username: str | None = None,
|
||||
password: str | None = None,
|
||||
rdns: bool = False,
|
||||
):
|
||||
self._loop = asyncio.get_event_loop()
|
||||
|
||||
self._proxy_type = proxy_type
|
||||
self._proxy_host = host
|
||||
self._proxy_port = port
|
||||
self._password = password
|
||||
self._username = username
|
||||
self._rdns = rdns
|
||||
|
||||
self._resolver = Resolver(loop=self._loop)
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
dest_host: str,
|
||||
dest_port: int,
|
||||
timeout: float | None = None,
|
||||
_socket: AsyncSocket | None = None,
|
||||
) -> AsyncSocket:
|
||||
if timeout is None:
|
||||
timeout = 60
|
||||
|
||||
try:
|
||||
async with timeout_(timeout):
|
||||
# our dependency started to deprecate passing "_socket"
|
||||
# which is ... vital for our integration. We'll start by silencing the warning.
|
||||
# then we'll think on how to proceed.
|
||||
# A) the maintainer agrees to revert https://github.com/romis2012/python-socks/commit/173a7390469c06aa033f8dca67c827854b462bc3#diff-e4086fa970d1c98b1eb341e58cb70e9ceffe7391b2feecc4b66c7e92ea2de76fR64
|
||||
# B) the maintainer pursue the removal -> do we vendor our copy of python-socks? is there an alternative?
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
return await self._connect(
|
||||
dest_host=dest_host,
|
||||
dest_port=dest_port,
|
||||
_socket=_socket, # type: ignore[arg-type]
|
||||
)
|
||||
except asyncio.TimeoutError as e:
|
||||
raise ProxyTimeoutError(f"Proxy connection timed out: {timeout}") from e
|
||||
|
||||
async def _connect(
|
||||
self, dest_host: str, dest_port: int, _socket: AsyncSocket
|
||||
) -> AsyncSocket:
|
||||
try:
|
||||
connector = create_connector(
|
||||
proxy_type=self._proxy_type,
|
||||
username=self._username,
|
||||
password=self._password,
|
||||
rdns=self._rdns,
|
||||
resolver=self._resolver,
|
||||
)
|
||||
await connector.connect(
|
||||
stream=_socket, # type: ignore[arg-type]
|
||||
host=dest_host,
|
||||
port=dest_port,
|
||||
)
|
||||
|
||||
return _socket
|
||||
except asyncio.CancelledError: # Defensive:
|
||||
_socket.close()
|
||||
raise
|
||||
except ReplyError as e:
|
||||
_socket.close()
|
||||
raise ProxyError(e, error_code=e.error_code) # type: ignore[no-untyped-call]
|
||||
except Exception: # Defensive:
|
||||
_socket.close()
|
||||
raise
|
||||
|
||||
@property
|
||||
def proxy_host(self) -> str:
|
||||
return self._proxy_host
|
||||
|
||||
@property
|
||||
def proxy_port(self) -> int:
|
||||
return self._proxy_port
|
||||
|
||||
@classmethod
|
||||
def create(cls, *args: typing.Any, **kwargs: typing.Any) -> AsyncioProxy:
|
||||
return cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, **kwargs: typing.Any) -> AsyncioProxy:
|
||||
url_args = parse_proxy_url(url)
|
||||
return cls(*url_args, **kwargs)
|
||||
@@ -0,0 +1,27 @@
|
||||
# Dummy file to match upstream modules
|
||||
# without actually serving them.
|
||||
# urllib3-future diverged from urllib3.
|
||||
# only the top-level (public API) are guaranteed to be compatible.
|
||||
# in-fact urllib3-future propose a better way to migrate/transition toward
|
||||
# newer protocols.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
def inject_into_urllib3() -> None:
|
||||
warnings.warn(
|
||||
(
|
||||
"urllib3-future does not support WASM / Emscripten platform. "
|
||||
"Please reinstall legacy urllib3 in the meantime. "
|
||||
"Run `pip uninstall -y urllib3 urllib3-future` then "
|
||||
"`pip install urllib3-future`, finally `pip install urllib3`. "
|
||||
"Sorry for the inconvenience."
|
||||
),
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
|
||||
def extract_from_urllib3() -> None:
|
||||
pass
|
||||
@@ -0,0 +1,39 @@
|
||||
# Copyright 2022 Akamai Technologies, Inc
|
||||
# Largely rewritten in 2023 for urllib3-future
|
||||
# Copyright 2024 Ahmed Tahri
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ._configuration import QuicTLSConfig
|
||||
from .protocols import (
|
||||
HTTP1Protocol,
|
||||
HTTP2Protocol,
|
||||
HTTP3Protocol,
|
||||
HTTPOverQUICProtocol,
|
||||
HTTPOverTCPProtocol,
|
||||
HTTPProtocol,
|
||||
HTTPProtocolFactory,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"QuicTLSConfig",
|
||||
"HTTP1Protocol",
|
||||
"HTTP2Protocol",
|
||||
"HTTP3Protocol",
|
||||
"HTTPOverQUICProtocol",
|
||||
"HTTPOverTCPProtocol",
|
||||
"HTTPProtocol",
|
||||
"HTTPProtocolFactory",
|
||||
)
|
||||
@@ -0,0 +1,59 @@
|
||||
# Copyright 2022 Akamai Technologies, Inc
|
||||
# Largely rewritten in 2023 for urllib3-future
|
||||
# Copyright 2024 Ahmed Tahri
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Mapping
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class QuicTLSConfig:
|
||||
"""
|
||||
Client TLS configuration.
|
||||
"""
|
||||
|
||||
#: Allows to proceed for server without valid TLS certificates.
|
||||
insecure: bool = False
|
||||
|
||||
#: File with CA certificates to trust for server verification
|
||||
cafile: str | None = None
|
||||
|
||||
#: Directory with CA certificates to trust for server verification
|
||||
capath: str | None = None
|
||||
|
||||
#: Blob with CA certificates to trust for server verification
|
||||
cadata: bytes | None = None
|
||||
|
||||
#: If provided, will trigger an additional load_cert_chain() upon the QUIC Configuration
|
||||
certfile: str | bytes | None = None
|
||||
|
||||
keyfile: str | bytes | None = None
|
||||
|
||||
keypassword: str | bytes | None = None
|
||||
|
||||
#: The QUIC session ticket which should be used for session resumption
|
||||
session_ticket: Any | None = None
|
||||
|
||||
cert_fingerprint: str | None = None
|
||||
cert_use_common_name: bool = False
|
||||
|
||||
verify_hostname: bool = True
|
||||
assert_hostname: str | None = None
|
||||
|
||||
ciphers: list[Mapping[str, Any]] | None = None
|
||||
|
||||
idle_timeout: float = 300.0
|
||||
@@ -0,0 +1,151 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from collections import deque
|
||||
|
||||
from .events import Event
|
||||
|
||||
|
||||
class StreamMatrix:
|
||||
"""Efficient way to store events for concurrent streams."""
|
||||
|
||||
__slots__ = (
|
||||
"_matrix",
|
||||
"_count",
|
||||
"_event_cursor_id",
|
||||
)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._matrix: dict[int | None, deque[Event]] = {}
|
||||
self._count: int = 0
|
||||
self._event_cursor_id: int = 0
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._count
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return self._count > 0
|
||||
|
||||
@property
|
||||
def streams(self) -> list[int]:
|
||||
return sorted(i for i in self._matrix.keys() if i is not None)
|
||||
|
||||
def append(self, event: Event) -> None:
|
||||
matrix_idx = getattr(event, "stream_id", None)
|
||||
|
||||
event._id = self._event_cursor_id
|
||||
self._event_cursor_id += 1
|
||||
|
||||
if matrix_idx not in self._matrix:
|
||||
self._matrix[matrix_idx] = deque()
|
||||
|
||||
self._matrix[matrix_idx].append(event)
|
||||
|
||||
self._count += 1
|
||||
|
||||
def extend(self, events: typing.Iterable[Event]) -> None:
|
||||
triaged_events: dict[int | None, list[Event]] = {}
|
||||
|
||||
for event in events:
|
||||
matrix_idx = getattr(event, "stream_id", None)
|
||||
|
||||
event._id = self._event_cursor_id
|
||||
|
||||
self._event_cursor_id += 1
|
||||
self._count += 1
|
||||
|
||||
if matrix_idx not in triaged_events:
|
||||
triaged_events[matrix_idx] = []
|
||||
|
||||
triaged_events[matrix_idx].append(event)
|
||||
|
||||
for k, v in triaged_events.items():
|
||||
if k not in self._matrix:
|
||||
self._matrix[k] = deque()
|
||||
|
||||
self._matrix[k].extend(v)
|
||||
|
||||
def appendleft(self, event: Event) -> None:
|
||||
matrix_idx = getattr(event, "stream_id", None)
|
||||
event._id = self._event_cursor_id
|
||||
self._event_cursor_id += 1
|
||||
|
||||
if matrix_idx not in self._matrix:
|
||||
self._matrix[matrix_idx] = deque()
|
||||
|
||||
self._matrix[matrix_idx].appendleft(event)
|
||||
|
||||
self._count += 1
|
||||
|
||||
def popleft(self, stream_id: int | None = None) -> Event | None:
|
||||
if self._count == 0:
|
||||
return None
|
||||
|
||||
have_global_event: bool = None in self._matrix and bool(self._matrix[None])
|
||||
any_stream_event: bool = (
|
||||
bool(self._matrix) if not have_global_event else len(self._matrix) > 1
|
||||
)
|
||||
|
||||
if stream_id is None and any_stream_event:
|
||||
matrix_dict_iter = self._matrix.__iter__()
|
||||
|
||||
stream_id = next(matrix_dict_iter)
|
||||
|
||||
if stream_id is None:
|
||||
stream_id = next(matrix_dict_iter)
|
||||
|
||||
if (
|
||||
stream_id is not None
|
||||
and have_global_event
|
||||
and stream_id in self._matrix
|
||||
and self._matrix[None][0]._id < self._matrix[stream_id][0]._id
|
||||
):
|
||||
stream_id = None
|
||||
elif have_global_event is True and stream_id not in self._matrix:
|
||||
stream_id = None
|
||||
|
||||
if stream_id not in self._matrix:
|
||||
return None
|
||||
|
||||
ev = self._matrix[stream_id].popleft()
|
||||
|
||||
if ev is not None:
|
||||
self._count -= 1
|
||||
|
||||
if stream_id is not None and not self._matrix[stream_id]:
|
||||
del self._matrix[stream_id]
|
||||
|
||||
return ev
|
||||
|
||||
def count(
|
||||
self,
|
||||
stream_id: int | None = None,
|
||||
excl_event: tuple[type[Event], ...] | None = None,
|
||||
) -> int:
|
||||
if stream_id is None:
|
||||
return self._count
|
||||
if stream_id not in self._matrix:
|
||||
return 0
|
||||
|
||||
return len(
|
||||
self._matrix[stream_id]
|
||||
if excl_event is None
|
||||
else [e for e in self._matrix[stream_id] if not isinstance(e, excl_event)]
|
||||
)
|
||||
|
||||
def has(
|
||||
self,
|
||||
stream_id: int | None = None,
|
||||
excl_event: tuple[type[Event], ...] | None = None,
|
||||
) -> bool:
|
||||
if stream_id is None:
|
||||
return True if self._count else False
|
||||
if stream_id not in self._matrix:
|
||||
return False
|
||||
|
||||
if excl_event is not None:
|
||||
return any(
|
||||
e for e in self._matrix[stream_id] if not isinstance(e, excl_event)
|
||||
)
|
||||
|
||||
return True if self._matrix[stream_id] else False
|
||||
@@ -0,0 +1,25 @@
|
||||
# Copyright 2022 Akamai Technologies, Inc
|
||||
# Largely rewritten in 2023 for urllib3-future
|
||||
# Copyright 2024 Ahmed Tahri
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
HeaderType = Tuple[bytes, bytes]
|
||||
HeadersType = Sequence[HeaderType]
|
||||
|
||||
AddressType = Tuple[str, int]
|
||||
DatagramType = Tuple[bytes, AddressType]
|
||||
@@ -0,0 +1,43 @@
|
||||
# Copyright 2022 Akamai Technologies, Inc
|
||||
# Largely rewritten in 2023 for urllib3-future
|
||||
# Copyright 2024 Ahmed Tahri
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ._events import (
|
||||
ConnectionTerminated,
|
||||
DataReceived,
|
||||
EarlyHeadersReceived,
|
||||
Event,
|
||||
GoawayReceived,
|
||||
HandshakeCompleted,
|
||||
HeadersReceived,
|
||||
StreamEvent,
|
||||
StreamReset,
|
||||
StreamResetReceived,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"Event",
|
||||
"ConnectionTerminated",
|
||||
"GoawayReceived",
|
||||
"StreamEvent",
|
||||
"StreamReset",
|
||||
"StreamResetReceived",
|
||||
"HeadersReceived",
|
||||
"DataReceived",
|
||||
"HandshakeCompleted",
|
||||
"EarlyHeadersReceived",
|
||||
)
|
||||
@@ -0,0 +1,202 @@
|
||||
# Copyright 2022 Akamai Technologies, Inc
|
||||
# Largely rewritten in 2023 for urllib3-future
|
||||
# Copyright 2024 Ahmed Tahri
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from .._typing import HeadersType
|
||||
|
||||
|
||||
class Event:
|
||||
"""
|
||||
Base class for HTTP events.
|
||||
|
||||
This is an abstract base class that should not be initialized.
|
||||
"""
|
||||
|
||||
_id: int
|
||||
|
||||
|
||||
#
|
||||
# Connection events
|
||||
#
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionTerminated(Event):
|
||||
"""
|
||||
Connection was terminated.
|
||||
|
||||
Extends :class:`.Event`.
|
||||
"""
|
||||
|
||||
#: Reason for closing the connection.
|
||||
error_code: int = 0
|
||||
|
||||
#: Optional message with more information
|
||||
message: str | None = field(default=None, compare=False)
|
||||
|
||||
def __repr__(self) -> str: # Defensive: debug purposes only
|
||||
cls = type(self).__name__
|
||||
return f"{cls}(error_code={self.error_code!r}, message={self.message!r})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GoawayReceived(Event):
|
||||
"""
|
||||
GOAWAY frame was received
|
||||
|
||||
Extends :class:`.Event`.
|
||||
"""
|
||||
|
||||
#: Highest stream ID that could be processed.
|
||||
last_stream_id: int
|
||||
|
||||
#: Reason for closing the connection.
|
||||
error_code: int = 0
|
||||
|
||||
def __repr__(self) -> str: # Defensive: debug purposes only
|
||||
cls = type(self).__name__
|
||||
return (
|
||||
f"{cls}(last_stream_id={self.last_stream_id!r}, "
|
||||
f"error_code={self.error_code!r})"
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# Stream events
|
||||
#
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamEvent(Event):
|
||||
"""
|
||||
Event on one HTTP stream.
|
||||
|
||||
This is an abstract base class that should not be used directly.
|
||||
|
||||
Extends :class:`.Event`.
|
||||
"""
|
||||
|
||||
#: Stream ID
|
||||
stream_id: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamReset(StreamEvent):
|
||||
"""
|
||||
One stream of an HTTP connection was reset.
|
||||
|
||||
When a stream is reset, it must no longer be used, but the parent
|
||||
connection and other streams are unaffected.
|
||||
|
||||
This is an abstract base class that should not be used directly.
|
||||
More specific subclasses (StreamResetSent or StreamResetReceived)
|
||||
should be emitted.
|
||||
|
||||
Extends :class:`.StreamEvent`.
|
||||
"""
|
||||
|
||||
#: Reason for closing the stream.
|
||||
error_code: int = 0
|
||||
end_stream: bool = True
|
||||
|
||||
def __repr__(self) -> str: # Defensive: debug purposes only
|
||||
cls = type(self).__name__
|
||||
return f"{cls}(stream_id={self.stream_id!r}, error_code={self.error_code!r})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamResetReceived(StreamReset):
|
||||
"""
|
||||
One stream of an HTTP connection was reset by the peer.
|
||||
|
||||
This probably means that we did something that the peer does not like.
|
||||
|
||||
Extends :class:`.StreamReset`.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class HandshakeCompleted(Event):
|
||||
alpn_protocol: str | None
|
||||
|
||||
def __repr__(self) -> str: # Defensive: debug purposes only
|
||||
cls = type(self).__name__
|
||||
return f"{cls}(alpn={self.alpn_protocol})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HeadersReceived(StreamEvent):
|
||||
"""
|
||||
A frame with HTTP headers was received.
|
||||
|
||||
Extends :class:`.StreamEvent`.
|
||||
"""
|
||||
|
||||
#: The received HTTP headers
|
||||
headers: HeadersType
|
||||
|
||||
#: Signals that data will not be sent by the peer over the stream.
|
||||
end_stream: bool = False
|
||||
|
||||
def __repr__(self) -> str: # Defensive: debug purposes only
|
||||
cls = type(self).__name__
|
||||
return (
|
||||
f"{cls}(stream_id={self.stream_id!r}, "
|
||||
f"len(headers)={len(self.headers)}, end_stream={self.end_stream!r})"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataReceived(StreamEvent):
|
||||
"""
|
||||
A frame with HTTP data was received.
|
||||
|
||||
Extends :class:`.StreamEvent`.
|
||||
"""
|
||||
|
||||
#: The received data.
|
||||
data: bytes
|
||||
|
||||
#: Signals that no more data will be sent by the peer over the stream.
|
||||
end_stream: bool = False
|
||||
|
||||
def __repr__(self) -> str: # Defensive: debug purposes only
|
||||
cls = type(self).__name__
|
||||
return (
|
||||
f"{cls}(stream_id={self.stream_id!r}, "
|
||||
f"len(data)={len(self.data)}, end_stream={self.end_stream!r})"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EarlyHeadersReceived(StreamEvent):
|
||||
#: The received HTTP headers
|
||||
headers: HeadersType
|
||||
|
||||
def __repr__(self) -> str: # Defensive: debug purposes only
|
||||
cls = type(self).__name__
|
||||
return (
|
||||
f"{cls}(stream_id={self.stream_id!r}, "
|
||||
f"len(headers)={len(self.headers)}, end_stream=False)"
|
||||
)
|
||||
|
||||
@property
|
||||
def end_stream(self) -> typing.Literal[False]:
|
||||
return False
|
||||
@@ -0,0 +1,37 @@
|
||||
# Copyright 2022 Akamai Technologies, Inc
|
||||
# Largely rewritten in 2023 for urllib3-future
|
||||
# Copyright 2024 Ahmed Tahri
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ._factories import HTTPProtocolFactory
|
||||
from ._protocols import (
|
||||
HTTP1Protocol,
|
||||
HTTP2Protocol,
|
||||
HTTP3Protocol,
|
||||
HTTPOverQUICProtocol,
|
||||
HTTPOverTCPProtocol,
|
||||
HTTPProtocol,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"HTTP1Protocol",
|
||||
"HTTP2Protocol",
|
||||
"HTTP3Protocol",
|
||||
"HTTPOverQUICProtocol",
|
||||
"HTTPOverTCPProtocol",
|
||||
"HTTPProtocol",
|
||||
"HTTPProtocolFactory",
|
||||
)
|
||||
@@ -0,0 +1,90 @@
|
||||
# Copyright 2022 Akamai Technologies, Inc
|
||||
# Largely rewritten in 2023 for urllib3-future
|
||||
# Copyright 2024 Ahmed Tahri
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
HTTP factories create HTTP protools based on defined set of arguments.
|
||||
|
||||
We define the :class:`HTTPProtocol` interface to allow interchange
|
||||
HTTP versions and protocol implementations. But constructors of
|
||||
the class is not part of the interface. Every implementation
|
||||
can use a different options to init instances.
|
||||
|
||||
Factories unify access to the creation of the protocol instances,
|
||||
so that clients and servers can swap protocol implementations,
|
||||
delegating the initialization to factories.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
from abc import ABCMeta
|
||||
from typing import Any
|
||||
|
||||
from ._protocols import HTTPOverQUICProtocol, HTTPOverTCPProtocol, HTTPProtocol
|
||||
|
||||
|
||||
class HTTPProtocolFactory(metaclass=ABCMeta):
|
||||
@staticmethod
|
||||
def new(
|
||||
type_protocol: type[HTTPProtocol],
|
||||
implementation: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> HTTPOverQUICProtocol | HTTPOverTCPProtocol:
|
||||
"""Create a new state-machine that target given protocol type."""
|
||||
assert type_protocol != HTTPProtocol, (
|
||||
"HTTPProtocol is ambiguous and cannot be requested in the factory."
|
||||
)
|
||||
|
||||
package_name: str = __name__.split(".")[0]
|
||||
|
||||
version_target: str = "".join(
|
||||
c for c in str(type_protocol).replace(package_name, "") if c.isdigit()
|
||||
)
|
||||
module_expr: str = f".protocols.http{version_target}"
|
||||
|
||||
if implementation:
|
||||
module_expr += f"._{implementation.lower()}"
|
||||
|
||||
try:
|
||||
http_module = importlib.import_module(
|
||||
module_expr, f"{package_name}.contrib.hface"
|
||||
)
|
||||
except ImportError as e:
|
||||
raise NotImplementedError(
|
||||
f"{type_protocol} cannot be loaded. Tried to import '{module_expr}'."
|
||||
) from e
|
||||
|
||||
implementations: list[
|
||||
tuple[str, type[HTTPOverQUICProtocol | HTTPOverTCPProtocol]]
|
||||
] = inspect.getmembers(
|
||||
http_module,
|
||||
lambda e: isinstance(e, type)
|
||||
and issubclass(e, (HTTPOverQUICProtocol, HTTPOverTCPProtocol)),
|
||||
)
|
||||
|
||||
if not implementations:
|
||||
raise NotImplementedError(
|
||||
f"{type_protocol} cannot be loaded. "
|
||||
"No compatible implementation available. "
|
||||
"Make sure your implementation inherit either from HTTPOverQUICProtocol or HTTPOverTCPProtocol."
|
||||
)
|
||||
|
||||
implementation_target: type[HTTPOverQUICProtocol | HTTPOverTCPProtocol] = (
|
||||
implementations.pop()[1]
|
||||
)
|
||||
|
||||
return implementation_target(**kwargs)
|
||||
@@ -0,0 +1,358 @@
|
||||
# Copyright 2022 Akamai Technologies, Inc
|
||||
# Largely rewritten in 2023 for urllib3-future
|
||||
# Copyright 2024 Ahmed Tahri
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, Sequence
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from typing_extensions import Literal
|
||||
|
||||
from .._typing import HeadersType
|
||||
from ..events import Event
|
||||
|
||||
|
||||
class BaseProtocol(metaclass=ABCMeta):
|
||||
"""Sans-IO common methods whenever it is TCP, UDP or QUIC."""
|
||||
|
||||
@abstractmethod
|
||||
def bytes_received(self, data: bytes) -> None:
|
||||
"""
|
||||
Called when some data is received.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# Sending direction
|
||||
|
||||
@abstractmethod
|
||||
def bytes_to_send(self) -> bytes:
|
||||
"""
|
||||
Returns data for sending out of the internal data buffer.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def connection_lost(self) -> None:
|
||||
"""
|
||||
Called when the connection is lost or closed.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def should_wait_remote_flow_control(
|
||||
self, stream_id: int, amt: int | None = None
|
||||
) -> bool | None:
|
||||
"""
|
||||
Verify if the client should listen network incoming data for
|
||||
the flow control update purposes.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def max_frame_size(self) -> int:
|
||||
"""
|
||||
Determine if the remote set a limited size for each data frame.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class OverTCPProtocol(BaseProtocol, metaclass=ABCMeta):
|
||||
"""
|
||||
Interface for sans-IO protocols on top TCP.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def eof_received(self) -> None:
|
||||
"""
|
||||
Called when the other end signals it won’t send any more data.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class OverUDPProtocol(BaseProtocol, metaclass=ABCMeta):
|
||||
"""
|
||||
Interface for sans-IO protocols on top UDP.
|
||||
"""
|
||||
|
||||
|
||||
class OverQUICProtocol(OverUDPProtocol):
|
||||
@property
|
||||
@abstractmethod
|
||||
def connection_ids(self) -> Sequence[bytes]:
|
||||
"""
|
||||
QUIC connection IDs
|
||||
|
||||
This property can be used to assign UDP packets to QUIC connections.
|
||||
|
||||
:return: a sequence of connection IDs
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def session_ticket(self) -> Any | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@typing.overload
|
||||
def getpeercert(self, *, binary_form: Literal[True]) -> bytes: ...
|
||||
|
||||
@typing.overload
|
||||
def getpeercert(self, *, binary_form: Literal[False] = ...) -> dict[str, Any]: ...
|
||||
|
||||
@abstractmethod
|
||||
def getpeercert(self, *, binary_form: bool = False) -> bytes | dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@typing.overload
|
||||
def getissuercert(self, *, binary_form: Literal[True]) -> bytes | None: ...
|
||||
|
||||
@typing.overload
|
||||
def getissuercert(
|
||||
self, *, binary_form: Literal[False] = ...
|
||||
) -> dict[str, Any] | None: ...
|
||||
|
||||
@abstractmethod
|
||||
def getissuercert(
|
||||
self, *, binary_form: bool = False
|
||||
) -> bytes | dict[str, Any] | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def cipher(self) -> str | None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class HTTPProtocol(metaclass=ABCMeta):
|
||||
"""
|
||||
Sans-IO representation of an HTTP connection
|
||||
"""
|
||||
|
||||
implementation: str
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def exceptions() -> tuple[type[BaseException], ...]:
|
||||
"""Return exception types that should be handled in your application."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def multiplexed(self) -> bool:
|
||||
"""
|
||||
Whether this connection supports multiple parallel streams.
|
||||
|
||||
Returns ``True`` for HTTP/2 and HTTP/3 connections.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_stream_count(self) -> int:
|
||||
"""Determine how much concurrent stream the connection can handle."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def is_idle(self) -> bool:
|
||||
"""
|
||||
Return True if this connection is BOTH available and not doing anything.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def is_available(self) -> bool:
|
||||
"""
|
||||
Return whether this connection is capable to open new streams.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def has_expired(self) -> bool:
|
||||
"""
|
||||
Return whether this connection is closed or should be closed.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_available_stream_id(self) -> int:
|
||||
"""
|
||||
Return an ID that can be used to create a new stream.
|
||||
|
||||
Use the returned ID with :meth:`.submit_headers` to create the stream.
|
||||
This method may or may not return one value until that method is called.
|
||||
|
||||
:return: stream ID
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def submit_headers(
|
||||
self, stream_id: int, headers: HeadersType, end_stream: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Submit a frame with HTTP headers.
|
||||
|
||||
If this is a client connection, this method starts an HTTP request.
|
||||
If this is a server connection, it starts an HTTP response.
|
||||
|
||||
:param stream_id: stream ID
|
||||
:param headers: HTTP headers
|
||||
:param end_stream: whether to close the stream for sending
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def submit_data(
|
||||
self, stream_id: int, data: bytes, end_stream: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Submit a frame with HTTP data.
|
||||
|
||||
:param stream_id: stream ID
|
||||
:param data: payload
|
||||
:param end_stream: whether to close the stream for sending
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def submit_stream_reset(self, stream_id: int, error_code: int = 0) -> None:
|
||||
"""
|
||||
Immediate terminate a stream.
|
||||
|
||||
Stream reset is used to request cancellation of a stream
|
||||
or to indicate that an error condition has occurred.
|
||||
|
||||
Use :attr:`.error_codes` to obtain error codes for common problems.
|
||||
|
||||
:param stream_id: stream ID
|
||||
:param error_code: indicates why the stream is being terminated
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def submit_close(self, error_code: int = 0) -> None:
|
||||
"""
|
||||
Submit graceful close the connection.
|
||||
|
||||
Use :attr:`.error_codes` to obtain error codes for common problems.
|
||||
|
||||
:param error_code: indicates why the connections is being closed
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def next_event(self, stream_id: int | None = None) -> Event | None:
|
||||
"""
|
||||
Consume next HTTP event.
|
||||
|
||||
:return: an event instance
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def events(self, stream_id: int | None = None) -> typing.Iterator[Event]:
|
||||
"""
|
||||
Consume available HTTP events.
|
||||
|
||||
:return: an iterator that unpack "next_event" until exhausted.
|
||||
"""
|
||||
while True:
|
||||
ev = self.next_event(stream_id=stream_id)
|
||||
|
||||
if ev is None:
|
||||
break
|
||||
|
||||
yield ev
|
||||
|
||||
@abstractmethod
|
||||
def has_pending_event(
|
||||
self,
|
||||
*,
|
||||
stream_id: int | None = None,
|
||||
excl_event: tuple[type[Event], ...] | None = None,
|
||||
) -> bool:
|
||||
"""Verify if there is queued event waiting to be consumed."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def reshelve(self, *events: Event) -> None:
|
||||
"""Put back events into the deque."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def ping(self) -> None:
|
||||
"""Send a PING frame to the remote peer. Thus keeping the connection alive."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class HTTPOverTCPProtocol(HTTPProtocol, OverTCPProtocol, metaclass=ABCMeta):
|
||||
"""
|
||||
:class:`HTTPProtocol` over a TCP connection
|
||||
|
||||
An interface for HTTP/1 and HTTP/2 protocols.
|
||||
Extends :class:`.HTTPProtocol`.
|
||||
"""
|
||||
|
||||
|
||||
class HTTPOverQUICProtocol(HTTPProtocol, OverQUICProtocol, metaclass=ABCMeta):
|
||||
"""
|
||||
:class:`HTTPProtocol` over a QUIC connection
|
||||
|
||||
Abstract base class for HTTP/3 protocols.
|
||||
Extends :class:`.HTTPProtocol`.
|
||||
"""
|
||||
|
||||
|
||||
class HTTP1Protocol(HTTPOverTCPProtocol, metaclass=ABCMeta):
|
||||
"""
|
||||
Sans-IO representation of an HTTP/1 connection
|
||||
|
||||
An interface for HTTP/1 implementations.
|
||||
Extends :class:`.HTTPOverTCPProtocol`.
|
||||
"""
|
||||
|
||||
@property
|
||||
def multiplexed(self) -> bool:
|
||||
return False
|
||||
|
||||
def should_wait_remote_flow_control(
|
||||
self, stream_id: int, amt: int | None = None
|
||||
) -> bool | None:
|
||||
return NotImplemented # type: ignore[no-any-return]
|
||||
|
||||
|
||||
class HTTP2Protocol(HTTPOverTCPProtocol, metaclass=ABCMeta):
|
||||
"""
|
||||
Sans-IO representation of an HTTP/2 connection
|
||||
|
||||
An abstract base class for HTTP/2 implementations.
|
||||
Extends :class:`.HTTPOverTCPProtocol`.
|
||||
"""
|
||||
|
||||
@property
|
||||
def multiplexed(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class HTTP3Protocol(HTTPOverQUICProtocol, metaclass=ABCMeta):
|
||||
"""
|
||||
Sans-IO representation of an HTTP/2 connection
|
||||
|
||||
An abstract base class for HTTP/3 implementations.
|
||||
Extends :class:`.HTTPOverQUICProtocol`
|
||||
"""
|
||||
|
||||
@property
|
||||
def multiplexed(self) -> bool:
|
||||
return True
|
||||
@@ -0,0 +1,21 @@
|
||||
# Copyright 2022 Akamai Technologies, Inc
|
||||
# Largely rewritten in 2023 for urllib3-future
|
||||
# Copyright 2024 Ahmed Tahri
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ._h11 import HTTP1ProtocolHyperImpl
|
||||
|
||||
__all__ = ("HTTP1ProtocolHyperImpl",)
|
||||
@@ -0,0 +1,347 @@
|
||||
# Copyright 2022 Akamai Technologies, Inc
|
||||
# Largely rewritten in 2023 for urllib3-future
|
||||
# Copyright 2024 Ahmed Tahri
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
|
||||
import h11
|
||||
from h11._state import _SWITCH_UPGRADE, ConnectionState
|
||||
|
||||
from ..._stream_matrix import StreamMatrix
|
||||
from ..._typing import HeadersType
|
||||
from ...events import (
|
||||
ConnectionTerminated,
|
||||
DataReceived,
|
||||
EarlyHeadersReceived,
|
||||
Event,
|
||||
HeadersReceived,
|
||||
)
|
||||
from .._protocols import HTTP1Protocol
|
||||
|
||||
|
||||
@lru_cache(maxsize=64)
|
||||
def capitalize_header_name(name: bytes) -> bytes:
|
||||
"""
|
||||
Take a header name and capitalize it.
|
||||
>>> capitalize_header_name(b"x-hEllo-wORLD")
|
||||
'X-Hello-World'
|
||||
>>> capitalize_header_name(b"server")
|
||||
'Server'
|
||||
>>> capitalize_header_name(b"contEnt-TYPE")
|
||||
'Content-Type'
|
||||
>>> capitalize_header_name(b"content_type")
|
||||
'Content-Type'
|
||||
"""
|
||||
return b"-".join(el.capitalize() for el in name.split(b"-"))
|
||||
|
||||
|
||||
def headers_to_request(headers: HeadersType) -> h11.Event:
|
||||
method = authority = path = host = None
|
||||
regular_headers = []
|
||||
|
||||
for name, value in headers:
|
||||
if name.startswith(b":"):
|
||||
if name == b":method":
|
||||
method = value
|
||||
elif name == b":scheme":
|
||||
pass
|
||||
elif name == b":authority":
|
||||
authority = value
|
||||
elif name == b":path":
|
||||
path = value
|
||||
else:
|
||||
raise ValueError("Unexpected request header: " + name.decode())
|
||||
else:
|
||||
if host is None and name == b"host":
|
||||
host = value
|
||||
|
||||
# We found that many projects... actually expect the header name to be sent capitalized... hardcoded
|
||||
# within their tests. Bad news, we have to keep doing this nonsense (namely capitalize_header_name)
|
||||
regular_headers.append((capitalize_header_name(name), value))
|
||||
|
||||
if authority is None:
|
||||
raise ValueError("Missing request header: :authority")
|
||||
|
||||
if method == b"CONNECT" and path is None:
|
||||
# CONNECT requests are a special case.
|
||||
target = authority
|
||||
else:
|
||||
target = path # type: ignore[assignment]
|
||||
|
||||
if host is None:
|
||||
regular_headers.insert(0, (b"Host", authority))
|
||||
elif host != authority:
|
||||
raise ValueError("Host header does not match :authority.")
|
||||
|
||||
return h11.Request(
|
||||
method=method, # type: ignore[arg-type]
|
||||
headers=regular_headers,
|
||||
target=target,
|
||||
)
|
||||
|
||||
|
||||
def headers_from_response(
|
||||
response: h11.InformationalResponse | h11.Response,
|
||||
) -> HeadersType:
|
||||
"""
|
||||
Converts an HTTP/1.0 or HTTP/1.1 response to HTTP/2-like headers.
|
||||
|
||||
Generates from pseudo (colon) headers from a response line.
|
||||
"""
|
||||
return [
|
||||
(b":status", str(response.status_code).encode("ascii"))
|
||||
] + response.headers.raw_items()
|
||||
|
||||
|
||||
class RelaxConnectionState(ConnectionState):
|
||||
def process_event( # type: ignore[no-untyped-def]
|
||||
self,
|
||||
role,
|
||||
event_type,
|
||||
server_switch_event=None,
|
||||
) -> None:
|
||||
if server_switch_event is not None:
|
||||
if server_switch_event not in self.pending_switch_proposals:
|
||||
if server_switch_event is _SWITCH_UPGRADE:
|
||||
warnings.warn(
|
||||
f"Received server {server_switch_event} event without a pending proposal. "
|
||||
"This will raise an exception in a future version. It is temporarily relaxed to match the "
|
||||
"legacy http.client standard library.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.pending_switch_proposals.add(_SWITCH_UPGRADE)
|
||||
|
||||
return super().process_event(role, event_type, server_switch_event)
|
||||
|
||||
|
||||
class HTTP1ProtocolHyperImpl(HTTP1Protocol):
|
||||
implementation: str = "h11"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._connection: h11.Connection = h11.Connection(h11.CLIENT)
|
||||
self._connection._cstate = RelaxConnectionState()
|
||||
|
||||
self._data_buffer: list[bytes] = []
|
||||
self._events: StreamMatrix = StreamMatrix()
|
||||
self._terminated: bool = False
|
||||
self._switched: bool = False
|
||||
|
||||
self._current_stream_id: int = 1
|
||||
|
||||
@staticmethod
|
||||
def exceptions() -> tuple[type[BaseException], ...]:
|
||||
return h11.LocalProtocolError, h11.ProtocolError, h11.RemoteProtocolError
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self._connection.our_state == self._connection.their_state == h11.IDLE
|
||||
|
||||
@property
|
||||
def max_stream_count(self) -> int:
|
||||
return 1
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
return self._connection.their_state in {
|
||||
h11.IDLE,
|
||||
h11.MUST_CLOSE,
|
||||
}
|
||||
|
||||
def has_expired(self) -> bool:
|
||||
return self._terminated
|
||||
|
||||
def get_available_stream_id(self) -> int:
|
||||
if not self.is_available():
|
||||
raise RuntimeError(
|
||||
"Cannot generate a new stream ID because the connection is not idle. "
|
||||
"HTTP/1.1 is not multiplexed and we do not support HTTP pipelining."
|
||||
)
|
||||
return self._current_stream_id
|
||||
|
||||
def submit_close(self, error_code: int = 0) -> None:
|
||||
pass # no-op
|
||||
|
||||
def submit_headers(
|
||||
self, stream_id: int, headers: HeadersType, end_stream: bool = False
|
||||
) -> None:
|
||||
if stream_id != self._current_stream_id:
|
||||
raise ValueError("Invalid stream ID.")
|
||||
|
||||
self._h11_submit(headers_to_request(headers))
|
||||
|
||||
if end_stream:
|
||||
self._h11_submit(h11.EndOfMessage())
|
||||
|
||||
def submit_data(
|
||||
self, stream_id: int, data: bytes, end_stream: bool = False
|
||||
) -> None:
|
||||
if stream_id != self._current_stream_id:
|
||||
raise ValueError("Invalid stream ID.")
|
||||
if self._connection.their_state == h11.SWITCHED_PROTOCOL:
|
||||
self._data_buffer.append(data)
|
||||
if end_stream:
|
||||
self._events.append(self._connection_terminated())
|
||||
return
|
||||
self._h11_submit(h11.Data(data))
|
||||
if end_stream:
|
||||
self._h11_submit(h11.EndOfMessage())
|
||||
|
||||
def submit_stream_reset(self, stream_id: int, error_code: int = 0) -> None:
|
||||
# HTTP/1 cannot submit a stream (it does not have real streams).
|
||||
# But if there are no other streams, we can close the connection instead.
|
||||
self.connection_lost()
|
||||
|
||||
def connection_lost(self) -> None:
|
||||
if self._connection.their_state == h11.SWITCHED_PROTOCOL:
|
||||
self._events.append(self._connection_terminated())
|
||||
return
|
||||
# This method is called when the connection is closed without an EOF.
|
||||
# But not all connections support EOF, so being here does not
|
||||
# necessarily mean that something when wrong.
|
||||
#
|
||||
# The tricky part is that HTTP/1.0 server can send responses
|
||||
# without Content-Length or Transfer-Encoding headers,
|
||||
# meaning that a response body is closed with the connection.
|
||||
# In such cases, we require a proper EOF to distinguish complete
|
||||
# messages from partial messages interrupted by network failure.
|
||||
if not self._terminated:
|
||||
self._connection.send_failed()
|
||||
self._events.append(self._connection_terminated())
|
||||
|
||||
def eof_received(self) -> None:
|
||||
if self._connection.their_state == h11.SWITCHED_PROTOCOL:
|
||||
self._events.append(self._connection_terminated())
|
||||
return
|
||||
self._h11_data_received(b"")
|
||||
|
||||
def bytes_received(self, data: bytes) -> None:
|
||||
if not data:
|
||||
return # h11 treats empty data as EOF.
|
||||
if self._connection.their_state == h11.SWITCHED_PROTOCOL:
|
||||
self._events.append(DataReceived(self._current_stream_id, data))
|
||||
return
|
||||
else:
|
||||
self._h11_data_received(data)
|
||||
|
||||
def bytes_to_send(self) -> bytes:
|
||||
data = b"".join(self._data_buffer)
|
||||
self._data_buffer.clear()
|
||||
self._maybe_start_next_cycle()
|
||||
return data
|
||||
|
||||
def next_event(self, stream_id: int | None = None) -> Event | None:
|
||||
return self._events.popleft(stream_id=stream_id)
|
||||
|
||||
def has_pending_event(
|
||||
self,
|
||||
*,
|
||||
stream_id: int | None = None,
|
||||
excl_event: tuple[type[Event], ...] | None = None,
|
||||
) -> bool:
|
||||
return self._events.has(stream_id=stream_id, excl_event=excl_event)
|
||||
|
||||
def _h11_submit(self, h11_event: h11.Event) -> None:
|
||||
chunks = self._connection.send_with_data_passthrough(h11_event)
|
||||
if chunks:
|
||||
self._data_buffer += chunks
|
||||
|
||||
def _h11_data_received(self, data: bytes) -> None:
|
||||
self._connection.receive_data(data)
|
||||
self._fetch_events()
|
||||
|
||||
def _fetch_events(self) -> None:
|
||||
a = self._events.append
|
||||
while not self._terminated:
|
||||
try:
|
||||
h11_event = self._connection.next_event()
|
||||
except h11.RemoteProtocolError as e:
|
||||
a(self._connection_terminated(e.error_status_hint, str(e)))
|
||||
break
|
||||
|
||||
ev_type = h11_event.__class__
|
||||
|
||||
if h11_event is h11.NEED_DATA or h11_event is h11.PAUSED:
|
||||
if h11.MUST_CLOSE == self._connection.their_state:
|
||||
a(self._connection_terminated())
|
||||
else:
|
||||
break
|
||||
elif ev_type is h11.Response:
|
||||
a(
|
||||
HeadersReceived(
|
||||
self._current_stream_id,
|
||||
headers_from_response(h11_event), # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
elif ev_type is h11.InformationalResponse:
|
||||
a(
|
||||
EarlyHeadersReceived(
|
||||
stream_id=self._current_stream_id,
|
||||
headers=headers_from_response(h11_event), # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
elif ev_type is h11.Data:
|
||||
# officially h11 typed data as "bytes"
|
||||
# but we... found that it store bytearray sometime.
|
||||
payload = h11_event.data # type: ignore[union-attr]
|
||||
a(
|
||||
DataReceived(
|
||||
self._current_stream_id,
|
||||
bytes(payload) if payload.__class__ is bytearray else payload,
|
||||
)
|
||||
)
|
||||
elif ev_type is h11.EndOfMessage:
|
||||
# HTTP/2 and HTTP/3 send END_STREAM flag with HEADERS and DATA frames.
|
||||
# We emulate similar behavior for HTTP/1.
|
||||
if h11_event.headers: # type: ignore[union-attr]
|
||||
last_event: HeadersReceived | DataReceived = HeadersReceived(
|
||||
self._current_stream_id,
|
||||
h11_event.headers, # type: ignore[union-attr]
|
||||
self._connection.their_state != h11.MIGHT_SWITCH_PROTOCOL, # type: ignore[attr-defined]
|
||||
)
|
||||
else:
|
||||
last_event = DataReceived(
|
||||
self._current_stream_id,
|
||||
b"",
|
||||
self._connection.their_state != h11.MIGHT_SWITCH_PROTOCOL, # type: ignore[attr-defined]
|
||||
)
|
||||
a(last_event)
|
||||
self._maybe_start_next_cycle()
|
||||
elif ev_type is h11.ConnectionClosed:
|
||||
a(self._connection_terminated())
|
||||
|
||||
def _connection_terminated(
|
||||
self, error_code: int = 0, message: str | None = None
|
||||
) -> Event:
|
||||
self._terminated = True
|
||||
return ConnectionTerminated(error_code, message)
|
||||
|
||||
def _maybe_start_next_cycle(self) -> None:
|
||||
if h11.DONE == self._connection.our_state == self._connection.their_state:
|
||||
self._connection.start_next_cycle()
|
||||
self._current_stream_id += 1
|
||||
if h11.SWITCHED_PROTOCOL == self._connection.their_state and not self._switched:
|
||||
data, closed = self._connection.trailing_data
|
||||
if data:
|
||||
self._events.append(DataReceived(self._current_stream_id, data))
|
||||
self._switched = True
|
||||
|
||||
def reshelve(self, *events: Event) -> None:
|
||||
for ev in reversed(events):
|
||||
self._events.appendleft(ev)
|
||||
|
||||
def ping(self) -> None:
|
||||
raise NotImplementedError("http1 does not support PING")
|
||||
@@ -0,0 +1,21 @@
|
||||
# Copyright 2022 Akamai Technologies, Inc
|
||||
# Largely rewritten in 2023 for urllib3-future
|
||||
# Copyright 2024 Ahmed Tahri
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ._h2 import HTTP2ProtocolHyperImpl
|
||||
|
||||
__all__ = ("HTTP2ProtocolHyperImpl",)
|
||||
@@ -0,0 +1,312 @@
|
||||
# Copyright 2022 Akamai Technologies, Inc
|
||||
# Largely rewritten in 2023 for urllib3-future
|
||||
# Copyright 2024 Ahmed Tahri
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from secrets import token_bytes
|
||||
from typing import Iterator
|
||||
|
||||
import jh2.config # type: ignore
|
||||
import jh2.connection # type: ignore
|
||||
import jh2.errors # type: ignore
|
||||
import jh2.events # type: ignore
|
||||
import jh2.exceptions # type: ignore
|
||||
import jh2.settings # type: ignore
|
||||
|
||||
from ..._stream_matrix import StreamMatrix
|
||||
from ..._typing import HeadersType
|
||||
from ...events import (
|
||||
ConnectionTerminated,
|
||||
DataReceived,
|
||||
EarlyHeadersReceived,
|
||||
Event,
|
||||
GoawayReceived,
|
||||
HandshakeCompleted,
|
||||
HeadersReceived,
|
||||
StreamResetReceived,
|
||||
)
|
||||
from .._protocols import HTTP2Protocol
|
||||
|
||||
|
||||
class _PatchedH2Connection(jh2.connection.H2Connection): # type: ignore[misc]
|
||||
"""
|
||||
This is a performance hotfix class. We internally, already keep
|
||||
track of the open stream count.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: jh2.config.H2Configuration | None = None,
|
||||
observable_impl: HTTP2ProtocolHyperImpl | None = None,
|
||||
) -> None:
|
||||
super().__init__(config=config)
|
||||
# by default CONNECT is disabled
|
||||
# we need it to support natively WebSocket over HTTP/2 for example.
|
||||
self.local_settings = jh2.settings.Settings(
|
||||
client=True,
|
||||
initial_values={
|
||||
jh2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 100,
|
||||
jh2.settings.SettingCodes.MAX_HEADER_LIST_SIZE: self.DEFAULT_MAX_HEADER_LIST_SIZE,
|
||||
jh2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL: 1,
|
||||
},
|
||||
)
|
||||
self._observable_impl = observable_impl
|
||||
|
||||
def _open_streams(self, *args, **kwargs) -> int: # type: ignore[no-untyped-def]
|
||||
if self._observable_impl is not None:
|
||||
return self._observable_impl._open_stream_count
|
||||
return super()._open_streams(*args, **kwargs) # type: ignore[no-any-return]
|
||||
|
||||
def _receive_goaway_frame(self, frame): # type: ignore[no-untyped-def]
|
||||
"""
|
||||
Receive a GOAWAY frame on the connection.
|
||||
We purposely override this method to work around a known bug of jh2.
|
||||
"""
|
||||
events = self.state_machine.process_input(
|
||||
jh2.connection.ConnectionInputs.RECV_GOAWAY
|
||||
)
|
||||
|
||||
err_code = jh2.errors._error_code_from_int(frame.error_code)
|
||||
|
||||
# GOAWAY allows an
|
||||
# endpoint to gracefully stop accepting new streams while still
|
||||
# finishing processing of previously established streams.
|
||||
# see https://tools.ietf.org/html/rfc7540#section-6.8
|
||||
# hyper/h2 does not allow such a thing for now. let's work around this.
|
||||
if (
|
||||
err_code == 0
|
||||
and self._observable_impl is not None
|
||||
and self._observable_impl._open_stream_count > 0
|
||||
):
|
||||
self.state_machine.state = jh2.connection.ConnectionState.CLIENT_OPEN
|
||||
|
||||
# Clear the outbound data buffer: we cannot send further data now.
|
||||
self.clear_outbound_data_buffer()
|
||||
|
||||
# Fire an appropriate ConnectionTerminated event.
|
||||
new_event = jh2.events.ConnectionTerminated()
|
||||
new_event.error_code = err_code
|
||||
new_event.last_stream_id = frame.last_stream_id
|
||||
new_event.additional_data = (
|
||||
frame.additional_data if frame.additional_data else None
|
||||
)
|
||||
events.append(new_event)
|
||||
|
||||
return [], events
|
||||
|
||||
|
||||
HEADER_OR_TRAILER_TYPE_SET = {
|
||||
jh2.events.ResponseReceived,
|
||||
jh2.events.TrailersReceived,
|
||||
}
|
||||
|
||||
|
||||
class HTTP2ProtocolHyperImpl(HTTP2Protocol):
|
||||
implementation: str = "h2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
validate_outbound_headers: bool = False,
|
||||
validate_inbound_headers: bool = False,
|
||||
normalize_outbound_headers: bool = False,
|
||||
normalize_inbound_headers: bool = True,
|
||||
) -> None:
|
||||
self._connection: jh2.connection.H2Connection = _PatchedH2Connection(
|
||||
jh2.config.H2Configuration(
|
||||
client_side=True,
|
||||
validate_outbound_headers=validate_outbound_headers,
|
||||
normalize_outbound_headers=normalize_outbound_headers,
|
||||
validate_inbound_headers=validate_inbound_headers,
|
||||
normalize_inbound_headers=normalize_inbound_headers,
|
||||
),
|
||||
observable_impl=self,
|
||||
)
|
||||
self._open_stream_count: int = 0
|
||||
self._connection.initiate_connection()
|
||||
self._connection.increment_flow_control_window(2**24)
|
||||
self._events: StreamMatrix = StreamMatrix()
|
||||
self._terminated: bool = False
|
||||
self._goaway_to_honor: bool = False
|
||||
self._max_stream_count: int = (
|
||||
self._connection.remote_settings.max_concurrent_streams
|
||||
)
|
||||
self._max_frame_size: int = self._connection.remote_settings.max_frame_size
|
||||
|
||||
def max_frame_size(self) -> int:
|
||||
return self._max_frame_size
|
||||
|
||||
@staticmethod
|
||||
def exceptions() -> tuple[type[BaseException], ...]:
|
||||
return jh2.exceptions.ProtocolError, jh2.exceptions.H2Error
|
||||
|
||||
def is_available(self) -> bool:
|
||||
if self._terminated:
|
||||
return False
|
||||
return self._max_stream_count > self._open_stream_count
|
||||
|
||||
@property
|
||||
def max_stream_count(self) -> int:
|
||||
return self._max_stream_count
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
return self._terminated is False and self._open_stream_count == 0
|
||||
|
||||
def has_expired(self) -> bool:
|
||||
return self._terminated or self._goaway_to_honor
|
||||
|
||||
def get_available_stream_id(self) -> int:
|
||||
return self._connection.get_next_available_stream_id() # type: ignore[no-any-return]
|
||||
|
||||
def submit_close(self, error_code: int = 0) -> None:
|
||||
self._connection.close_connection(error_code)
|
||||
|
||||
def submit_headers(
|
||||
self, stream_id: int, headers: HeadersType, end_stream: bool = False
|
||||
) -> None:
|
||||
self._connection.send_headers(stream_id, headers, end_stream)
|
||||
self._connection.increment_flow_control_window(2**24, stream_id=stream_id)
|
||||
self._open_stream_count += 1
|
||||
|
||||
def submit_data(
|
||||
self, stream_id: int, data: bytes, end_stream: bool = False
|
||||
) -> None:
|
||||
self._connection.send_data(stream_id, data, end_stream)
|
||||
|
||||
def submit_stream_reset(self, stream_id: int, error_code: int = 0) -> None:
|
||||
self._connection.reset_stream(stream_id, error_code)
|
||||
|
||||
def next_event(self, stream_id: int | None = None) -> Event | None:
|
||||
return self._events.popleft(stream_id=stream_id)
|
||||
|
||||
def has_pending_event(
|
||||
self,
|
||||
*,
|
||||
stream_id: int | None = None,
|
||||
excl_event: tuple[type[Event], ...] | None = None,
|
||||
) -> bool:
|
||||
return self._events.has(stream_id=stream_id, excl_event=excl_event)
|
||||
|
||||
def _map_events(self, h2_events: list[jh2.events.Event]) -> Iterator[Event]:
|
||||
for e in h2_events:
|
||||
ev_type = e.__class__
|
||||
|
||||
if ev_type in HEADER_OR_TRAILER_TYPE_SET:
|
||||
end_stream = e.stream_ended is not None
|
||||
if end_stream:
|
||||
self._open_stream_count -= 1
|
||||
stream = self._connection.streams.pop(e.stream_id)
|
||||
self._connection._closed_streams[e.stream_id] = stream.closed_by
|
||||
yield HeadersReceived(e.stream_id, e.headers, end_stream=end_stream)
|
||||
elif ev_type is jh2.events.DataReceived:
|
||||
end_stream = e.stream_ended is not None
|
||||
if end_stream:
|
||||
self._open_stream_count -= 1
|
||||
stream = self._connection.streams.pop(e.stream_id)
|
||||
self._connection._closed_streams[e.stream_id] = stream.closed_by
|
||||
self._connection.acknowledge_received_data(
|
||||
e.flow_controlled_length, e.stream_id
|
||||
)
|
||||
yield DataReceived(e.stream_id, e.data, end_stream=end_stream)
|
||||
elif ev_type is jh2.events.InformationalResponseReceived:
|
||||
yield EarlyHeadersReceived(
|
||||
e.stream_id,
|
||||
e.headers,
|
||||
)
|
||||
elif ev_type is jh2.events.StreamReset:
|
||||
self._open_stream_count -= 1
|
||||
# event StreamEnded may occur before StreamReset
|
||||
if e.stream_id in self._connection.streams:
|
||||
stream = self._connection.streams.pop(e.stream_id)
|
||||
self._connection._closed_streams[e.stream_id] = stream.closed_by
|
||||
yield StreamResetReceived(e.stream_id, e.error_code)
|
||||
elif ev_type is jh2.events.ConnectionTerminated:
|
||||
# ConnectionTerminated from h2 means that GOAWAY was received.
|
||||
# A server can send GOAWAY for graceful shutdown, where clients
|
||||
# do not open new streams, but inflight requests can be completed.
|
||||
#
|
||||
# Saying "connection was terminated" can be confusing,
|
||||
# so we emit an event called "GoawayReceived".
|
||||
if e.error_code == 0:
|
||||
self._goaway_to_honor = True
|
||||
yield GoawayReceived(e.last_stream_id, e.error_code)
|
||||
else:
|
||||
self._terminated = True
|
||||
yield ConnectionTerminated(e.error_code, None)
|
||||
elif ev_type in {
|
||||
jh2.events.SettingsAcknowledged,
|
||||
jh2.events.RemoteSettingsChanged,
|
||||
}:
|
||||
yield HandshakeCompleted(alpn_protocol="h2")
|
||||
|
||||
def connection_lost(self) -> None:
|
||||
self._connection_terminated()
|
||||
|
||||
def eof_received(self) -> None:
|
||||
self._connection_terminated()
|
||||
|
||||
def bytes_received(self, data: bytes) -> None:
|
||||
if not data:
|
||||
return
|
||||
|
||||
try:
|
||||
h2_events = self._connection.receive_data(data)
|
||||
except jh2.exceptions.ProtocolError as e:
|
||||
self._connection_terminated(e.error_code, str(e))
|
||||
else:
|
||||
self._events.extend(self._map_events(h2_events))
|
||||
|
||||
# we want to perpetually mark the connection as "saturated"
|
||||
if self._goaway_to_honor:
|
||||
self._max_stream_count = self._open_stream_count
|
||||
|
||||
if self._connection.remote_settings.has_update:
|
||||
if not self._goaway_to_honor:
|
||||
self._max_stream_count = (
|
||||
self._connection.remote_settings.max_concurrent_streams
|
||||
)
|
||||
self._max_frame_size = self._connection.remote_settings.max_frame_size
|
||||
|
||||
def bytes_to_send(self) -> bytes:
|
||||
return self._connection.data_to_send() # type: ignore[no-any-return]
|
||||
|
||||
def _connection_terminated(
|
||||
self, error_code: int = 0, message: str | None = None
|
||||
) -> None:
|
||||
if self._terminated:
|
||||
return
|
||||
error_code = int(error_code) # Convert h2 IntEnum to an actual int
|
||||
self._terminated = True
|
||||
self._events.append(ConnectionTerminated(error_code, message))
|
||||
|
||||
def should_wait_remote_flow_control(
|
||||
self, stream_id: int, amt: int | None = None
|
||||
) -> bool | None:
|
||||
flow_remaining_bytes: int = self._connection.local_flow_control_window(
|
||||
stream_id
|
||||
)
|
||||
|
||||
if amt is None:
|
||||
return flow_remaining_bytes == 0
|
||||
|
||||
return amt > flow_remaining_bytes
|
||||
|
||||
def reshelve(self, *events: Event) -> None:
|
||||
for ev in reversed(events):
|
||||
self._events.appendleft(ev)
|
||||
|
||||
def ping(self) -> None:
|
||||
self._connection.ping(token_bytes(8))
|
||||
@@ -0,0 +1,21 @@
|
||||
# Copyright 2022 Akamai Technologies, Inc
|
||||
# Largely rewritten in 2023 for urllib3-future
|
||||
# Copyright 2024 Ahmed Tahri
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ._qh3 import HTTP3ProtocolAioQuicImpl
|
||||
|
||||
__all__ = ("HTTP3ProtocolAioQuicImpl",)
|
||||
@@ -0,0 +1,592 @@
|
||||
# Copyright 2022 Akamai Technologies, Inc
|
||||
# Largely rewritten in 2023 for urllib3-future
|
||||
# Copyright 2024 Ahmed Tahri
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import ssl
|
||||
import typing
|
||||
from collections import deque
|
||||
from os import environ
|
||||
from random import randint
|
||||
from time import time as monotonic
|
||||
from typing import Any, Iterable, Sequence
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from typing_extensions import Literal
|
||||
|
||||
from qh3 import (
|
||||
CipherSuite,
|
||||
H3Connection,
|
||||
H3Error,
|
||||
ProtocolError,
|
||||
QuicConfiguration,
|
||||
QuicConnection,
|
||||
QuicConnectionError,
|
||||
QuicFileLogger,
|
||||
SessionTicket,
|
||||
h3_events,
|
||||
quic_events,
|
||||
)
|
||||
from qh3.h3.connection import FrameType
|
||||
from qh3.quic.connection import QuicConnectionState
|
||||
|
||||
from ..._configuration import QuicTLSConfig
|
||||
from ..._stream_matrix import StreamMatrix
|
||||
from ..._typing import AddressType, HeadersType
|
||||
from ...events import (
|
||||
ConnectionTerminated,
|
||||
DataReceived,
|
||||
EarlyHeadersReceived,
|
||||
Event,
|
||||
GoawayReceived,
|
||||
)
|
||||
from ...events import HandshakeCompleted as _HandshakeCompleted
|
||||
from ...events import HeadersReceived, StreamResetReceived
|
||||
from .._protocols import HTTP3Protocol
|
||||
|
||||
|
||||
QUIC_RELEVANT_EVENT_TYPES = {
|
||||
quic_events.HandshakeCompleted,
|
||||
quic_events.ConnectionTerminated,
|
||||
quic_events.StreamReset,
|
||||
}
|
||||
|
||||
|
||||
class HTTP3ProtocolAioQuicImpl(HTTP3Protocol):
|
||||
implementation: str = "qh3"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
remote_address: AddressType,
|
||||
server_name: str,
|
||||
tls_config: QuicTLSConfig,
|
||||
) -> None:
|
||||
keylogfile_path: str | None = environ.get("SSLKEYLOGFILE", None)
|
||||
qlogdir_path: str | None = environ.get("QUICLOGDIR", None)
|
||||
|
||||
self._configuration: QuicConfiguration = QuicConfiguration(
|
||||
is_client=True,
|
||||
verify_mode=ssl.CERT_NONE if tls_config.insecure else ssl.CERT_REQUIRED,
|
||||
cafile=tls_config.cafile,
|
||||
capath=tls_config.capath,
|
||||
cadata=tls_config.cadata,
|
||||
alpn_protocols=["h3"],
|
||||
session_ticket=tls_config.session_ticket,
|
||||
server_name=server_name,
|
||||
hostname_checks_common_name=tls_config.cert_use_common_name,
|
||||
assert_fingerprint=tls_config.cert_fingerprint,
|
||||
verify_hostname=tls_config.verify_hostname,
|
||||
secrets_log_file=open(keylogfile_path, "w") if keylogfile_path else None, # type: ignore[arg-type]
|
||||
quic_logger=QuicFileLogger(qlogdir_path) if qlogdir_path else None,
|
||||
idle_timeout=tls_config.idle_timeout,
|
||||
max_data=2**24,
|
||||
max_stream_data=2**24,
|
||||
)
|
||||
|
||||
if tls_config.ciphers:
|
||||
available_ciphers = {c.name: c for c in CipherSuite}
|
||||
chosen_ciphers: list[CipherSuite] = []
|
||||
|
||||
for cipher in tls_config.ciphers:
|
||||
if "name" in cipher and isinstance(cipher["name"], str):
|
||||
chosen_ciphers.append(
|
||||
available_ciphers[cipher["name"].replace("TLS_", "")]
|
||||
)
|
||||
|
||||
if len(chosen_ciphers) == 0:
|
||||
raise ValueError(
|
||||
f"Unable to find a compatible cipher in '{tls_config.ciphers}' to establish a QUIC connection. "
|
||||
f"QUIC support one of '{['TLS_' + e for e in available_ciphers.keys()]}' only."
|
||||
)
|
||||
|
||||
self._configuration.cipher_suites = chosen_ciphers
|
||||
|
||||
if tls_config.certfile:
|
||||
self._configuration.load_cert_chain(
|
||||
tls_config.certfile,
|
||||
tls_config.keyfile,
|
||||
tls_config.keypassword,
|
||||
)
|
||||
|
||||
self._quic: QuicConnection = QuicConnection(configuration=self._configuration)
|
||||
self._connection_ids: set[bytes] = set()
|
||||
self._remote_address = remote_address
|
||||
self._events: StreamMatrix = StreamMatrix()
|
||||
self._packets: deque[bytes] = deque()
|
||||
self._http: H3Connection | None = None
|
||||
self._terminated: bool = False
|
||||
self._data_in_flight: bool = False
|
||||
self._open_stream_count: int = 0
|
||||
self._total_stream_count: int = 0
|
||||
self._goaway_to_honor: bool = False
|
||||
self._max_stream_count: int = (
|
||||
100 # safe-default, broadly used. (and set by qh3)
|
||||
)
|
||||
self._max_frame_size: int | None = None
|
||||
|
||||
@staticmethod
|
||||
def exceptions() -> tuple[type[BaseException], ...]:
|
||||
return ProtocolError, H3Error, QuicConnectionError, AssertionError
|
||||
|
||||
@property
|
||||
def max_stream_count(self) -> int:
|
||||
return self._max_stream_count
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return (
|
||||
self._terminated is False
|
||||
and self._max_stream_count > self._quic.open_outbound_streams
|
||||
)
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
return self._terminated is False and self._open_stream_count == 0
|
||||
|
||||
def has_expired(self) -> bool:
|
||||
if not self._terminated and not self._goaway_to_honor:
|
||||
now = monotonic()
|
||||
self._quic.handle_timer(now)
|
||||
self._packets.extend(
|
||||
map(lambda e: e[0], self._quic.datagrams_to_send(now=now))
|
||||
)
|
||||
if self._quic._state in {
|
||||
QuicConnectionState.CLOSING,
|
||||
QuicConnectionState.TERMINATED,
|
||||
}:
|
||||
self._terminated = True
|
||||
if (
|
||||
hasattr(self._quic, "_close_event")
|
||||
and self._quic._close_event is not None
|
||||
):
|
||||
self._events.extend(self._map_quic_event(self._quic._close_event))
|
||||
self._terminated = True
|
||||
return self._terminated or self._goaway_to_honor
|
||||
|
||||
@property
|
||||
def session_ticket(self) -> SessionTicket | None:
|
||||
return self._quic.tls.session_ticket if self._quic and self._quic.tls else None
|
||||
|
||||
def get_available_stream_id(self) -> int:
|
||||
return self._quic.get_next_available_stream_id()
|
||||
|
||||
def submit_close(self, error_code: int = 0) -> None:
|
||||
# QUIC has two different frame types for closing the connection.
|
||||
# From RFC 9000 (QUIC: A UDP-Based Multiplexed and Secure Transport):
|
||||
#
|
||||
# > An endpoint sends a CONNECTION_CLOSE frame (type=0x1c or 0x1d)
|
||||
# > to notify its peer that the connection is being closed.
|
||||
# > The CONNECTION_CLOSE frame with a type of 0x1c is used to signal errors
|
||||
# > at only the QUIC layer, or the absence of errors (with the NO_ERROR code).
|
||||
# > The CONNECTION_CLOSE frame with a type of 0x1d is used
|
||||
# > to signal an error with the application that uses QUIC.
|
||||
frame_type = 0x1D if error_code else 0x1C
|
||||
self._quic.close(error_code=error_code, frame_type=frame_type)
|
||||
|
||||
def submit_headers(
|
||||
self, stream_id: int, headers: HeadersType, end_stream: bool = False
|
||||
) -> None:
|
||||
assert self._http is not None
|
||||
self._open_stream_count += 1
|
||||
self._total_stream_count += 1
|
||||
self._http.send_headers(stream_id, list(headers), end_stream)
|
||||
|
||||
def submit_data(
|
||||
self, stream_id: int, data: bytes, end_stream: bool = False
|
||||
) -> None:
|
||||
assert self._http is not None
|
||||
self._http.send_data(stream_id, data, end_stream)
|
||||
if end_stream is False:
|
||||
self._data_in_flight = True
|
||||
|
||||
def submit_stream_reset(self, stream_id: int, error_code: int = 0) -> None:
|
||||
self._quic.reset_stream(stream_id, error_code)
|
||||
|
||||
def next_event(self, stream_id: int | None = None) -> Event | None:
|
||||
return self._events.popleft(stream_id=stream_id)
|
||||
|
||||
def has_pending_event(
|
||||
self,
|
||||
*,
|
||||
stream_id: int | None = None,
|
||||
excl_event: tuple[type[Event], ...] | None = None,
|
||||
) -> bool:
|
||||
return self._events.has(stream_id=stream_id, excl_event=excl_event)
|
||||
|
||||
@property
|
||||
def connection_ids(self) -> Sequence[bytes]:
|
||||
return list(self._connection_ids)
|
||||
|
||||
def connection_lost(self) -> None:
|
||||
self._terminated = True
|
||||
self._events.append(ConnectionTerminated())
|
||||
|
||||
def bytes_received(self, data: bytes) -> None:
|
||||
self._quic.receive_datagram(data, self._remote_address, now=monotonic())
|
||||
self._fetch_events()
|
||||
|
||||
if self._data_in_flight:
|
||||
self._data_in_flight = False
|
||||
|
||||
# we want to perpetually mark the connection as "saturated"
|
||||
if self._goaway_to_honor:
|
||||
self._max_stream_count = self._open_stream_count
|
||||
else:
|
||||
# This section may confuse beginners
|
||||
# See RFC 9000 -> 19.11. MAX_STREAMS Frames
|
||||
# footer extract:
|
||||
# Note that these frames (and the corresponding transport parameters)
|
||||
# do not describe the number of streams that can be opened
|
||||
# concurrently. The limit includes streams that have been closed as
|
||||
# well as those that are open.
|
||||
#
|
||||
# so, finding that remote_max_streams_bidi is increasing constantly is normal.
|
||||
new_stream_limit = (
|
||||
self._quic._remote_max_streams_bidi - self._total_stream_count
|
||||
)
|
||||
|
||||
if (
|
||||
new_stream_limit
|
||||
and new_stream_limit != self._max_stream_count
|
||||
and new_stream_limit > 0
|
||||
):
|
||||
self._max_stream_count = new_stream_limit
|
||||
|
||||
if (
|
||||
self._quic._remote_max_stream_data_bidi_remote
|
||||
and self._quic._remote_max_stream_data_bidi_remote
|
||||
!= self._max_frame_size
|
||||
):
|
||||
self._max_frame_size = self._quic._remote_max_stream_data_bidi_remote
|
||||
|
||||
def bytes_to_send(self) -> bytes:
|
||||
if not self._packets:
|
||||
now = monotonic()
|
||||
|
||||
if self._http is None:
|
||||
self._quic.connect(self._remote_address, now=now)
|
||||
self._http = H3Connection(self._quic)
|
||||
|
||||
# the QUIC state machine returns datagrams (addr, packet)
|
||||
# the client never have to worry about the destination
|
||||
# unless server yield a preferred address?
|
||||
self._packets.extend(
|
||||
map(lambda e: e[0], self._quic.datagrams_to_send(now=now))
|
||||
)
|
||||
|
||||
if not self._packets:
|
||||
return b""
|
||||
|
||||
# it is absolutely crucial to return one at a time
|
||||
# because UDP don't support sending more than
|
||||
# MTU (to be more precise, lowest MTU in the network path from A (you) to B (server))
|
||||
return self._packets.popleft()
|
||||
|
||||
def _fetch_events(self) -> None:
|
||||
assert self._http is not None
|
||||
|
||||
for quic_event in iter(self._quic.next_event, None):
|
||||
self._events.extend(self._map_quic_event(quic_event))
|
||||
for h3_event in self._http.handle_event(quic_event):
|
||||
self._events.extend(self._map_h3_event(h3_event))
|
||||
|
||||
if hasattr(self._quic, "_close_event") and self._quic._close_event is not None:
|
||||
self._events.extend(self._map_quic_event(self._quic._close_event))
|
||||
|
||||
def _map_quic_event(self, quic_event: quic_events.QuicEvent) -> Iterable[Event]:
|
||||
ev_type = quic_event.__class__
|
||||
|
||||
# fastest path execution, most of the time we don't have those
|
||||
# 3 event types.
|
||||
if ev_type not in QUIC_RELEVANT_EVENT_TYPES:
|
||||
return
|
||||
|
||||
if ev_type is quic_events.HandshakeCompleted:
|
||||
yield _HandshakeCompleted(quic_event.alpn_protocol) # type: ignore[attr-defined]
|
||||
elif ev_type is quic_events.ConnectionTerminated:
|
||||
if quic_event.frame_type == FrameType.GOAWAY.value: # type: ignore[attr-defined]
|
||||
self._goaway_to_honor = True
|
||||
stream_list: list[int] = [
|
||||
e for e in self._events._matrix.keys() if e is not None
|
||||
]
|
||||
yield GoawayReceived(stream_list[-1], quic_event.error_code) # type: ignore[attr-defined]
|
||||
else:
|
||||
self._terminated = True
|
||||
yield ConnectionTerminated(
|
||||
quic_event.error_code, # type: ignore[attr-defined]
|
||||
quic_event.reason_phrase, # type: ignore[attr-defined]
|
||||
)
|
||||
elif ev_type is quic_events.StreamReset:
|
||||
self._open_stream_count -= 1
|
||||
yield StreamResetReceived(quic_event.stream_id, quic_event.error_code) # type: ignore[attr-defined]
|
||||
|
||||
def _map_h3_event(self, h3_event: h3_events.H3Event) -> Iterable[Event]:
|
||||
ev_type = h3_event.__class__
|
||||
|
||||
if ev_type is h3_events.HeadersReceived:
|
||||
if h3_event.stream_ended: # type: ignore[attr-defined]
|
||||
self._open_stream_count -= 1
|
||||
yield HeadersReceived(
|
||||
h3_event.stream_id, # type: ignore[attr-defined]
|
||||
h3_event.headers, # type: ignore[attr-defined]
|
||||
h3_event.stream_ended, # type: ignore[attr-defined]
|
||||
)
|
||||
elif ev_type is h3_events.DataReceived:
|
||||
if h3_event.stream_ended: # type: ignore[attr-defined]
|
||||
self._open_stream_count -= 1
|
||||
yield DataReceived(h3_event.stream_id, h3_event.data, h3_event.stream_ended) # type: ignore[attr-defined]
|
||||
elif ev_type is h3_events.InformationalHeadersReceived:
|
||||
yield EarlyHeadersReceived(
|
||||
h3_event.stream_id, # type: ignore[attr-defined]
|
||||
h3_event.headers, # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
def should_wait_remote_flow_control(
|
||||
self, stream_id: int, amt: int | None = None
|
||||
) -> bool | None:
|
||||
return self._data_in_flight
|
||||
|
||||
@typing.overload
|
||||
def getissuercert(self, *, binary_form: Literal[True]) -> bytes | None: ...
|
||||
|
||||
@typing.overload
|
||||
def getissuercert(
|
||||
self, *, binary_form: Literal[False] = ...
|
||||
) -> dict[str, Any] | None: ...
|
||||
|
||||
def getissuercert(
|
||||
self, *, binary_form: bool = False
|
||||
) -> bytes | dict[str, typing.Any] | None:
|
||||
x509_certificate = self._quic.get_peercert()
|
||||
|
||||
if x509_certificate is None:
|
||||
raise ValueError("TLS handshake has not been done yet")
|
||||
|
||||
if not self._quic.get_issuercerts():
|
||||
return None
|
||||
|
||||
x509_certificate = self._quic.get_issuercerts()[0]
|
||||
|
||||
if binary_form:
|
||||
return x509_certificate.public_bytes()
|
||||
|
||||
datetime.datetime.fromtimestamp(
|
||||
x509_certificate.not_valid_before, tz=datetime.timezone.utc
|
||||
)
|
||||
|
||||
issuer_info = {
|
||||
"version": x509_certificate.version + 1,
|
||||
"serialNumber": x509_certificate.serial_number.upper(),
|
||||
"subject": [],
|
||||
"issuer": [],
|
||||
"notBefore": datetime.datetime.fromtimestamp(
|
||||
x509_certificate.not_valid_before, tz=datetime.timezone.utc
|
||||
).strftime("%b %d %H:%M:%S %Y")
|
||||
+ " UTC",
|
||||
"notAfter": datetime.datetime.fromtimestamp(
|
||||
x509_certificate.not_valid_after, tz=datetime.timezone.utc
|
||||
).strftime("%b %d %H:%M:%S %Y")
|
||||
+ " UTC",
|
||||
}
|
||||
|
||||
_short_name_assoc = {
|
||||
"CN": "commonName",
|
||||
"L": "localityName",
|
||||
"ST": "stateOrProvinceName",
|
||||
"O": "organizationName",
|
||||
"OU": "organizationalUnitName",
|
||||
"C": "countryName",
|
||||
"STREET": "streetAddress",
|
||||
"DC": "domainComponent",
|
||||
"E": "email",
|
||||
}
|
||||
|
||||
for raw_oid, rfc4514_attribute_name, value in x509_certificate.subject:
|
||||
if rfc4514_attribute_name not in _short_name_assoc:
|
||||
continue
|
||||
issuer_info["subject"].append( # type: ignore[attr-defined]
|
||||
(
|
||||
(
|
||||
_short_name_assoc[rfc4514_attribute_name],
|
||||
value.decode(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
for raw_oid, rfc4514_attribute_name, value in x509_certificate.issuer:
|
||||
if rfc4514_attribute_name not in _short_name_assoc:
|
||||
continue
|
||||
issuer_info["issuer"].append( # type: ignore[attr-defined]
|
||||
(
|
||||
(
|
||||
_short_name_assoc[rfc4514_attribute_name],
|
||||
value.decode(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return issuer_info
|
||||
|
||||
@typing.overload
|
||||
def getpeercert(self, *, binary_form: Literal[True]) -> bytes: ...
|
||||
|
||||
@typing.overload
|
||||
def getpeercert(self, *, binary_form: Literal[False] = ...) -> dict[str, Any]: ...
|
||||
|
||||
def getpeercert(
|
||||
self, *, binary_form: bool = False
|
||||
) -> bytes | dict[str, typing.Any]:
|
||||
x509_certificate = self._quic.get_peercert()
|
||||
|
||||
if x509_certificate is None:
|
||||
raise ValueError("TLS handshake has not been done yet")
|
||||
|
||||
if binary_form:
|
||||
return x509_certificate.public_bytes()
|
||||
|
||||
peer_info = {
|
||||
"version": x509_certificate.version + 1,
|
||||
"serialNumber": x509_certificate.serial_number.upper(),
|
||||
"subject": [],
|
||||
"issuer": [],
|
||||
"notBefore": datetime.datetime.fromtimestamp(
|
||||
x509_certificate.not_valid_before, tz=datetime.timezone.utc
|
||||
).strftime("%b %d %H:%M:%S %Y")
|
||||
+ " UTC",
|
||||
"notAfter": datetime.datetime.fromtimestamp(
|
||||
x509_certificate.not_valid_after, tz=datetime.timezone.utc
|
||||
).strftime("%b %d %H:%M:%S %Y")
|
||||
+ " UTC",
|
||||
"subjectAltName": [],
|
||||
"OCSP": [],
|
||||
"caIssuers": [],
|
||||
"crlDistributionPoints": [],
|
||||
}
|
||||
|
||||
_short_name_assoc = {
|
||||
"CN": "commonName",
|
||||
"L": "localityName",
|
||||
"ST": "stateOrProvinceName",
|
||||
"O": "organizationName",
|
||||
"OU": "organizationalUnitName",
|
||||
"C": "countryName",
|
||||
"STREET": "streetAddress",
|
||||
"DC": "domainComponent",
|
||||
"E": "email",
|
||||
}
|
||||
|
||||
for raw_oid, rfc4514_attribute_name, value in x509_certificate.subject:
|
||||
if rfc4514_attribute_name not in _short_name_assoc:
|
||||
continue
|
||||
peer_info["subject"].append( # type: ignore[attr-defined]
|
||||
(
|
||||
(
|
||||
_short_name_assoc[rfc4514_attribute_name],
|
||||
value.decode(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
for raw_oid, rfc4514_attribute_name, value in x509_certificate.issuer:
|
||||
if rfc4514_attribute_name not in _short_name_assoc:
|
||||
continue
|
||||
peer_info["issuer"].append( # type: ignore[attr-defined]
|
||||
(
|
||||
(
|
||||
_short_name_assoc[rfc4514_attribute_name],
|
||||
value.decode(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
for alt_name in x509_certificate.get_subject_alt_names():
|
||||
decoded_alt_name = alt_name.decode()
|
||||
in_parenthesis = decoded_alt_name[
|
||||
decoded_alt_name.index("(") + 1 : decoded_alt_name.index(")")
|
||||
]
|
||||
if decoded_alt_name.startswith("DNS"):
|
||||
peer_info["subjectAltName"].append(("DNS", in_parenthesis)) # type: ignore[attr-defined]
|
||||
else:
|
||||
from ....resolver.utils import inet4_ntoa, inet6_ntoa
|
||||
|
||||
if len(in_parenthesis) == 11:
|
||||
ip_address_decoded = inet4_ntoa(
|
||||
bytes.fromhex(in_parenthesis.replace(":", ""))
|
||||
)
|
||||
else:
|
||||
ip_address_decoded = inet6_ntoa(
|
||||
bytes.fromhex(in_parenthesis.replace(":", ""))
|
||||
)
|
||||
peer_info["subjectAltName"].append(("IP Address", ip_address_decoded)) # type: ignore[attr-defined]
|
||||
|
||||
peer_info["OCSP"] = []
|
||||
|
||||
for endpoint in x509_certificate.get_ocsp_endpoints():
|
||||
decoded_endpoint = endpoint.decode()
|
||||
|
||||
peer_info["OCSP"].append( # type: ignore[attr-defined]
|
||||
decoded_endpoint[decoded_endpoint.index("(") + 1 : -1]
|
||||
)
|
||||
|
||||
peer_info["caIssuers"] = []
|
||||
|
||||
for endpoint in x509_certificate.get_issuer_endpoints():
|
||||
decoded_endpoint = endpoint.decode()
|
||||
peer_info["caIssuers"].append( # type: ignore[attr-defined]
|
||||
decoded_endpoint[decoded_endpoint.index("(") + 1 : -1]
|
||||
)
|
||||
|
||||
peer_info["crlDistributionPoints"] = []
|
||||
|
||||
for endpoint in x509_certificate.get_crl_endpoints():
|
||||
decoded_endpoint = endpoint.decode()
|
||||
peer_info["crlDistributionPoints"].append( # type: ignore[attr-defined]
|
||||
decoded_endpoint[decoded_endpoint.index("(") + 1 : -1]
|
||||
)
|
||||
|
||||
pop_keys = []
|
||||
|
||||
for k in peer_info:
|
||||
if isinstance(peer_info[k], list):
|
||||
peer_info[k] = tuple(peer_info[k]) # type: ignore[arg-type]
|
||||
if not peer_info[k]:
|
||||
pop_keys.append(k)
|
||||
|
||||
for k in pop_keys:
|
||||
peer_info.pop(k)
|
||||
|
||||
return peer_info
|
||||
|
||||
def cipher(self) -> str | None:
|
||||
cipher_suite = self._quic.get_cipher()
|
||||
|
||||
if cipher_suite is None:
|
||||
raise ValueError("TLS handshake has not been done yet")
|
||||
|
||||
return f"TLS_{cipher_suite.name}"
|
||||
|
||||
def reshelve(self, *events: Event) -> None:
|
||||
for ev in reversed(events):
|
||||
self._events.appendleft(ev)
|
||||
|
||||
def ping(self) -> None:
|
||||
self._quic.send_ping(randint(0, 65535))
|
||||
|
||||
def max_frame_size(self) -> int:
|
||||
if self._max_frame_size is not None:
|
||||
return self._max_frame_size
|
||||
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from io import UnsupportedOperation
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import ssl
|
||||
|
||||
from ._ctypes import load_cert_chain as _ctypes_load_cert_chain
|
||||
from ._shm import load_cert_chain as _shm_load_cert_chain
|
||||
|
||||
SUPPORTED_METHODS: list[
|
||||
typing.Callable[
|
||||
[
|
||||
ssl.SSLContext,
|
||||
bytes | str,
|
||||
bytes | str,
|
||||
bytes | str | typing.Callable[[], str | bytes] | None,
|
||||
],
|
||||
None,
|
||||
]
|
||||
] = [
|
||||
_ctypes_load_cert_chain,
|
||||
_shm_load_cert_chain,
|
||||
]
|
||||
|
||||
|
||||
def load_cert_chain(
|
||||
ctx: ssl.SSLContext,
|
||||
certdata: bytes | str,
|
||||
keydata: bytes | str,
|
||||
password: bytes | str | typing.Callable[[], str | bytes] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Unique workaround the known limitation of CPython inability to initialize the mTLS context without files.
|
||||
:raise UnsupportedOperation: If anything goes wrong in the process.
|
||||
"""
|
||||
err = None
|
||||
|
||||
for supported in SUPPORTED_METHODS:
|
||||
try:
|
||||
supported(
|
||||
ctx,
|
||||
certdata,
|
||||
keydata,
|
||||
password,
|
||||
)
|
||||
return
|
||||
except UnsupportedOperation as e:
|
||||
if err is None:
|
||||
err = e
|
||||
|
||||
if err is not None:
|
||||
raise err
|
||||
|
||||
raise UnsupportedOperation("unable to initialize mTLS using in-memory cert and key")
|
||||
|
||||
|
||||
__all__ = ("load_cert_chain",)
|
||||
@@ -0,0 +1,376 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ctypes
|
||||
import os
|
||||
import sys
|
||||
import typing
|
||||
from io import UnsupportedOperation
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import ssl
|
||||
|
||||
|
||||
class _OpenSSL:
|
||||
"""Access hazardous material from CPython OpenSSL (or compatible SSL) implementation."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
import platform
|
||||
|
||||
if platform.python_implementation() != "CPython":
|
||||
raise UnsupportedOperation("Only CPython is supported")
|
||||
|
||||
import ssl
|
||||
|
||||
self._name = ssl.OPENSSL_VERSION
|
||||
self.ssl = ssl
|
||||
|
||||
# bug seen in Windows + CPython < 3.11
|
||||
# where CPython official API for options
|
||||
# cast OpenSSL get_options to SIGNED long
|
||||
# where we want UNSIGNED long.
|
||||
_ssl_options_signed_long_bug = False
|
||||
|
||||
if not hasattr(ssl, "_ssl"):
|
||||
raise UnsupportedOperation(
|
||||
"Unsupported interpreter due to missing private ssl module"
|
||||
)
|
||||
|
||||
if platform.system() == "Windows":
|
||||
# possible search locations
|
||||
candidates = {
|
||||
os.path.dirname(sys.executable),
|
||||
os.path.join(sys.prefix, "DLLs"),
|
||||
sys.prefix,
|
||||
}
|
||||
|
||||
if hasattr(ssl._ssl, "__file__"):
|
||||
candidates.add(os.path.dirname(ssl._ssl.__file__))
|
||||
|
||||
_ssl_options_signed_long_bug = sys.version_info < (3, 11)
|
||||
|
||||
ssl_potential_match = None
|
||||
crypto_potential_match = None
|
||||
|
||||
for d in candidates:
|
||||
if not os.path.exists(d):
|
||||
continue
|
||||
|
||||
for filename in os.listdir(d):
|
||||
if ssl_potential_match is None:
|
||||
if filename.startswith("libssl") and filename.endswith(".dll"):
|
||||
ssl_potential_match = os.path.join(d, filename)
|
||||
|
||||
if crypto_potential_match is None:
|
||||
if filename.startswith("libcrypto") and filename.endswith(
|
||||
".dll"
|
||||
):
|
||||
crypto_potential_match = os.path.join(d, filename)
|
||||
|
||||
if crypto_potential_match and ssl_potential_match:
|
||||
break
|
||||
|
||||
if not ssl_potential_match or not crypto_potential_match:
|
||||
raise UnsupportedOperation(
|
||||
"Could not locate OpenSSL DLLs next to Python; "
|
||||
"check your /DLLs folder or your PATH."
|
||||
)
|
||||
|
||||
self._ssl = ctypes.CDLL(ssl_potential_match)
|
||||
self._crypto = ctypes.CDLL(crypto_potential_match)
|
||||
else:
|
||||
# that's the most common path
|
||||
# ssl built in module already loaded both crypto and ssl
|
||||
# symbols.
|
||||
if hasattr(ssl._ssl, "__file__"):
|
||||
self._ssl = ctypes.CDLL(ssl._ssl.__file__)
|
||||
else:
|
||||
# _ssl is statically linked into the interpreter
|
||||
# (e.g. python-build-standalone via uv). OpenSSL symbols
|
||||
# are in the main process image; ctypes.CDLL(None) exposes them.
|
||||
# see https://github.com/jawah/urllib3.future/issues/325 for more
|
||||
# details.
|
||||
self._ssl = ctypes.CDLL(None)
|
||||
self._crypto = self._ssl
|
||||
|
||||
# we want to ensure a minimal set of symbols
|
||||
# are present. CPython should have at least:
|
||||
for required_symbol in [
|
||||
"SSL_CTX_use_certificate",
|
||||
"SSL_CTX_check_private_key",
|
||||
"SSL_CTX_use_PrivateKey",
|
||||
]:
|
||||
if not hasattr(self._ssl, required_symbol):
|
||||
raise UnsupportedOperation(
|
||||
f"Python interpreter built against '{self._name}' is unsupported. (libssl) {required_symbol} is not present."
|
||||
)
|
||||
|
||||
for required_symbol in [
|
||||
"BIO_free",
|
||||
"BIO_new_mem_buf",
|
||||
"PEM_read_bio_X509",
|
||||
"PEM_read_bio_PrivateKey",
|
||||
"ERR_get_error",
|
||||
"ERR_error_string",
|
||||
]:
|
||||
if not hasattr(self._crypto, required_symbol):
|
||||
raise UnsupportedOperation(
|
||||
f"Python interpreter built against '{self._name}' is unsupported. (libcrypto) {required_symbol} is not present."
|
||||
)
|
||||
|
||||
# https://docs.openssl.org/3.0/man3/SSL_CTX_use_certificate/
|
||||
self.SSL_CTX_use_certificate = self._ssl.SSL_CTX_use_certificate
|
||||
self.SSL_CTX_use_certificate.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
|
||||
self.SSL_CTX_use_certificate.restype = ctypes.c_int
|
||||
|
||||
self.SSL_CTX_check_private_key = self._ssl.SSL_CTX_check_private_key
|
||||
self.SSL_CTX_check_private_key.argtypes = [ctypes.c_void_p]
|
||||
self.SSL_CTX_check_private_key.restype = ctypes.c_int
|
||||
|
||||
# https://docs.openssl.org/3.0/man3/BIO_new/
|
||||
self.BIO_free = self._crypto.BIO_free
|
||||
self.BIO_free.argtypes = [ctypes.c_void_p]
|
||||
self.BIO_free.restype = None
|
||||
|
||||
self.BIO_new_mem_buf = self._crypto.BIO_new_mem_buf
|
||||
self.BIO_new_mem_buf.argtypes = [ctypes.c_void_p, ctypes.c_int]
|
||||
self.BIO_new_mem_buf.restype = ctypes.c_void_p
|
||||
|
||||
# https://docs.openssl.org/3.0/man3/PEM_read_bio_PrivateKey/
|
||||
self.PEM_read_bio_X509 = self._crypto.PEM_read_bio_X509
|
||||
self.PEM_read_bio_X509.argtypes = [
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
]
|
||||
self.PEM_read_bio_X509.restype = ctypes.c_void_p
|
||||
|
||||
self.PEM_read_bio_PrivateKey = self._crypto.PEM_read_bio_PrivateKey
|
||||
self.PEM_read_bio_PrivateKey.argtypes = [
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
]
|
||||
self.PEM_read_bio_PrivateKey.restype = ctypes.c_void_p
|
||||
|
||||
# https://docs.openssl.org/3.0/man3/SSL_CTX_use_certificate/
|
||||
self.SSL_CTX_use_PrivateKey = self._ssl.SSL_CTX_use_PrivateKey
|
||||
self.SSL_CTX_use_PrivateKey.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
|
||||
self.SSL_CTX_use_PrivateKey.restype = ctypes.c_int
|
||||
|
||||
self.ERR_get_error = self._crypto.ERR_get_error
|
||||
self.ERR_get_error.argtypes = []
|
||||
self.ERR_get_error.restype = ctypes.c_ulong
|
||||
|
||||
self.ERR_error_string = self._crypto.ERR_error_string
|
||||
self.ERR_error_string.argtypes = [ctypes.c_ulong, ctypes.c_char_p]
|
||||
self.ERR_error_string.restype = ctypes.c_char_p
|
||||
|
||||
if hasattr(self._ssl, "SSL_CTX_get_options"):
|
||||
self.SSL_CTX_get_options = self._ssl.SSL_CTX_get_options
|
||||
self.SSL_CTX_get_options.argtypes = [ctypes.c_void_p]
|
||||
self.SSL_CTX_get_options.restype = (
|
||||
ctypes.c_ulong if not _ssl_options_signed_long_bug else ctypes.c_long
|
||||
) # OpenSSL's options are long
|
||||
elif hasattr(self._ssl, "SSL_CTX_ctrl"):
|
||||
# some old build inline SSL_CTX_get_options (mere C define)
|
||||
# define SSL_CTX_get_options(ctx) SSL_CTX_ctrl((ctx),SSL_CTRL_OPTIONS,0,NULL)
|
||||
# define SSL_CTRL_OPTIONS 32
|
||||
|
||||
self.SSL_CTX_ctrl = self._ssl.SSL_CTX_ctrl
|
||||
self.SSL_CTX_ctrl.argtypes = [
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_int,
|
||||
ctypes.c_int,
|
||||
ctypes.c_void_p,
|
||||
]
|
||||
self.SSL_CTX_ctrl.restype = (
|
||||
ctypes.c_ulong if not _ssl_options_signed_long_bug else ctypes.c_long
|
||||
)
|
||||
|
||||
self.SSL_CTX_get_options = lambda ctx: self.SSL_CTX_ctrl( # type: ignore[assignment]
|
||||
ctx, 32, 0, None
|
||||
)
|
||||
else:
|
||||
raise UnsupportedOperation()
|
||||
|
||||
def pull_error(self) -> typing.NoReturn:
|
||||
raise self.ssl.SSLError(
|
||||
self.ERR_error_string(
|
||||
self.ERR_get_error(), ctypes.create_string_buffer(256)
|
||||
).decode()
|
||||
)
|
||||
|
||||
|
||||
_IS_GIL_DISABLED = hasattr(sys, "_is_gil_enabled") and sys._is_gil_enabled() is False
|
||||
_IS_LINUX = sys.platform == "linux"
|
||||
_FT_HEAD_ADDITIONAL_OFFSET = 1 if _IS_LINUX else 2
|
||||
|
||||
_head_extra_fields = []
|
||||
|
||||
if sys.flags.debug:
|
||||
# In debug builds (_POSIX_C_SOURCE or Py_DEBUG is defined), PyObject_HEAD
|
||||
# is preceded by _PyObject_HEAD_EXTRA, which typically consists of
|
||||
# two pointers (_ob_next, _ob_prev).
|
||||
_head_extra_fields = [("_ob_next", ctypes.c_void_p), ("_ob_prev", ctypes.c_void_p)]
|
||||
|
||||
|
||||
# Define the PySSLContext C structure using ctypes.
|
||||
# This definition assumes that 'SSL_CTX *ctx' is the first member
|
||||
# immediately following PyObject_HEAD. This has been observed to be
|
||||
# the case in various CPython versions (e.g., 3.7 through 3.14 so far).
|
||||
#
|
||||
# CPython's Modules/_ssl.c (simplified):
|
||||
# typedef struct {
|
||||
# PyObject_HEAD // Expands to _PyObject_HEAD_EXTRA (if debug) + ob_refcnt + ob_type
|
||||
# SSL_CTX *ctx;
|
||||
# // ... other members ...
|
||||
# } PySSLContextObject;
|
||||
#
|
||||
class PySSLContextStruct(ctypes.Structure):
|
||||
_fields_ = (
|
||||
_head_extra_fields # type: ignore[assignment]
|
||||
+ [
|
||||
("ob_refcnt", ctypes.c_ssize_t), # Py_ssize_t ob_refcnt;
|
||||
("ob_type", ctypes.c_void_p), # PyTypeObject *ob_type;
|
||||
]
|
||||
+ (
|
||||
[(f"_ob_ft{i}", ctypes.c_void_p) for i in range(_FT_HEAD_ADDITIONAL_OFFSET)]
|
||||
if _IS_GIL_DISABLED
|
||||
else []
|
||||
)
|
||||
+ [
|
||||
("ssl_ctx", ctypes.c_void_p), # SSL_CTX *ctx; (this is the pointer we want)
|
||||
# If there were other C members between ob_type and ssl_ctx,
|
||||
# they would need to be defined here with their correct types and padding.
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _split_client_cert(data: bytes) -> list[bytes]:
|
||||
line_ending = b"\n" if b"-----\r\n" not in data else b"\r\n"
|
||||
boundary = b"-----END CERTIFICATE-----" + line_ending
|
||||
|
||||
certificates = []
|
||||
|
||||
for chunk in data.split(boundary):
|
||||
if chunk:
|
||||
start_marker = chunk.find(b"-----BEGIN CERTIFICATE-----" + line_ending)
|
||||
if start_marker == -1:
|
||||
break
|
||||
pem_reconstructed = b"".join([chunk[start_marker:], boundary])
|
||||
certificates.append(pem_reconstructed)
|
||||
|
||||
return certificates
|
||||
|
||||
|
||||
def load_cert_chain(
|
||||
ctx: ssl.SSLContext,
|
||||
certdata: bytes | str,
|
||||
keydata: bytes | str,
|
||||
password: bytes | str | typing.Callable[[], str | bytes] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Unique workaround the known limitation of CPython inability to initialize the mTLS context without files.
|
||||
:raise UnsupportedOperation: If anything goes wrong in the process.
|
||||
"""
|
||||
lib = _OpenSSL()
|
||||
|
||||
# Get the memory address of the Python ssl.SSLContext object.
|
||||
# id() returns the address of the PyObject.
|
||||
addr = id(ctx)
|
||||
|
||||
# Cast this memory address to a pointer to our defined PySSLContextStruct.
|
||||
ptr_to_pysslcontext_struct = ctypes.cast(addr, ctypes.POINTER(PySSLContextStruct))
|
||||
|
||||
# Access the 'ssl_ctx' field from the structure. This field holds the
|
||||
# actual SSL_CTX* C pointer value.
|
||||
ssl_ctx_address = ptr_to_pysslcontext_struct.contents.ssl_ctx
|
||||
|
||||
# We want to ensure we got the right pointer address
|
||||
# the safest way to achieve that is to retrieve options
|
||||
# and compare it with the official ctx property.
|
||||
if lib.SSL_CTX_get_options is not None:
|
||||
bypass_options = lib.SSL_CTX_get_options(ssl_ctx_address)
|
||||
expected_options = int(ctx.options)
|
||||
|
||||
if bypass_options != expected_options:
|
||||
raise UnsupportedOperation(
|
||||
f"CPython internal SSL_CTX changed! Cannot pursue safely. Expected = {expected_options:x} Actual = {bypass_options:x}"
|
||||
)
|
||||
|
||||
# normalize inputs
|
||||
if isinstance(certdata, str):
|
||||
certdata = certdata.encode()
|
||||
if isinstance(keydata, str):
|
||||
keydata = keydata.encode()
|
||||
|
||||
client_chain = _split_client_cert(certdata)
|
||||
|
||||
leaf_certificate = client_chain[0]
|
||||
|
||||
# Use a BIO to read the client certificate
|
||||
# only the leaf certificate is supported here.
|
||||
cert_bio = lib.BIO_new_mem_buf(leaf_certificate, len(leaf_certificate))
|
||||
|
||||
if not cert_bio:
|
||||
raise MemoryError("Unable to allocate memory to load the client certificate")
|
||||
|
||||
# Use a BIO to load the key in-memory
|
||||
key_bio = lib.BIO_new_mem_buf(keydata, len(keydata))
|
||||
|
||||
if not key_bio:
|
||||
raise MemoryError("Unable to allocate memory to load the client key")
|
||||
|
||||
# prepare the password
|
||||
if callable(password):
|
||||
password = password()
|
||||
|
||||
if isinstance(password, str):
|
||||
password = password.encode()
|
||||
|
||||
assert password is None or isinstance(password, bytes)
|
||||
|
||||
# the allocated X509 obj MUST NOT be freed by ourselves
|
||||
# OpenSSL internals will free it once not needed.
|
||||
cert = lib.PEM_read_bio_X509(cert_bio, None, None, None)
|
||||
|
||||
# we do own the BIO, once the X509 leaf is instantiated, no need
|
||||
# to keep it afterward.
|
||||
lib.BIO_free(cert_bio)
|
||||
|
||||
if not cert:
|
||||
lib.pull_error()
|
||||
|
||||
pkey = lib.PEM_read_bio_PrivateKey(key_bio, None, None, password)
|
||||
|
||||
lib.BIO_free(key_bio)
|
||||
|
||||
if not pkey:
|
||||
lib.pull_error()
|
||||
|
||||
if lib.SSL_CTX_use_certificate(ssl_ctx_address, cert) != 1:
|
||||
lib.pull_error()
|
||||
|
||||
if lib.SSL_CTX_use_PrivateKey(ssl_ctx_address, pkey) != 1:
|
||||
lib.pull_error()
|
||||
|
||||
if lib.SSL_CTX_check_private_key(ssl_ctx_address) != 1:
|
||||
lib.pull_error()
|
||||
|
||||
# Unfortunately, most of the time
|
||||
# SSL_CTX_add_extra_chain_cert is unavailable
|
||||
# in the final CPython build.
|
||||
# According to OpenSSL latest docs: "The engine
|
||||
# will attempt to build the required chain for the CA store"
|
||||
# It's not going to be used as a trust anchor! (i.e. not self-signed)
|
||||
# "If no chain is specified, the library will try to complete the
|
||||
# chain from the available CA certificates in the trusted
|
||||
# CA storage, see SSL_CTX_load_verify_locations(3)."
|
||||
# see: https://docs.openssl.org/master/man3/SSL_CTX_add_extra_chain_cert/#notes
|
||||
if len(client_chain) > 1:
|
||||
ctx.load_verify_locations(cadata=(b"\n".join(client_chain[1:])).decode())
|
||||
|
||||
|
||||
__all__ = ("load_cert_chain",)
|
||||
@@ -0,0 +1,122 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import secrets
|
||||
import stat
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
from hashlib import sha256
|
||||
from io import UnsupportedOperation
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import ssl
|
||||
|
||||
|
||||
def load_cert_chain(
|
||||
ctx: ssl.SSLContext,
|
||||
certdata: str | bytes,
|
||||
keydata: str | bytes | None = None,
|
||||
password: typing.Callable[[], str | bytes] | str | bytes | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Unique workaround the known limitation of CPython inability to initialize the mTLS context without files.
|
||||
Only supported on Linux, FreeBSD, and OpenBSD.
|
||||
:raise UnsupportedOperation: If anything goes wrong in the process.
|
||||
"""
|
||||
if (
|
||||
sys.platform != "linux"
|
||||
and sys.platform.startswith("freebsd") is False
|
||||
and sys.platform.startswith("openbsd") is False
|
||||
):
|
||||
raise UnsupportedOperation(
|
||||
f"Unable to provide support for in-memory client certificate: Unsupported platform {sys.platform}"
|
||||
)
|
||||
|
||||
unique_name: str = f"{sha256(secrets.token_bytes(32)).hexdigest()}.pem"
|
||||
|
||||
if isinstance(certdata, bytes):
|
||||
certdata = certdata.decode("ascii")
|
||||
|
||||
if keydata is not None:
|
||||
if isinstance(keydata, bytes):
|
||||
keydata = keydata.decode("ascii")
|
||||
|
||||
if hasattr(os, "memfd_create"):
|
||||
fd = os.memfd_create(unique_name, os.MFD_CLOEXEC)
|
||||
else:
|
||||
# this branch patch is for CPython <3.8 and PyPy 3.7+
|
||||
from ctypes import c_int, c_ushort, cdll, create_string_buffer, get_errno, util
|
||||
|
||||
loc = util.find_library("rt") or util.find_library("c")
|
||||
|
||||
if not loc:
|
||||
raise UnsupportedOperation(
|
||||
"Unable to provide support for in-memory client certificate: libc or librt not found."
|
||||
)
|
||||
|
||||
lib = cdll.LoadLibrary(loc)
|
||||
|
||||
_shm_open = lib.shm_open
|
||||
# _shm_unlink = lib.shm_unlink
|
||||
|
||||
buf_name = create_string_buffer(unique_name.encode())
|
||||
|
||||
try:
|
||||
fd = _shm_open(
|
||||
buf_name,
|
||||
c_int(os.O_RDWR | os.O_CREAT),
|
||||
c_ushort(stat.S_IRUSR | stat.S_IWUSR),
|
||||
)
|
||||
except SystemError as e:
|
||||
raise UnsupportedOperation(
|
||||
f"Unable to provide support for in-memory client certificate: {e}"
|
||||
)
|
||||
|
||||
if fd == -1:
|
||||
raise UnsupportedOperation(
|
||||
f"Unable to provide support for in-memory client certificate: {os.strerror(get_errno())}"
|
||||
)
|
||||
|
||||
# Linux 3.17+
|
||||
path = f"/proc/self/fd/{fd}"
|
||||
|
||||
# Alt-path
|
||||
shm_path = f"/dev/shm/{unique_name}"
|
||||
|
||||
if os.path.exists(path) is False:
|
||||
if os.path.exists(shm_path):
|
||||
path = shm_path
|
||||
else:
|
||||
os.fdopen(fd).close()
|
||||
|
||||
raise UnsupportedOperation(
|
||||
"Unable to provide support for in-memory client certificate: no virtual patch available?"
|
||||
)
|
||||
|
||||
os.chmod(path, stat.S_IRUSR | stat.S_IWUSR)
|
||||
|
||||
with open(path, "w") as fp:
|
||||
fp.write(certdata)
|
||||
|
||||
if keydata:
|
||||
fp.write(keydata)
|
||||
|
||||
path = fp.name
|
||||
|
||||
ctx.load_cert_chain(path, password=password)
|
||||
|
||||
# we shall start cleaning remnants
|
||||
os.fdopen(fd).close()
|
||||
|
||||
if os.path.exists(shm_path):
|
||||
os.unlink(shm_path)
|
||||
|
||||
if os.path.exists(path) or os.path.exists(shm_path):
|
||||
warnings.warn(
|
||||
"In-memory client certificate: The kernel leaked a file descriptor outside of its expected lifetime.",
|
||||
ResourceWarning,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ("load_cert_chain",)
|
||||
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
(
|
||||
"'urllib3.contrib.pyopenssl' module has been removed in urllib3.future due to incompatibilities "
|
||||
"with our QUIC integration. While the import proceed without error for your convenience, it is rendered "
|
||||
"completely ineffective. Were you looking for in-memory client certificate? "
|
||||
"See https://urllib3future.readthedocs.io/en/latest/advanced-usage.html#in-memory-client-mtls-certificate"
|
||||
),
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
import OpenSSL.SSL # type: ignore # noqa
|
||||
|
||||
__all__ = ["inject_into_urllib3", "extract_from_urllib3"]
|
||||
|
||||
|
||||
def inject_into_urllib3() -> None:
|
||||
"""Kept for BC-purposes."""
|
||||
...
|
||||
|
||||
|
||||
def extract_from_urllib3() -> None:
|
||||
"""Kept for BC-purposes."""
|
||||
...
|
||||
@@ -0,0 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .factories import ResolverDescription, ResolverFactory
|
||||
from .protocols import BaseResolver, ManyResolver, ProtocolResolver
|
||||
|
||||
__all__ = (
|
||||
"ResolverFactory",
|
||||
"ProtocolResolver",
|
||||
"BaseResolver",
|
||||
"ManyResolver",
|
||||
"ResolverDescription",
|
||||
)
|
||||
@@ -0,0 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .factories import AsyncResolverDescription, AsyncResolverFactory
|
||||
from .protocols import AsyncBaseResolver, AsyncManyResolver
|
||||
|
||||
__all__ = (
|
||||
"AsyncResolverDescription",
|
||||
"AsyncResolverFactory",
|
||||
"AsyncBaseResolver",
|
||||
"AsyncManyResolver",
|
||||
)
|
||||
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._urllib3 import (
|
||||
AdGuardResolver,
|
||||
CloudflareResolver,
|
||||
GoogleResolver,
|
||||
HTTPSResolver,
|
||||
NextDNSResolver,
|
||||
OpenDNSResolver,
|
||||
Quad9Resolver,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"HTTPSResolver",
|
||||
"GoogleResolver",
|
||||
"CloudflareResolver",
|
||||
"AdGuardResolver",
|
||||
"OpenDNSResolver",
|
||||
"Quad9Resolver",
|
||||
"NextDNSResolver",
|
||||
)
|
||||
@@ -0,0 +1,656 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
from asyncio import as_completed
|
||||
from base64 import b64encode
|
||||
|
||||
from ....._async.connectionpool import AsyncHTTPSConnectionPool
|
||||
from ....._async.response import AsyncHTTPResponse
|
||||
from ....._collections import HTTPHeaderDict
|
||||
from .....backend import ConnectionInfo, HttpVersion
|
||||
from .....util.url import parse_url
|
||||
from ...protocols import (
|
||||
DomainNameServerQuery,
|
||||
DomainNameServerReturn,
|
||||
ProtocolResolver,
|
||||
SupportedQueryType,
|
||||
)
|
||||
from ...utils import is_ipv4, is_ipv6, validate_length_of, parse_https_rdata
|
||||
from ..protocols import AsyncBaseResolver
|
||||
|
||||
|
||||
class HTTPSResolver(AsyncBaseResolver):
|
||||
"""
|
||||
Advanced DNS over HTTPS resolver.
|
||||
No common ground emerged from IETF w/ JSON. Following Google’s DNS over HTTPS schematics that is
|
||||
also implemented at Cloudflare.
|
||||
|
||||
Support RFC 8484 without JSON. Disabled by default.
|
||||
"""
|
||||
|
||||
implementation = "urllib3"
|
||||
protocol = ProtocolResolver.DOH
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: str | None,
|
||||
port: int | None = None,
|
||||
*patterns: str,
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
super().__init__(server, port or 443, *patterns, **kwargs)
|
||||
|
||||
self._path: str = "/resolve"
|
||||
|
||||
if "path" in kwargs:
|
||||
if isinstance(kwargs["path"], str) and kwargs["path"] != "/":
|
||||
self._path = kwargs["path"]
|
||||
kwargs.pop("path")
|
||||
|
||||
self._rfc8484: bool = False
|
||||
|
||||
if "rfc8484" in kwargs:
|
||||
if kwargs["rfc8484"]:
|
||||
self._rfc8484 = True
|
||||
kwargs.pop("rfc8484")
|
||||
|
||||
assert self._server is not None
|
||||
|
||||
if "source_address" in kwargs:
|
||||
if isinstance(kwargs["source_address"], str):
|
||||
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
|
||||
|
||||
if bind_ip and bind_port.isdigit():
|
||||
kwargs["source_address"] = (
|
||||
bind_ip,
|
||||
int(bind_port),
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid source_address given in parameters")
|
||||
else:
|
||||
raise ValueError("invalid source_address given in parameters")
|
||||
|
||||
if "proxy" in kwargs:
|
||||
kwargs["_proxy"] = parse_url(kwargs["proxy"])
|
||||
kwargs.pop("proxy")
|
||||
|
||||
if "maxsize" not in kwargs:
|
||||
kwargs["maxsize"] = 10
|
||||
|
||||
if "proxy_headers" in kwargs and "_proxy" in kwargs:
|
||||
proxy_headers = HTTPHeaderDict()
|
||||
|
||||
if not isinstance(kwargs["proxy_headers"], list):
|
||||
kwargs["proxy_headers"] = [kwargs["proxy_headers"]]
|
||||
|
||||
for item in kwargs["proxy_headers"]:
|
||||
if ":" not in item:
|
||||
raise ValueError("Passed header is invalid in DNS parameters")
|
||||
|
||||
k, v = item.split(":", 1)
|
||||
proxy_headers.add(k, v)
|
||||
|
||||
kwargs["_proxy_headers"] = proxy_headers
|
||||
|
||||
if "headers" in kwargs:
|
||||
headers = HTTPHeaderDict()
|
||||
|
||||
if not isinstance(kwargs["headers"], list):
|
||||
kwargs["headers"] = [kwargs["headers"]]
|
||||
|
||||
for item in kwargs["headers"]:
|
||||
if ":" not in item:
|
||||
raise ValueError("Passed header is invalid in DNS parameters")
|
||||
|
||||
k, v = item.split(":", 1)
|
||||
headers.add(k, v)
|
||||
|
||||
kwargs["headers"] = headers
|
||||
|
||||
if "disabled_svn" in kwargs:
|
||||
if not isinstance(kwargs["disabled_svn"], list):
|
||||
kwargs["disabled_svn"] = [kwargs["disabled_svn"]]
|
||||
|
||||
disabled_svn = set()
|
||||
|
||||
for svn in kwargs["disabled_svn"]:
|
||||
svn = svn.lower()
|
||||
|
||||
if svn == "h11":
|
||||
disabled_svn.add(HttpVersion.h11)
|
||||
elif svn == "h2":
|
||||
disabled_svn.add(HttpVersion.h2)
|
||||
elif svn == "h3":
|
||||
disabled_svn.add(HttpVersion.h3)
|
||||
|
||||
kwargs["disabled_svn"] = disabled_svn
|
||||
|
||||
if "on_post_connection" in kwargs and callable(kwargs["on_post_connection"]):
|
||||
self._connection_callback: (
|
||||
typing.Callable[[ConnectionInfo], None] | None
|
||||
) = kwargs["on_post_connection"]
|
||||
kwargs.pop("on_post_connection")
|
||||
else:
|
||||
self._connection_callback = None
|
||||
|
||||
self._pool = AsyncHTTPSConnectionPool(self._server, self._port, **kwargs)
|
||||
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
await self._pool.close()
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self._pool.pool is not None
|
||||
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if host is None:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Tried to resolve 'localhost' from a HTTPSResolver"
|
||||
)
|
||||
|
||||
if port is None:
|
||||
port = 0 # Defensive: stdlib cpy behavior
|
||||
if isinstance(port, str):
|
||||
port = int(port) # Defensive: stdlib cpy behavior
|
||||
if port < 0:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Servname not supported for ai_socktype"
|
||||
)
|
||||
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii") # Defensive: stdlib cpy behavior
|
||||
|
||||
if is_ipv4(host):
|
||||
if family == socket.AF_INET6:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
type,
|
||||
6,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
),
|
||||
)
|
||||
]
|
||||
elif is_ipv6(host):
|
||||
if family == socket.AF_INET:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET6,
|
||||
type,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
validate_length_of(host)
|
||||
|
||||
promises = []
|
||||
remote_preemptive_quic_rr = False
|
||||
|
||||
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
|
||||
quic_upgrade_via_dns_rr = False
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET]:
|
||||
if not self._rfc8484:
|
||||
promises.append(
|
||||
await self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{"name": host, "type": "1"},
|
||||
headers={"Accept": "application/dns-json"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
dns_query = DomainNameServerQuery(
|
||||
host, SupportedQueryType.A, override_id=0
|
||||
)
|
||||
dns_payload = bytes(dns_query)
|
||||
promises.append(
|
||||
await self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{
|
||||
"dns": b64encode(dns_payload).decode().replace("=", ""),
|
||||
},
|
||||
headers={"Accept": "application/dns-message"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
|
||||
if not self._rfc8484:
|
||||
promises.append(
|
||||
await self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{"name": host, "type": "28"},
|
||||
headers={"Accept": "application/dns-json"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
dns_query = DomainNameServerQuery(
|
||||
host, SupportedQueryType.AAAA, override_id=0
|
||||
)
|
||||
dns_payload = bytes(dns_query)
|
||||
|
||||
promises.append(
|
||||
await self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{
|
||||
"dns": b64encode(dns_payload).decode().replace("=", ""),
|
||||
},
|
||||
headers={"Accept": "application/dns-message"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
|
||||
if quic_upgrade_via_dns_rr:
|
||||
if not self._rfc8484:
|
||||
promises.append(
|
||||
await self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{"name": host, "type": "65"},
|
||||
headers={"Accept": "application/dns-json"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
dns_query = DomainNameServerQuery(
|
||||
host, SupportedQueryType.HTTPS, override_id=0
|
||||
)
|
||||
dns_payload = bytes(dns_query)
|
||||
|
||||
promises.append(
|
||||
await self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{
|
||||
"dns": b64encode(dns_payload).decode().replace("=", ""),
|
||||
},
|
||||
headers={"Accept": "application/dns-message"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
|
||||
tasks = []
|
||||
responses = []
|
||||
|
||||
for promise in promises:
|
||||
# already resolved
|
||||
if isinstance(promise, AsyncHTTPResponse):
|
||||
responses.append(promise)
|
||||
continue
|
||||
tasks.append(self._pool.get_response(promise=promise))
|
||||
|
||||
if tasks:
|
||||
for waiting_promise_coro in as_completed(tasks):
|
||||
responses.append(await waiting_promise_coro) # type: ignore[arg-type]
|
||||
|
||||
results: list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
] = []
|
||||
|
||||
for response in responses:
|
||||
if response.status >= 300:
|
||||
raise socket.gaierror(
|
||||
f"DNS over HTTPS was unsuccessful, server response status {response.status}."
|
||||
)
|
||||
|
||||
if not self._rfc8484:
|
||||
payload = await response.json()
|
||||
|
||||
assert "Status" in payload and isinstance(payload["Status"], int)
|
||||
|
||||
if payload["Status"] != 0:
|
||||
msg = (
|
||||
payload["Comment"]
|
||||
if "Comment" in payload
|
||||
else f"Remote DNS indicated that an error occurred while providing resolution. Status {payload['Status']}."
|
||||
)
|
||||
|
||||
if isinstance(msg, list):
|
||||
msg = ", ".join(msg)
|
||||
|
||||
raise socket.gaierror(msg)
|
||||
|
||||
assert "Question" in payload and isinstance(payload["Question"], list)
|
||||
|
||||
if "Answer" not in payload:
|
||||
continue
|
||||
|
||||
assert isinstance(payload["Answer"], list)
|
||||
|
||||
for answer in payload["Answer"]:
|
||||
if answer["type"] not in [1, 28, 65]:
|
||||
continue
|
||||
|
||||
assert "data" in answer
|
||||
assert isinstance(answer["data"], str)
|
||||
|
||||
# DNS RR/HTTPS
|
||||
if answer["type"] == 65:
|
||||
# "1 . alpn=h3,h2 ipv4hint=104.16.132.229,104.16.133.229 ipv6hint=2606:4700::6810:84e5,2606:4700::6810:85e5"
|
||||
# or..
|
||||
# "1 . alpn=h2,h3"
|
||||
rr: str = answer["data"]
|
||||
|
||||
if rr.startswith("\\#"): # it means, raw, bytes.
|
||||
rr = "".join(rr[2:].split(" ")[2:])
|
||||
|
||||
try:
|
||||
raw_record = bytes.fromhex(rr)
|
||||
except ValueError:
|
||||
raw_record = b""
|
||||
|
||||
if not raw_record:
|
||||
continue
|
||||
|
||||
https_record = parse_https_rdata(raw_record)
|
||||
|
||||
if "h3" not in https_record["alpn"]:
|
||||
continue
|
||||
|
||||
remote_preemptive_quic_rr = True
|
||||
else:
|
||||
rr_decode: dict[str, str] = dict(
|
||||
tuple(_.lower().split("=", 1)) # type: ignore[misc]
|
||||
for _ in rr.split(" ")
|
||||
if "=" in _
|
||||
)
|
||||
|
||||
if "alpn" not in rr_decode or "h3" not in rr_decode["alpn"]:
|
||||
continue
|
||||
|
||||
remote_preemptive_quic_rr = True
|
||||
|
||||
if "ipv4hint" in rr_decode and family in [
|
||||
socket.AF_UNSPEC,
|
||||
socket.AF_INET,
|
||||
]:
|
||||
for ipv4 in rr_decode["ipv4hint"].split(","):
|
||||
results.append(
|
||||
(
|
||||
socket.AF_INET,
|
||||
socket.SOCK_DGRAM,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
ipv4,
|
||||
port,
|
||||
),
|
||||
)
|
||||
)
|
||||
if "ipv6hint" in rr_decode and family in [
|
||||
socket.AF_UNSPEC,
|
||||
socket.AF_INET6,
|
||||
]:
|
||||
for ipv6 in rr_decode["ipv6hint"].split(","):
|
||||
results.append(
|
||||
(
|
||||
socket.AF_INET6,
|
||||
socket.SOCK_DGRAM,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
ipv6,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
inet_type = (
|
||||
socket.AF_INET if answer["type"] == 1 else socket.AF_INET6
|
||||
)
|
||||
|
||||
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
|
||||
(
|
||||
answer["data"],
|
||||
port,
|
||||
)
|
||||
if inet_type == socket.AF_INET
|
||||
else (
|
||||
answer["data"],
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
)
|
||||
|
||||
results.append(
|
||||
(
|
||||
inet_type,
|
||||
type,
|
||||
6 if type == socket.SOCK_STREAM else 17,
|
||||
"",
|
||||
dst_addr,
|
||||
)
|
||||
)
|
||||
else:
|
||||
dns_resp = DomainNameServerReturn(await response.data)
|
||||
|
||||
for record in dns_resp.records:
|
||||
if record[0] == SupportedQueryType.HTTPS:
|
||||
assert isinstance(record[-1], dict)
|
||||
if "h3" in record[-1]["alpn"]:
|
||||
remote_preemptive_quic_rr = True
|
||||
continue
|
||||
|
||||
assert not isinstance(record[-1], dict)
|
||||
|
||||
inet_type = (
|
||||
socket.AF_INET
|
||||
if record[0] == SupportedQueryType.A
|
||||
else socket.AF_INET6
|
||||
)
|
||||
|
||||
dst_addr = (
|
||||
(
|
||||
record[-1],
|
||||
port,
|
||||
)
|
||||
if inet_type == socket.AF_INET
|
||||
else (
|
||||
record[-1],
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
)
|
||||
|
||||
results.append(
|
||||
(
|
||||
inet_type,
|
||||
type,
|
||||
6 if type == socket.SOCK_STREAM else 17,
|
||||
"",
|
||||
dst_addr,
|
||||
)
|
||||
)
|
||||
|
||||
quic_results: list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
] = []
|
||||
|
||||
if remote_preemptive_quic_rr:
|
||||
any_specified = False
|
||||
|
||||
for result in results:
|
||||
if result[1] == socket.SOCK_STREAM:
|
||||
quic_results.append(
|
||||
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
|
||||
)
|
||||
else:
|
||||
any_specified = True
|
||||
break
|
||||
|
||||
if any_specified:
|
||||
quic_results = []
|
||||
|
||||
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
|
||||
|
||||
|
||||
class GoogleResolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "google"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
if "rfc8484" in kwargs:
|
||||
if kwargs["rfc8484"]:
|
||||
kwargs["path"] = "/dns-query"
|
||||
super().__init__("dns.google", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class CloudflareResolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "cloudflare"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
kwargs.update({"path": "/dns-query"})
|
||||
super().__init__("cloudflare-dns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class AdGuardResolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "adguard"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
kwargs.update({"path": "/dns-query", "rfc8484": True})
|
||||
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class OpenDNSResolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "opendns"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
kwargs.update({"path": "/dns-query", "rfc8484": True})
|
||||
super().__init__("dns.opendns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class Quad9Resolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "quad9"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
kwargs.update({"path": "/dns-query", "rfc8484": True})
|
||||
super().__init__("dns11.quad9.net", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class NextDNSResolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "nextdns"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("dns.nextdns.io", port, *patterns, **kwargs)
|
||||
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
try:
|
||||
from ._qh3 import AdGuardResolver, NextDNSResolver, QUICResolver
|
||||
except ImportError:
|
||||
QUICResolver = None # type: ignore
|
||||
AdGuardResolver = None # type: ignore
|
||||
NextDNSResolver = None # type: ignore
|
||||
|
||||
|
||||
__all__ = (
|
||||
"QUICResolver",
|
||||
"AdGuardResolver",
|
||||
"NextDNSResolver",
|
||||
)
|
||||
@@ -0,0 +1,557 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
import ssl
|
||||
import typing
|
||||
from collections import deque
|
||||
from ssl import SSLError
|
||||
from time import time as monotonic
|
||||
|
||||
from qh3.quic.configuration import QuicConfiguration
|
||||
from qh3.quic.connection import QuicConnection
|
||||
from qh3.quic.events import (
|
||||
ConnectionTerminated,
|
||||
HandshakeCompleted,
|
||||
QuicEvent,
|
||||
StopSendingReceived,
|
||||
StreamDataReceived,
|
||||
StreamReset,
|
||||
)
|
||||
|
||||
from .....util.ssl_ import IS_FIPS, resolve_cert_reqs
|
||||
from ...protocols import (
|
||||
COMMON_RCODE_LABEL,
|
||||
DomainNameServerQuery,
|
||||
DomainNameServerReturn,
|
||||
ProtocolResolver,
|
||||
SupportedQueryType,
|
||||
)
|
||||
from ...utils import (
|
||||
is_ipv4,
|
||||
is_ipv6,
|
||||
rfc1035_pack,
|
||||
rfc1035_should_read,
|
||||
rfc1035_unpack,
|
||||
validate_length_of,
|
||||
)
|
||||
from ..dou import PlainResolver
|
||||
from ..system import SystemResolver
|
||||
|
||||
if IS_FIPS:
|
||||
raise ImportError(
|
||||
"DNS-over-QUIC disabled when Python is built with FIPS-compliant ssl module"
|
||||
)
|
||||
|
||||
|
||||
class QUICResolver(PlainResolver):
|
||||
protocol = ProtocolResolver.DOQ
|
||||
implementation = "qh3"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: str | None,
|
||||
port: int | None = None,
|
||||
*patterns: str,
|
||||
**kwargs: typing.Any,
|
||||
):
|
||||
super().__init__(server, port or 853, *patterns, **kwargs)
|
||||
|
||||
# qh3 load_default_certs seems off. need to investigate.
|
||||
if "ca_cert_data" not in kwargs and "ca_certs" not in kwargs:
|
||||
kwargs["ca_cert_data"] = []
|
||||
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
|
||||
try:
|
||||
ctx.load_default_certs()
|
||||
|
||||
for der in ctx.get_ca_certs(binary_form=True):
|
||||
kwargs["ca_cert_data"].append(ssl.DER_cert_to_PEM_cert(der))
|
||||
|
||||
if kwargs["ca_cert_data"]:
|
||||
kwargs["ca_cert_data"] = "".join(kwargs["ca_cert_data"])
|
||||
else:
|
||||
del kwargs["ca_cert_data"]
|
||||
except (AttributeError, ValueError, OSError):
|
||||
del kwargs["ca_cert_data"]
|
||||
|
||||
if "ca_cert_data" not in kwargs and "ca_certs" not in kwargs:
|
||||
if (
|
||||
"cert_reqs" not in kwargs
|
||||
or resolve_cert_reqs(kwargs["cert_reqs"]) is ssl.CERT_REQUIRED
|
||||
):
|
||||
raise ssl.SSLError(
|
||||
"DoQ requires at least one CA loaded in order to verify the remote peer certificate. "
|
||||
"Add ?cert_reqs=0 to disable certificate checks."
|
||||
)
|
||||
|
||||
configuration = QuicConfiguration(
|
||||
is_client=True,
|
||||
alpn_protocols=["doq"],
|
||||
server_name=self._server
|
||||
if "server_hostname" not in kwargs
|
||||
else kwargs["server_hostname"],
|
||||
verify_mode=resolve_cert_reqs(kwargs["cert_reqs"])
|
||||
if "cert_reqs" in kwargs
|
||||
else ssl.CERT_REQUIRED,
|
||||
cadata=kwargs["ca_cert_data"].encode()
|
||||
if "ca_cert_data" in kwargs
|
||||
else None,
|
||||
cafile=kwargs["ca_certs"] if "ca_certs" in kwargs else None,
|
||||
idle_timeout=300.0,
|
||||
)
|
||||
|
||||
if "cert_file" in kwargs:
|
||||
configuration.load_cert_chain(
|
||||
kwargs["cert_file"],
|
||||
kwargs["key_file"] if "key_file" in kwargs else None,
|
||||
kwargs["key_password"] if "key_password" in kwargs else None,
|
||||
)
|
||||
elif "cert_data" in kwargs:
|
||||
configuration.load_cert_chain(
|
||||
kwargs["cert_data"],
|
||||
kwargs["key_data"] if "key_data" in kwargs else None,
|
||||
kwargs["key_password"] if "key_password" in kwargs else None,
|
||||
)
|
||||
|
||||
self._quic = QuicConnection(configuration=configuration)
|
||||
|
||||
self._read_semaphore: asyncio.Semaphore = asyncio.Semaphore()
|
||||
self._connect_attempt: asyncio.Event = asyncio.Event()
|
||||
self._handshake_event: asyncio.Event = asyncio.Event()
|
||||
|
||||
self._terminated: bool = False
|
||||
self._should_disconnect: bool = False
|
||||
|
||||
# DNS over QUIC mandate the size-prefix (unsigned int, 2b)
|
||||
self._rfc1035_prefix_mandated = True
|
||||
|
||||
self._unconsumed: deque[DomainNameServerReturn] = deque()
|
||||
self._pending: deque[DomainNameServerQuery] = deque()
|
||||
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
if (
|
||||
not self._terminated
|
||||
and self._socket is not None
|
||||
and not self._socket.should_connect()
|
||||
):
|
||||
self._quic.close()
|
||||
|
||||
while True:
|
||||
datagrams = self._quic.datagrams_to_send(monotonic())
|
||||
|
||||
if not datagrams:
|
||||
break
|
||||
|
||||
for datagram in datagrams:
|
||||
data, addr = datagram
|
||||
await self._socket.sendall(data)
|
||||
|
||||
self._socket.close()
|
||||
await self._socket.wait_for_close()
|
||||
self._terminated = True
|
||||
if self._socket is None or self._socket.should_connect():
|
||||
self._terminated = True
|
||||
|
||||
def is_available(self) -> bool:
|
||||
if self._terminated:
|
||||
return False
|
||||
if self._socket is None or self._socket.should_connect():
|
||||
return True
|
||||
self._quic.handle_timer(monotonic())
|
||||
if hasattr(self._quic, "_close_event") and self._quic._close_event is not None:
|
||||
self._terminated = True
|
||||
return not self._terminated
|
||||
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if host is None:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Tried to resolve 'localhost' using the QUICResolver"
|
||||
)
|
||||
if port is None:
|
||||
port = 0 # Defensive: stdlib cpy behavior
|
||||
if isinstance(port, str):
|
||||
port = int(port) # Defensive: stdlib cpy behavior
|
||||
if port < 0:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Servname not supported for ai_socktype"
|
||||
)
|
||||
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii") # Defensive: stdlib cpy behavior
|
||||
|
||||
if is_ipv4(host):
|
||||
if family == socket.AF_INET6:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
type,
|
||||
6,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
),
|
||||
)
|
||||
]
|
||||
elif is_ipv6(host):
|
||||
if family == socket.AF_INET:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET6,
|
||||
type,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
validate_length_of(host)
|
||||
|
||||
if self._socket is None and self._connect_attempt.is_set() is False:
|
||||
self._connect_attempt.set()
|
||||
assert self.server is not None
|
||||
self._quic.connect((self._server, self._port), monotonic())
|
||||
self._socket = await SystemResolver().create_connection(
|
||||
(self.server, self.port or 853),
|
||||
timeout=self._timeout,
|
||||
source_address=self._source_address,
|
||||
socket_options=None,
|
||||
socket_kind=self._socket_type,
|
||||
)
|
||||
await self.__exchange_until(HandshakeCompleted, receive_first=False)
|
||||
self._handshake_event.set()
|
||||
else:
|
||||
await self._handshake_event.wait()
|
||||
|
||||
assert self._socket is not None
|
||||
|
||||
remote_preemptive_quic_rr = False
|
||||
|
||||
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
|
||||
quic_upgrade_via_dns_rr = False
|
||||
|
||||
tbq = []
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET]:
|
||||
tbq.append(SupportedQueryType.A)
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
|
||||
tbq.append(SupportedQueryType.AAAA)
|
||||
|
||||
if quic_upgrade_via_dns_rr:
|
||||
tbq.append(SupportedQueryType.HTTPS)
|
||||
|
||||
queries = DomainNameServerQuery.bulk(host, *tbq)
|
||||
open_streams = []
|
||||
|
||||
for q in queries:
|
||||
payload = bytes(q)
|
||||
|
||||
self._pending.append(q)
|
||||
|
||||
if self._rfc1035_prefix_mandated is True:
|
||||
payload = rfc1035_pack(payload)
|
||||
|
||||
stream_id = self._quic.get_next_available_stream_id()
|
||||
self._quic.send_stream_data(stream_id, payload, True)
|
||||
|
||||
open_streams.append(stream_id)
|
||||
|
||||
for dg in self._quic.datagrams_to_send(monotonic()):
|
||||
await self._socket.sendall(dg[0])
|
||||
|
||||
responses: list[DomainNameServerReturn] = []
|
||||
|
||||
while len(responses) < len(tbq):
|
||||
await self._read_semaphore.acquire()
|
||||
if self._unconsumed:
|
||||
dns_resp = None
|
||||
for query in queries:
|
||||
for unconsumed in self._unconsumed:
|
||||
if unconsumed.id == query.id:
|
||||
dns_resp = unconsumed
|
||||
responses.append(dns_resp)
|
||||
break
|
||||
if dns_resp:
|
||||
break
|
||||
if dns_resp:
|
||||
self._unconsumed.remove(dns_resp)
|
||||
self._pending.remove(query)
|
||||
self._read_semaphore.release()
|
||||
continue
|
||||
|
||||
try:
|
||||
events: list[StreamDataReceived] = await self.__exchange_until( # type: ignore[assignment]
|
||||
StreamDataReceived,
|
||||
receive_first=True,
|
||||
event_type_collectable=(StreamDataReceived,),
|
||||
respect_end_stream_signal=False,
|
||||
)
|
||||
|
||||
payload = b"".join([e.data for e in events])
|
||||
|
||||
while rfc1035_should_read(payload):
|
||||
events.extend(
|
||||
await self.__exchange_until( # type: ignore[arg-type]
|
||||
StreamDataReceived,
|
||||
receive_first=True,
|
||||
event_type_collectable=(StreamDataReceived,),
|
||||
respect_end_stream_signal=False,
|
||||
)
|
||||
)
|
||||
payload = b"".join([e.data for e in events])
|
||||
except (TimeoutError, OSError, socket.timeout, ConnectionError) as e:
|
||||
raise socket.gaierror(
|
||||
"Got unexpectedly disconnected while waiting for name resolution"
|
||||
) from e
|
||||
|
||||
self._read_semaphore.release()
|
||||
|
||||
if not payload:
|
||||
continue
|
||||
|
||||
#: We can receive two responses at once (or more, concatenated). Let's unwrap them.
|
||||
fragments = rfc1035_unpack(payload)
|
||||
|
||||
for fragment in fragments:
|
||||
dns_resp = DomainNameServerReturn(fragment)
|
||||
|
||||
if any(dns_resp.id == _.id for _ in queries):
|
||||
responses.append(dns_resp)
|
||||
|
||||
query_tbr: DomainNameServerQuery | None = None
|
||||
|
||||
for query_tbr in self._pending:
|
||||
if query_tbr.id == dns_resp.id:
|
||||
break
|
||||
if query_tbr:
|
||||
self._pending.remove(query_tbr)
|
||||
else:
|
||||
self._unconsumed.append(dns_resp)
|
||||
|
||||
if self._should_disconnect:
|
||||
await self.close()
|
||||
self._should_disconnect = False
|
||||
self._terminated = True
|
||||
|
||||
results = []
|
||||
|
||||
for response in responses:
|
||||
if not response.is_ok:
|
||||
if response.rcode == 2:
|
||||
raise socket.gaierror(
|
||||
f"DNSSEC validation failure. Check http://dnsviz.net/d/{host}/dnssec/ and http://dnssec-debugger.verisignlabs.com/{host} for errors"
|
||||
)
|
||||
raise socket.gaierror(
|
||||
f"DNS returned an error: {COMMON_RCODE_LABEL[response.rcode] if response.rcode in COMMON_RCODE_LABEL else f'code {response.rcode}'}"
|
||||
)
|
||||
|
||||
for record in response.records:
|
||||
if record[0] == SupportedQueryType.HTTPS:
|
||||
assert isinstance(record[-1], dict)
|
||||
if "h3" in record[-1]["alpn"]:
|
||||
remote_preemptive_quic_rr = True
|
||||
continue
|
||||
|
||||
assert not isinstance(record[-1], dict)
|
||||
|
||||
inet_type = (
|
||||
socket.AF_INET
|
||||
if record[0] == SupportedQueryType.A
|
||||
else socket.AF_INET6
|
||||
)
|
||||
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
|
||||
(
|
||||
record[-1],
|
||||
port,
|
||||
)
|
||||
if inet_type == socket.AF_INET
|
||||
else (
|
||||
record[-1],
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
)
|
||||
|
||||
results.append(
|
||||
(
|
||||
inet_type,
|
||||
type,
|
||||
6 if type == socket.SOCK_STREAM else 17,
|
||||
"",
|
||||
dst_addr,
|
||||
)
|
||||
)
|
||||
|
||||
quic_results = []
|
||||
|
||||
if remote_preemptive_quic_rr:
|
||||
any_specified = False
|
||||
|
||||
for result in results:
|
||||
if result[1] == socket.SOCK_STREAM:
|
||||
quic_results.append(
|
||||
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
|
||||
)
|
||||
else:
|
||||
any_specified = True
|
||||
break
|
||||
|
||||
if any_specified:
|
||||
quic_results = []
|
||||
|
||||
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
|
||||
|
||||
async def __exchange_until(
|
||||
self,
|
||||
event_type: type[QuicEvent] | tuple[type[QuicEvent], ...],
|
||||
*,
|
||||
receive_first: bool = False,
|
||||
event_type_collectable: type[QuicEvent]
|
||||
| tuple[type[QuicEvent], ...]
|
||||
| None = None,
|
||||
respect_end_stream_signal: bool = True,
|
||||
) -> list[QuicEvent]:
|
||||
assert self._socket is not None
|
||||
|
||||
while True:
|
||||
if receive_first is False:
|
||||
now = monotonic()
|
||||
while True:
|
||||
datagrams = self._quic.datagrams_to_send(now)
|
||||
|
||||
if not datagrams:
|
||||
break
|
||||
|
||||
for datagram in datagrams:
|
||||
data, addr = datagram
|
||||
await self._socket.sendall(data)
|
||||
|
||||
events = []
|
||||
|
||||
while True:
|
||||
if not self._quic._events:
|
||||
data_in = await self._socket.recv(1500)
|
||||
|
||||
if not data_in:
|
||||
break
|
||||
|
||||
now = monotonic()
|
||||
|
||||
if not isinstance(data_in, list):
|
||||
self._quic.receive_datagram(
|
||||
data_in, (self._server, self._port), now
|
||||
)
|
||||
else:
|
||||
for gro_segment in data_in:
|
||||
self._quic.receive_datagram(
|
||||
gro_segment, (self._server, self._port), now
|
||||
)
|
||||
|
||||
while True:
|
||||
datagrams = self._quic.datagrams_to_send(now)
|
||||
|
||||
if not datagrams:
|
||||
break
|
||||
|
||||
for datagram in datagrams:
|
||||
data, addr = datagram
|
||||
await self._socket.sendall(data)
|
||||
|
||||
for ev in iter(self._quic.next_event, None):
|
||||
if isinstance(ev, ConnectionTerminated):
|
||||
if ev.error_code == 298:
|
||||
raise SSLError(
|
||||
"DNS over QUIC did not succeed (Error 298). Chain certificate verification failed."
|
||||
)
|
||||
raise socket.gaierror(
|
||||
f"DNS over QUIC encountered a unrecoverable failure (error {ev.error_code} {ev.reason_phrase})"
|
||||
)
|
||||
elif isinstance(ev, StreamReset):
|
||||
self._terminated = True
|
||||
raise socket.gaierror(
|
||||
"DNS over QUIC server submitted a StreamReset. A request was rejected."
|
||||
)
|
||||
elif isinstance(ev, StopSendingReceived):
|
||||
self._should_disconnect = True
|
||||
continue
|
||||
|
||||
if event_type_collectable:
|
||||
if isinstance(ev, event_type_collectable):
|
||||
events.append(ev)
|
||||
else:
|
||||
events.append(ev)
|
||||
|
||||
if isinstance(ev, event_type):
|
||||
if not respect_end_stream_signal:
|
||||
return events
|
||||
if hasattr(ev, "stream_ended") and ev.stream_ended:
|
||||
return events
|
||||
elif hasattr(ev, "stream_ended") is False:
|
||||
return events
|
||||
|
||||
return events
|
||||
|
||||
|
||||
class AdGuardResolver(
|
||||
QUICResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "adguard"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any):
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class NextDNSResolver(
|
||||
QUICResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "nextdns"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any):
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
super().__init__("dns.nextdns.io", port, *patterns, **kwargs)
|
||||
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._ssl import (
|
||||
AdGuardResolver,
|
||||
CloudflareResolver,
|
||||
GoogleResolver,
|
||||
OpenDNSResolver,
|
||||
Quad9Resolver,
|
||||
TLSResolver,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"TLSResolver",
|
||||
"GoogleResolver",
|
||||
"CloudflareResolver",
|
||||
"AdGuardResolver",
|
||||
"Quad9Resolver",
|
||||
"OpenDNSResolver",
|
||||
)
|
||||
@@ -0,0 +1,197 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
|
||||
from .....util._async.ssl_ import ssl_wrap_socket
|
||||
from .....util.ssl_ import resolve_cert_reqs
|
||||
from ...protocols import ProtocolResolver
|
||||
from ..dou import PlainResolver
|
||||
from ..system import SystemResolver
|
||||
|
||||
|
||||
class TLSResolver(PlainResolver):
|
||||
"""
|
||||
Basic DNS resolver over TLS.
|
||||
Comply with RFC 7858: https://datatracker.ietf.org/doc/html/rfc7858
|
||||
"""
|
||||
|
||||
protocol = ProtocolResolver.DOT
|
||||
implementation = "ssl"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: str | None,
|
||||
port: int | None = None,
|
||||
*patterns: str,
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
self._socket_type = socket.SOCK_STREAM
|
||||
|
||||
super().__init__(server, port or 853, *patterns, **kwargs)
|
||||
|
||||
# DNS over TLS mandate the size-prefix (unsigned int, 2 bytes)
|
||||
self._rfc1035_prefix_mandated = True
|
||||
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if self._socket is None and self._connect_attempt.is_set() is False:
|
||||
assert self.server is not None
|
||||
self._connect_attempt.set()
|
||||
self._socket = await SystemResolver().create_connection(
|
||||
(self.server, self.port or 853),
|
||||
timeout=self._timeout,
|
||||
source_address=self._source_address,
|
||||
socket_options=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1, "tcp"),),
|
||||
socket_kind=self._socket_type,
|
||||
)
|
||||
self._socket = await ssl_wrap_socket(
|
||||
self._socket,
|
||||
server_hostname=self.server
|
||||
if "server_hostname" not in self._kwargs
|
||||
else self._kwargs["server_hostname"],
|
||||
keyfile=self._kwargs["key_file"]
|
||||
if "key_file" in self._kwargs
|
||||
else None,
|
||||
certfile=self._kwargs["cert_file"]
|
||||
if "cert_file" in self._kwargs
|
||||
else None,
|
||||
cert_reqs=resolve_cert_reqs(self._kwargs["cert_reqs"])
|
||||
if "cert_reqs" in self._kwargs
|
||||
else None,
|
||||
ca_certs=self._kwargs["ca_certs"]
|
||||
if "ca_certs" in self._kwargs
|
||||
else None,
|
||||
ssl_version=self._kwargs["ssl_version"]
|
||||
if "ssl_version" in self._kwargs
|
||||
else None,
|
||||
ciphers=self._kwargs["ciphers"] if "ciphers" in self._kwargs else None,
|
||||
ca_cert_dir=self._kwargs["ca_cert_dir"]
|
||||
if "ca_cert_dir" in self._kwargs
|
||||
else None,
|
||||
key_password=self._kwargs["key_password"]
|
||||
if "key_password" in self._kwargs
|
||||
else None,
|
||||
ca_cert_data=self._kwargs["ca_cert_data"]
|
||||
if "ca_cert_data" in self._kwargs
|
||||
else None,
|
||||
certdata=self._kwargs["cert_data"]
|
||||
if "cert_data" in self._kwargs
|
||||
else None,
|
||||
keydata=self._kwargs["key_data"]
|
||||
if "key_data" in self._kwargs
|
||||
else None,
|
||||
)
|
||||
self._connect_finalized.set()
|
||||
|
||||
return await super().getaddrinfo(
|
||||
host,
|
||||
port,
|
||||
family=family,
|
||||
type=type,
|
||||
proto=proto,
|
||||
flags=flags,
|
||||
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
|
||||
)
|
||||
|
||||
|
||||
class GoogleResolver(
|
||||
TLSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "google"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("dns.google", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class CloudflareResolver(
|
||||
TLSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "cloudflare"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("1.1.1.1", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class AdGuardResolver(
|
||||
TLSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "adguard"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class OpenDNSResolver(
|
||||
TLSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "opendns"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("dns.opendns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class Quad9Resolver(
|
||||
TLSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "quad9"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("dns11.quad9.net", port, *patterns, **kwargs)
|
||||
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._socket import (
|
||||
AdGuardResolver,
|
||||
CloudflareResolver,
|
||||
GoogleResolver,
|
||||
PlainResolver,
|
||||
Quad9Resolver,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"PlainResolver",
|
||||
"CloudflareResolver",
|
||||
"GoogleResolver",
|
||||
"Quad9Resolver",
|
||||
"AdGuardResolver",
|
||||
)
|
||||
@@ -0,0 +1,431 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
import typing
|
||||
from collections import deque
|
||||
|
||||
from ....ssa import AsyncSocket
|
||||
from ...protocols import (
|
||||
COMMON_RCODE_LABEL,
|
||||
DomainNameServerQuery,
|
||||
DomainNameServerReturn,
|
||||
ProtocolResolver,
|
||||
SupportedQueryType,
|
||||
)
|
||||
from ...utils import (
|
||||
is_ipv4,
|
||||
is_ipv6,
|
||||
packet_fragment,
|
||||
rfc1035_pack,
|
||||
rfc1035_should_read,
|
||||
rfc1035_unpack,
|
||||
validate_length_of,
|
||||
)
|
||||
from ..protocols import AsyncBaseResolver
|
||||
from ..system import SystemResolver
|
||||
|
||||
|
||||
class PlainResolver(AsyncBaseResolver):
|
||||
"""
|
||||
Minimalist DNS resolver over UDP
|
||||
Comply with RFC 1035: https://datatracker.ietf.org/doc/html/rfc1035
|
||||
|
||||
EDNS is not supported, yet. But we plan to. Willing to contribute?
|
||||
"""
|
||||
|
||||
protocol = ProtocolResolver.DOU
|
||||
implementation = "socket"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: str | None,
|
||||
port: int | None = None,
|
||||
*patterns: str,
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
super().__init__(server, port, *patterns, **kwargs)
|
||||
|
||||
self._socket: AsyncSocket | None = None
|
||||
|
||||
if not hasattr(self, "_socket_type"):
|
||||
self._socket_type = socket.SOCK_DGRAM
|
||||
|
||||
if "source_address" in kwargs and isinstance(kwargs["source_address"], str):
|
||||
if ":" in kwargs["source_address"]:
|
||||
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
|
||||
self._source_address: tuple[str, int] | None = (bind_ip, int(bind_port))
|
||||
else:
|
||||
self._source_address = (kwargs["source_address"], 0)
|
||||
else:
|
||||
self._source_address = None
|
||||
|
||||
if "timeout" in kwargs and isinstance(
|
||||
kwargs["timeout"],
|
||||
(
|
||||
float,
|
||||
int,
|
||||
),
|
||||
):
|
||||
self._timeout: float | int | None = kwargs["timeout"]
|
||||
else:
|
||||
self._timeout = None
|
||||
|
||||
#: Only useful for inheritance, e.g. DNS over TLS support dns-message but require a prefix.
|
||||
self._rfc1035_prefix_mandated: bool = False
|
||||
|
||||
self._unconsumed: deque[DomainNameServerReturn] = deque()
|
||||
self._pending: deque[DomainNameServerQuery] = deque()
|
||||
|
||||
self._read_semaphore: asyncio.Semaphore = asyncio.Semaphore()
|
||||
self._connect_attempt: asyncio.Event = asyncio.Event()
|
||||
self._connect_finalized: asyncio.Event = asyncio.Event()
|
||||
|
||||
self._terminated: bool = False
|
||||
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
if not self._terminated:
|
||||
with self._lock:
|
||||
if self._socket is not None:
|
||||
self._socket.close()
|
||||
await self._socket.wait_for_close()
|
||||
self._terminated = True
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return not self._terminated
|
||||
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if host is None:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Tried to resolve 'localhost' from a PlainResolver"
|
||||
)
|
||||
|
||||
if port is None:
|
||||
port = 0 # Defensive: stdlib cpy behavior
|
||||
if isinstance(port, str):
|
||||
port = int(port) # Defensive: stdlib cpy behavior
|
||||
if port < 0:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Servname not supported for ai_socktype"
|
||||
)
|
||||
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii") # Defensive: stdlib cpy behavior
|
||||
|
||||
if is_ipv4(host):
|
||||
if family == socket.AF_INET6:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
type,
|
||||
6,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
),
|
||||
)
|
||||
]
|
||||
elif is_ipv6(host):
|
||||
if family == socket.AF_INET:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET6,
|
||||
type,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
validate_length_of(host)
|
||||
|
||||
if self._socket is None and self._connect_attempt.is_set() is False:
|
||||
self._connect_attempt.set()
|
||||
assert self.server is not None
|
||||
self._socket = await SystemResolver().create_connection(
|
||||
(self.server, self.port or 53),
|
||||
timeout=self._timeout,
|
||||
source_address=self._source_address,
|
||||
socket_options=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1, "tcp"),),
|
||||
socket_kind=self._socket_type,
|
||||
)
|
||||
self._connect_finalized.set()
|
||||
else:
|
||||
await self._connect_finalized.wait()
|
||||
assert self._socket is not None
|
||||
await self._socket.wait_for_readiness()
|
||||
|
||||
remote_preemptive_quic_rr = False
|
||||
|
||||
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
|
||||
quic_upgrade_via_dns_rr = False
|
||||
|
||||
tbq = []
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET]:
|
||||
tbq.append(SupportedQueryType.A)
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
|
||||
tbq.append(SupportedQueryType.AAAA)
|
||||
|
||||
if quic_upgrade_via_dns_rr:
|
||||
tbq.append(SupportedQueryType.HTTPS)
|
||||
|
||||
queries = DomainNameServerQuery.bulk(host, *tbq)
|
||||
|
||||
for q in queries:
|
||||
payload = bytes(q)
|
||||
self._pending.append(q)
|
||||
|
||||
if self._rfc1035_prefix_mandated is True:
|
||||
payload = rfc1035_pack(payload)
|
||||
|
||||
await self._socket.sendall(payload)
|
||||
|
||||
responses: list[DomainNameServerReturn] = []
|
||||
|
||||
while len(responses) < len(tbq):
|
||||
await self._read_semaphore.acquire()
|
||||
#: There we want to verify if another thread got a response that belong to this thread.
|
||||
if self._unconsumed:
|
||||
dns_resp = None
|
||||
|
||||
for query in queries:
|
||||
for unconsumed in self._unconsumed:
|
||||
if unconsumed.id == query.id:
|
||||
dns_resp = unconsumed
|
||||
responses.append(dns_resp)
|
||||
break
|
||||
if dns_resp:
|
||||
break
|
||||
|
||||
if dns_resp:
|
||||
self._pending.remove(query)
|
||||
self._unconsumed.remove(dns_resp)
|
||||
self._read_semaphore.release()
|
||||
continue
|
||||
|
||||
try:
|
||||
data_in_or_segments = await self._socket.recv(1500)
|
||||
|
||||
if isinstance(data_in_or_segments, list):
|
||||
payloads = data_in_or_segments
|
||||
elif data_in_or_segments:
|
||||
payloads = [data_in_or_segments]
|
||||
else:
|
||||
payloads = []
|
||||
|
||||
if self._rfc1035_prefix_mandated is True and payloads:
|
||||
payload = b"".join(payloads)
|
||||
while rfc1035_should_read(payload):
|
||||
extra = await self._socket.recv(1500)
|
||||
if isinstance(extra, list):
|
||||
payload += b"".join(extra)
|
||||
else:
|
||||
payload += extra
|
||||
payloads = [payload]
|
||||
except (TimeoutError, OSError, socket.timeout, ConnectionError) as e:
|
||||
raise socket.gaierror(
|
||||
"Got unexpectedly disconnected while waiting for name resolution"
|
||||
) from e
|
||||
|
||||
self._read_semaphore.release()
|
||||
|
||||
if not payloads:
|
||||
self._terminated = True
|
||||
raise socket.gaierror(
|
||||
"Got unexpectedly disconnected while waiting for name resolution"
|
||||
)
|
||||
|
||||
pending_raw_identifiers = [_.raw_id for _ in self._pending]
|
||||
|
||||
for payload in payloads:
|
||||
#: We can receive two responses at once (or more, concatenated). Let's unwrap them.
|
||||
if self._rfc1035_prefix_mandated is True:
|
||||
fragments = rfc1035_unpack(payload)
|
||||
else:
|
||||
fragments = packet_fragment(payload, *pending_raw_identifiers)
|
||||
|
||||
for fragment in fragments:
|
||||
dns_resp = DomainNameServerReturn(fragment)
|
||||
|
||||
if any(dns_resp.id == _.id for _ in queries):
|
||||
responses.append(dns_resp)
|
||||
|
||||
query_tbr: DomainNameServerQuery | None = None
|
||||
|
||||
for query_tbr in self._pending:
|
||||
if query_tbr.id == dns_resp.id:
|
||||
break
|
||||
|
||||
if query_tbr:
|
||||
self._pending.remove(query_tbr)
|
||||
else:
|
||||
self._unconsumed.append(dns_resp)
|
||||
|
||||
results = []
|
||||
|
||||
for response in responses:
|
||||
if not response.is_ok:
|
||||
if response.rcode == 2:
|
||||
raise socket.gaierror(
|
||||
f"DNSSEC validation failure. Check http://dnsviz.net/d/{host}/dnssec/ and http://dnssec-debugger.verisignlabs.com/{host} for errors"
|
||||
)
|
||||
raise socket.gaierror(
|
||||
f"DNS returned an error: {COMMON_RCODE_LABEL[response.rcode] if response.rcode in COMMON_RCODE_LABEL else f'code {response.rcode}'}"
|
||||
)
|
||||
|
||||
for record in response.records:
|
||||
if record[0] == SupportedQueryType.HTTPS:
|
||||
assert isinstance(record[-1], dict)
|
||||
if "h3" in record[-1]["alpn"]:
|
||||
remote_preemptive_quic_rr = True
|
||||
continue
|
||||
|
||||
assert not isinstance(record[-1], dict)
|
||||
|
||||
inet_type = (
|
||||
socket.AF_INET
|
||||
if record[0] == SupportedQueryType.A
|
||||
else socket.AF_INET6
|
||||
)
|
||||
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
|
||||
(
|
||||
record[-1],
|
||||
port,
|
||||
)
|
||||
if inet_type == socket.AF_INET
|
||||
else (
|
||||
record[-1],
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
)
|
||||
|
||||
results.append(
|
||||
(
|
||||
inet_type,
|
||||
type,
|
||||
6 if type == socket.SOCK_STREAM else 17,
|
||||
"",
|
||||
dst_addr,
|
||||
)
|
||||
)
|
||||
|
||||
quic_results = []
|
||||
|
||||
if remote_preemptive_quic_rr:
|
||||
any_specified = False
|
||||
|
||||
for result in results:
|
||||
if result[1] == socket.SOCK_STREAM:
|
||||
quic_results.append(
|
||||
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
|
||||
)
|
||||
else:
|
||||
any_specified = True
|
||||
break
|
||||
|
||||
if any_specified:
|
||||
quic_results = []
|
||||
|
||||
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
|
||||
|
||||
|
||||
class CloudflareResolver(
|
||||
PlainResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "cloudflare"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("1.1.1.1", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class GoogleResolver(
|
||||
PlainResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "google"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("8.8.8.8", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class Quad9Resolver(
|
||||
PlainResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "quad9"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("9.9.9.9", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class AdGuardResolver(
|
||||
PlainResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "adguard"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("94.140.14.140", port, *patterns, **kwargs)
|
||||
@@ -0,0 +1,243 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import typing
|
||||
from abc import ABCMeta
|
||||
from base64 import b64encode
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
from ....util import parse_url
|
||||
from ..factories import ResolverDescription
|
||||
from ..protocols import ProtocolResolver
|
||||
from .protocols import AsyncBaseResolver
|
||||
|
||||
|
||||
class AsyncResolverFactory(metaclass=ABCMeta):
|
||||
@staticmethod
|
||||
def has(
|
||||
protocol: ProtocolResolver,
|
||||
specifier: str | None = None,
|
||||
implementation: str | None = None,
|
||||
) -> bool:
|
||||
package_name: str = __name__.split(".")[0]
|
||||
module_expr = f".{protocol.value.replace('-', '_')}"
|
||||
|
||||
if implementation:
|
||||
module_expr += f"._{implementation.replace('-', '_').lower()}"
|
||||
|
||||
try:
|
||||
resolver_module = importlib.import_module(
|
||||
module_expr, f"{package_name}.contrib.resolver._async"
|
||||
)
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
implementations: list[tuple[str, type[AsyncBaseResolver]]] = inspect.getmembers(
|
||||
resolver_module,
|
||||
lambda e: isinstance(e, type)
|
||||
and issubclass(e, AsyncBaseResolver)
|
||||
and (
|
||||
(specifier is None and e.specifier is None) or specifier == e.specifier
|
||||
),
|
||||
)
|
||||
|
||||
if not implementations:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def new(
|
||||
protocol: ProtocolResolver,
|
||||
specifier: str | None = None,
|
||||
implementation: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncBaseResolver:
|
||||
package_name: str = __name__.split(".")[0]
|
||||
|
||||
module_expr = f".{protocol.value.replace('-', '_')}"
|
||||
|
||||
if implementation:
|
||||
module_expr += f"._{implementation.replace('-', '_').lower()}"
|
||||
|
||||
spe_msg = " " if specifier is None else f' (w/ specifier "{specifier}") '
|
||||
|
||||
try:
|
||||
resolver_module = importlib.import_module(
|
||||
module_expr, f"{package_name}.contrib.resolver._async"
|
||||
)
|
||||
except ImportError as e:
|
||||
raise NotImplementedError(
|
||||
f"{protocol}{spe_msg}cannot be loaded. Tried to import '{module_expr}'. Did you specify a non-existent implementation?"
|
||||
) from e
|
||||
|
||||
implementations: list[tuple[str, type[AsyncBaseResolver]]] = inspect.getmembers(
|
||||
resolver_module,
|
||||
lambda e: isinstance(e, type)
|
||||
and issubclass(e, AsyncBaseResolver)
|
||||
and (
|
||||
(specifier is None and e.specifier is None) or specifier == e.specifier
|
||||
)
|
||||
and hasattr(e, "protocol")
|
||||
and e.protocol == protocol,
|
||||
)
|
||||
|
||||
if not implementations:
|
||||
raise NotImplementedError(
|
||||
f"{protocol}{spe_msg}cannot be loaded. "
|
||||
"No compatible implementation available. "
|
||||
"Make sure your implementation inherit from BaseResolver."
|
||||
)
|
||||
|
||||
implementation_target: type[AsyncBaseResolver] = implementations.pop()[1]
|
||||
|
||||
return implementation_target(**kwargs)
|
||||
|
||||
|
||||
class AsyncResolverDescription(ResolverDescription):
|
||||
"""Describe how a BaseResolver must be instantiated."""
|
||||
|
||||
def new(self) -> AsyncBaseResolver:
|
||||
kwargs = {**self.kwargs}
|
||||
|
||||
if self.server:
|
||||
kwargs["server"] = self.server
|
||||
if self.port:
|
||||
kwargs["port"] = self.port
|
||||
if self.host_patterns:
|
||||
kwargs["patterns"] = self.host_patterns
|
||||
|
||||
return AsyncResolverFactory.new(
|
||||
self.protocol,
|
||||
self.specifier,
|
||||
self.implementation,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_url(url: str) -> AsyncResolverDescription:
|
||||
parsed_url = parse_url(url)
|
||||
|
||||
schema = parsed_url.scheme
|
||||
|
||||
if schema is None:
|
||||
raise ValueError("Given DNS url is missing a protocol")
|
||||
|
||||
specifier = None
|
||||
implementation = None
|
||||
|
||||
if "+" in schema:
|
||||
schema, specifier = tuple(schema.lower().split("+", 1))
|
||||
|
||||
protocol = ProtocolResolver(schema)
|
||||
kwargs: dict[str, typing.Any] = {}
|
||||
|
||||
if parsed_url.path:
|
||||
kwargs["path"] = parsed_url.path
|
||||
|
||||
if parsed_url.auth:
|
||||
kwargs["headers"] = dict()
|
||||
if ":" in parsed_url.auth:
|
||||
username, password = parsed_url.auth.split(":")
|
||||
|
||||
username = username.strip("'\"")
|
||||
password = password.strip("'\"")
|
||||
|
||||
kwargs["headers"]["Authorization"] = (
|
||||
f"Basic {b64encode(f'{username}:{password}'.encode()).decode()}"
|
||||
)
|
||||
else:
|
||||
kwargs["headers"]["Authorization"] = f"Bearer {parsed_url.auth}"
|
||||
|
||||
if parsed_url.query:
|
||||
parameters = parse_qs(parsed_url.query)
|
||||
|
||||
for parameter in parameters:
|
||||
if not parameters[parameter]:
|
||||
continue
|
||||
|
||||
parameter_insensible = parameter.lower()
|
||||
|
||||
if (
|
||||
isinstance(parameters[parameter], list)
|
||||
and len(parameters[parameter]) > 1
|
||||
):
|
||||
if parameter == "implementation":
|
||||
raise ValueError("Only one implementation can be passed to URL")
|
||||
|
||||
values = []
|
||||
|
||||
for e in parameters[parameter]:
|
||||
if "," in e:
|
||||
values.extend(e.split(","))
|
||||
else:
|
||||
values.append(e)
|
||||
|
||||
if parameter_insensible in kwargs:
|
||||
if isinstance(kwargs[parameter_insensible], list):
|
||||
kwargs[parameter_insensible].extend(values)
|
||||
else:
|
||||
values.append(kwargs[parameter_insensible])
|
||||
kwargs[parameter_insensible] = values
|
||||
continue
|
||||
|
||||
kwargs[parameter_insensible] = values
|
||||
continue
|
||||
|
||||
value: str = parameters[parameter][0].lower().strip(" ")
|
||||
|
||||
if parameter == "implementation":
|
||||
implementation = value
|
||||
continue
|
||||
|
||||
if "," in value:
|
||||
list_of_values = value.split(",")
|
||||
|
||||
if parameter_insensible in kwargs:
|
||||
if isinstance(kwargs[parameter_insensible], list):
|
||||
kwargs[parameter_insensible].extend(list_of_values)
|
||||
else:
|
||||
list_of_values.append(kwargs[parameter_insensible])
|
||||
continue
|
||||
|
||||
kwargs[parameter_insensible] = list_of_values
|
||||
continue
|
||||
|
||||
value_converted: bool | int | float | None = None
|
||||
|
||||
if value in ["false", "true"]:
|
||||
value_converted = True if value == "true" else False
|
||||
elif value.isdigit():
|
||||
value_converted = int(value)
|
||||
elif (
|
||||
value.count(".") == 1
|
||||
and value.index(".") > 0
|
||||
and value.replace(".", "").isdigit()
|
||||
):
|
||||
value_converted = float(value)
|
||||
|
||||
kwargs[parameter_insensible] = (
|
||||
value if value_converted is None else value_converted
|
||||
)
|
||||
|
||||
host_patterns: list[str] = []
|
||||
|
||||
if "hosts" in kwargs:
|
||||
host_patterns = (
|
||||
kwargs["hosts"].split(",")
|
||||
if isinstance(kwargs["hosts"], str)
|
||||
else kwargs["hosts"]
|
||||
)
|
||||
del kwargs["hosts"]
|
||||
|
||||
return AsyncResolverDescription(
|
||||
protocol,
|
||||
specifier,
|
||||
implementation,
|
||||
parsed_url.host,
|
||||
parsed_url.port,
|
||||
*host_patterns,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -0,0 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._dict import InMemoryResolver
|
||||
|
||||
__all__ = ("InMemoryResolver",)
|
||||
@@ -0,0 +1,186 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
|
||||
from .....util.url import _IPV6_ADDRZ_RE
|
||||
from ...protocols import ProtocolResolver
|
||||
from ...utils import is_ipv4, is_ipv6
|
||||
from ..protocols import AsyncBaseResolver
|
||||
|
||||
|
||||
class InMemoryResolver(AsyncBaseResolver):
|
||||
protocol = ProtocolResolver.MANUAL
|
||||
implementation = "dict"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any):
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
kwargs.pop("port")
|
||||
super().__init__(None, None, *patterns, **kwargs)
|
||||
|
||||
self._maxsize = 65535 if "maxsize" not in kwargs else int(kwargs["maxsize"])
|
||||
self._hosts: dict[str, list[tuple[socket.AddressFamily, str]]] = {}
|
||||
|
||||
if self._host_patterns:
|
||||
for record in self._host_patterns:
|
||||
if ":" not in record:
|
||||
continue
|
||||
hostname, addr = record.split(":", 1)
|
||||
self.register(hostname, addr)
|
||||
self._host_patterns = tuple([])
|
||||
|
||||
def recycle(self) -> AsyncBaseResolver:
|
||||
return self
|
||||
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
pass # no-op
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True
|
||||
|
||||
def have_constraints(self) -> bool:
|
||||
return True
|
||||
|
||||
def support(self, hostname: str | bytes | None) -> bool | None:
|
||||
if hostname is None:
|
||||
hostname = "localhost"
|
||||
if isinstance(hostname, bytes):
|
||||
hostname = hostname.decode("ascii")
|
||||
return hostname in self._hosts
|
||||
|
||||
def register(self, hostname: str, ipaddr: str) -> None:
|
||||
if hostname not in self._hosts:
|
||||
self._hosts[hostname] = []
|
||||
else:
|
||||
for e in self._hosts[hostname]:
|
||||
t, addr = e
|
||||
if addr in ipaddr:
|
||||
return
|
||||
|
||||
if _IPV6_ADDRZ_RE.match(ipaddr):
|
||||
self._hosts[hostname].append((socket.AF_INET6, ipaddr[1:-1]))
|
||||
elif is_ipv6(ipaddr):
|
||||
self._hosts[hostname].append((socket.AF_INET6, ipaddr))
|
||||
else:
|
||||
self._hosts[hostname].append((socket.AF_INET, ipaddr))
|
||||
|
||||
if len(self._hosts) > self._maxsize:
|
||||
k = None
|
||||
for k in self._hosts.keys():
|
||||
break
|
||||
if k:
|
||||
self._hosts.pop(k)
|
||||
|
||||
def clear(self, hostname: str) -> None:
|
||||
if hostname in self._hosts:
|
||||
del self._hosts[hostname]
|
||||
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if host is None:
|
||||
host = "localhost" # Defensive: stdlib cpy behavior
|
||||
|
||||
if port is None:
|
||||
port = 0 # Defensive: stdlib cpy behavior
|
||||
if isinstance(port, str):
|
||||
port = int(port) # Defensive: stdlib cpy behavior
|
||||
if port < 0:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Servname not supported for ai_socktype"
|
||||
)
|
||||
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii") # Defensive: stdlib cpy behavior
|
||||
|
||||
if is_ipv4(host):
|
||||
if family == socket.AF_INET6:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
type,
|
||||
6,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
),
|
||||
)
|
||||
]
|
||||
elif is_ipv6(host):
|
||||
if family == socket.AF_INET:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET6,
|
||||
type,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
results: list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
] = []
|
||||
|
||||
if host not in self._hosts:
|
||||
raise socket.gaierror(f"no records found for hostname {host} in-memory")
|
||||
|
||||
for entry in self._hosts[host]:
|
||||
addr_type, addr_target = entry
|
||||
|
||||
if family != socket.AF_UNSPEC:
|
||||
if family != addr_type:
|
||||
continue
|
||||
|
||||
results.append(
|
||||
(
|
||||
addr_type,
|
||||
type,
|
||||
6 if type == socket.SOCK_STREAM else 17,
|
||||
"",
|
||||
(addr_target, port)
|
||||
if addr_type == socket.AF_INET
|
||||
else (addr_target, port, 0, 0),
|
||||
)
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise socket.gaierror(f"no records found for hostname {host} in-memory")
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,105 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
|
||||
from ...protocols import ProtocolResolver
|
||||
from ...utils import is_ipv4, is_ipv6
|
||||
from ..protocols import AsyncBaseResolver
|
||||
|
||||
|
||||
class NullResolver(AsyncBaseResolver):
|
||||
protocol = ProtocolResolver.NULL
|
||||
implementation = "dummy"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any):
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
kwargs.pop("port")
|
||||
super().__init__(None, None, *patterns, **kwargs)
|
||||
|
||||
def recycle(self) -> AsyncBaseResolver:
|
||||
return self
|
||||
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
pass # no-op
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True
|
||||
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if host is None:
|
||||
host = "localhost" # Defensive: stdlib cpy behavior
|
||||
|
||||
if port is None:
|
||||
port = 0 # Defensive: stdlib cpy behavior
|
||||
if isinstance(port, str):
|
||||
port = int(port) # Defensive: stdlib cpy behavior
|
||||
if port < 0:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Servname not supported for ai_socktype"
|
||||
)
|
||||
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii") # Defensive: stdlib cpy behavior
|
||||
|
||||
if is_ipv4(host):
|
||||
if family == socket.AF_INET6:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
type,
|
||||
6,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
),
|
||||
)
|
||||
]
|
||||
elif is_ipv6(host):
|
||||
if family == socket.AF_INET:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET6,
|
||||
type,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
raise socket.gaierror(f"Tried to resolve '{host}' using the NullResolver")
|
||||
|
||||
|
||||
__all__ = ("NullResolver",)
|
||||
@@ -0,0 +1,375 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import socket
|
||||
import struct
|
||||
import sys
|
||||
import typing
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from ...._constant import UDP_LINUX_GRO
|
||||
from ...._typing import _TYPE_SOCKET_OPTIONS, _TYPE_TIMEOUT_INTERNAL
|
||||
from ....exceptions import LocationParseError
|
||||
from ....util.connection import _set_socket_options, allowed_gai_family
|
||||
from ....util.timeout import _DEFAULT_TIMEOUT
|
||||
from ...ssa import AsyncSocket
|
||||
from ...ssa._timeout import timeout as timeout_
|
||||
from ..protocols import BaseResolver
|
||||
|
||||
|
||||
class AsyncBaseResolver(BaseResolver, metaclass=ABCMeta):
|
||||
def recycle(self) -> AsyncBaseResolver:
|
||||
return super().recycle() # type: ignore[return-value]
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
"""Terminate the given resolver instance. This should render it unusable. Further inquiries should raise an exception."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
"""This method align itself on the standard library socket.getaddrinfo(). It must be implemented as-is on your Resolver."""
|
||||
raise NotImplementedError
|
||||
|
||||
# This function is copied from socket.py in the Python 2.7 standard
|
||||
# library test suite. Added to its signature is only `socket_options`.
|
||||
# One additional modification is that we avoid binding to IPv6 servers
|
||||
# discovered in DNS if the system doesn't have IPv6 functionality.
|
||||
async def create_connection( # type: ignore[override]
|
||||
self,
|
||||
address: tuple[str, int],
|
||||
timeout: _TYPE_TIMEOUT_INTERNAL = _DEFAULT_TIMEOUT,
|
||||
source_address: tuple[str, int] | None = None,
|
||||
socket_options: _TYPE_SOCKET_OPTIONS | None = None,
|
||||
socket_kind: socket.SocketKind = socket.SOCK_STREAM,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
timing_hook: typing.Callable[[tuple[timedelta, timedelta, datetime]], None]
|
||||
| None = None,
|
||||
default_socket_family: socket.AddressFamily = socket.AF_UNSPEC,
|
||||
) -> AsyncSocket:
|
||||
"""Connect to *address* and return the socket object.
|
||||
|
||||
Convenience function. Connect to *address* (a 2-tuple ``(host,
|
||||
port)``) and return the socket object. Passing the optional
|
||||
*timeout* parameter will set the timeout on the socket instance
|
||||
before attempting to connect. If no *timeout* is supplied, the
|
||||
global default timeout setting returned by :func:`socket.getdefaulttimeout`
|
||||
is used. If *source_address* is set it must be a tuple of (host, port)
|
||||
for the socket to bind as a source address before making the connection.
|
||||
An host of '' or port 0 tells the OS to use the default.
|
||||
"""
|
||||
|
||||
host, port = address
|
||||
if host.startswith("["):
|
||||
host = host.strip("[]")
|
||||
err = None
|
||||
|
||||
# Using the value from allowed_gai_family() in the context of getaddrinfo lets
|
||||
# us select whether to work with IPv4 DNS records, IPv6 records, or both.
|
||||
# The original create_connection function always returns all records.
|
||||
family = allowed_gai_family()
|
||||
|
||||
if family != socket.AF_UNSPEC:
|
||||
default_socket_family = family
|
||||
|
||||
if source_address is not None:
|
||||
if isinstance(
|
||||
ipaddress.ip_address(source_address[0]), ipaddress.IPv4Address
|
||||
):
|
||||
default_socket_family = socket.AF_INET
|
||||
else:
|
||||
default_socket_family = socket.AF_INET6
|
||||
|
||||
try:
|
||||
host.encode("idna")
|
||||
except UnicodeError:
|
||||
raise LocationParseError(f"'{host}', label empty or too long") from None
|
||||
|
||||
dt_pre_resolve = datetime.now(tz=timezone.utc)
|
||||
if timeout is not _DEFAULT_TIMEOUT and timeout is not None:
|
||||
# we can hang here in case of bad networking conditions
|
||||
# the DNS may never answer or the packets can be lost.
|
||||
# this isn't possible in sync mode. unfortunately.
|
||||
# found by user at https://github.com/jawah/niquests/issues/183
|
||||
# todo: find a way to limit getaddrinfo delays in sync mode.
|
||||
try:
|
||||
async with timeout_(timeout):
|
||||
records = await self.getaddrinfo(
|
||||
host,
|
||||
port,
|
||||
default_socket_family,
|
||||
socket_kind,
|
||||
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
|
||||
)
|
||||
except TimeoutError:
|
||||
raise socket.gaierror(
|
||||
f"unable to resolve '{host}' within timeout. the DNS server may be unresponsive."
|
||||
)
|
||||
else:
|
||||
records = await self.getaddrinfo(
|
||||
host,
|
||||
port,
|
||||
default_socket_family,
|
||||
socket_kind,
|
||||
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
|
||||
)
|
||||
delta_post_resolve = datetime.now(tz=timezone.utc) - dt_pre_resolve
|
||||
|
||||
dt_pre_established = datetime.now(tz=timezone.utc)
|
||||
for res in records:
|
||||
af, socktype, proto, canonname, sa = res
|
||||
sock = None
|
||||
try:
|
||||
sock = AsyncSocket(af, socktype, proto)
|
||||
|
||||
# we need to add this or reusing the same origin port will likely fail within
|
||||
# short period of time. kernel put port on wait shut.
|
||||
if source_address:
|
||||
try:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
|
||||
except (OSError, AttributeError): # Defensive: very old OS?
|
||||
try:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
except (
|
||||
OSError,
|
||||
AttributeError,
|
||||
): # Defensive: we can't do anything better than this.
|
||||
pass
|
||||
|
||||
try:
|
||||
sock.setsockopt(
|
||||
socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0)
|
||||
)
|
||||
except (OSError, AttributeError):
|
||||
pass
|
||||
|
||||
# attempt to leverage GRO when under Linux
|
||||
if socktype == socket.SOCK_DGRAM and sys.platform == "linux":
|
||||
try:
|
||||
sock.setsockopt(socket.SOL_UDP, UDP_LINUX_GRO, 1)
|
||||
except OSError: # Defensive: oh, well(...) anyway!
|
||||
pass
|
||||
|
||||
# If provided, set socket level options before connecting.
|
||||
_set_socket_options(sock, socket_options)
|
||||
|
||||
if timeout is not _DEFAULT_TIMEOUT:
|
||||
sock.settimeout(timeout)
|
||||
if source_address:
|
||||
sock.bind(source_address)
|
||||
|
||||
try:
|
||||
await sock.connect(sa)
|
||||
except asyncio.CancelledError:
|
||||
sock.close()
|
||||
raise
|
||||
|
||||
# Break explicitly a reference cycle
|
||||
err = None
|
||||
|
||||
delta_post_established = (
|
||||
datetime.now(tz=timezone.utc) - dt_pre_established
|
||||
)
|
||||
|
||||
if timing_hook is not None:
|
||||
timing_hook(
|
||||
(
|
||||
delta_post_resolve,
|
||||
delta_post_established,
|
||||
datetime.now(tz=timezone.utc),
|
||||
)
|
||||
)
|
||||
|
||||
return sock
|
||||
except (OSError, OverflowError) as _:
|
||||
err = _
|
||||
if sock is not None:
|
||||
sock.close()
|
||||
if isinstance(_, OverflowError):
|
||||
break
|
||||
|
||||
if err is not None:
|
||||
try:
|
||||
raise err
|
||||
finally:
|
||||
# Break explicitly a reference cycle
|
||||
err = None
|
||||
else:
|
||||
raise OSError("getaddrinfo returns an empty list")
|
||||
|
||||
|
||||
class AsyncManyResolver(AsyncBaseResolver):
|
||||
"""
|
||||
Special resolver that use many child resolver. Priorities
|
||||
are based on given order (list of BaseResolver).
|
||||
"""
|
||||
|
||||
def __init__(self, *resolvers: AsyncBaseResolver) -> None:
|
||||
super().__init__(None, None)
|
||||
|
||||
self._size = len(resolvers)
|
||||
|
||||
self._unconstrained: list[AsyncBaseResolver] = [
|
||||
_ for _ in resolvers if not _.have_constraints()
|
||||
]
|
||||
self._constrained: list[AsyncBaseResolver] = [
|
||||
_ for _ in resolvers if _.have_constraints()
|
||||
]
|
||||
|
||||
self._concurrent: int = 0
|
||||
self._terminated: bool = False
|
||||
|
||||
def recycle(self) -> AsyncBaseResolver:
|
||||
resolvers = []
|
||||
|
||||
for resolver in self._unconstrained + self._constrained:
|
||||
resolvers.append(resolver.recycle())
|
||||
|
||||
return AsyncManyResolver(*resolvers)
|
||||
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
for resolver in self._unconstrained + self._constrained:
|
||||
await resolver.close()
|
||||
|
||||
self._terminated = True
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return not self._terminated
|
||||
|
||||
def __resolvers(
|
||||
self, constrained: bool = False
|
||||
) -> typing.Generator[AsyncBaseResolver, None, None]:
|
||||
resolvers = self._unconstrained if not constrained else self._constrained
|
||||
|
||||
if not resolvers:
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
self._concurrent += 1
|
||||
|
||||
try:
|
||||
resolver_count = len(resolvers)
|
||||
start_idx = (self._concurrent - 1) % resolver_count
|
||||
|
||||
for idx in range(start_idx, resolver_count):
|
||||
if not resolvers[idx].is_available():
|
||||
with self._lock:
|
||||
resolvers[idx] = resolvers[idx].recycle()
|
||||
yield resolvers[idx]
|
||||
|
||||
if start_idx > 0:
|
||||
for idx in range(0, start_idx):
|
||||
if not resolvers[idx].is_available():
|
||||
with self._lock:
|
||||
resolvers[idx] = resolvers[idx].recycle()
|
||||
yield resolvers[idx]
|
||||
finally:
|
||||
with self._lock:
|
||||
self._concurrent -= 1
|
||||
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii")
|
||||
if host is None:
|
||||
host = "localhost"
|
||||
|
||||
tested_resolvers = []
|
||||
|
||||
any_constrained_tried: bool = False
|
||||
|
||||
for resolver in self.__resolvers(True):
|
||||
can_resolve = resolver.support(host)
|
||||
|
||||
if can_resolve is True:
|
||||
any_constrained_tried = True
|
||||
|
||||
try:
|
||||
results = await resolver.getaddrinfo(
|
||||
host,
|
||||
port,
|
||||
family,
|
||||
type,
|
||||
proto,
|
||||
flags,
|
||||
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
|
||||
)
|
||||
|
||||
if results:
|
||||
return results
|
||||
except socket.gaierror as exc:
|
||||
if isinstance(exc.args[0], str) and (
|
||||
"DNSSEC" in exc.args[0] or "DNSKEY" in exc.args[0]
|
||||
):
|
||||
raise
|
||||
continue
|
||||
elif can_resolve is False:
|
||||
tested_resolvers.append(resolver)
|
||||
|
||||
if any_constrained_tried:
|
||||
raise socket.gaierror(
|
||||
f"Name or service not known: {host} using {self._size - len(self._unconstrained)} resolver(s)"
|
||||
)
|
||||
|
||||
for resolver in self.__resolvers():
|
||||
try:
|
||||
results = await resolver.getaddrinfo(
|
||||
host,
|
||||
port,
|
||||
family,
|
||||
type,
|
||||
proto,
|
||||
flags,
|
||||
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
|
||||
)
|
||||
|
||||
if results:
|
||||
return results
|
||||
except socket.gaierror as exc:
|
||||
if isinstance(exc.args[0], str) and (
|
||||
"DNSSEC" in exc.args[0] or "DNSKEY" in exc.args[0]
|
||||
):
|
||||
raise
|
||||
continue
|
||||
|
||||
raise socket.gaierror(
|
||||
f"Name or service not known: {host} using {self._size - len(self._constrained)} resolver(s)"
|
||||
)
|
||||
@@ -0,0 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._socket import SystemResolver
|
||||
|
||||
__all__ = ("SystemResolver",)
|
||||
@@ -0,0 +1,66 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
import typing
|
||||
|
||||
from ...protocols import ProtocolResolver
|
||||
from ..protocols import AsyncBaseResolver
|
||||
|
||||
|
||||
class SystemResolver(AsyncBaseResolver):
|
||||
implementation = "socket"
|
||||
protocol = ProtocolResolver.SYSTEM
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any):
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
kwargs.pop("port")
|
||||
super().__init__(None, None, *patterns, **kwargs)
|
||||
|
||||
def support(self, hostname: str | bytes | None) -> bool | None:
|
||||
if hostname is None:
|
||||
return True
|
||||
if isinstance(hostname, bytes):
|
||||
hostname = hostname.decode("ascii")
|
||||
if hostname == "localhost":
|
||||
return True
|
||||
return super().support(hostname)
|
||||
|
||||
def recycle(self) -> AsyncBaseResolver:
|
||||
return self
|
||||
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
pass # no-op!
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True
|
||||
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
return await asyncio.get_running_loop().getaddrinfo(
|
||||
host=host,
|
||||
port=port,
|
||||
family=family,
|
||||
type=type,
|
||||
proto=proto,
|
||||
flags=flags,
|
||||
)
|
||||
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._urllib3 import (
|
||||
AdGuardResolver,
|
||||
CloudflareResolver,
|
||||
GoogleResolver,
|
||||
HTTPSResolver,
|
||||
NextDNSResolver,
|
||||
OpenDNSResolver,
|
||||
Quad9Resolver,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"HTTPSResolver",
|
||||
"GoogleResolver",
|
||||
"CloudflareResolver",
|
||||
"AdGuardResolver",
|
||||
"OpenDNSResolver",
|
||||
"Quad9Resolver",
|
||||
"NextDNSResolver",
|
||||
)
|
||||
@@ -0,0 +1,641 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
from base64 import b64encode
|
||||
|
||||
from ...._collections import HTTPHeaderDict
|
||||
from ....backend import ConnectionInfo, HttpVersion, ResponsePromise
|
||||
from ....connectionpool import HTTPSConnectionPool
|
||||
from ....response import HTTPResponse
|
||||
from ....util.url import parse_url
|
||||
from ..protocols import (
|
||||
BaseResolver,
|
||||
DomainNameServerQuery,
|
||||
DomainNameServerReturn,
|
||||
ProtocolResolver,
|
||||
SupportedQueryType,
|
||||
)
|
||||
from ..utils import is_ipv4, is_ipv6, validate_length_of, parse_https_rdata
|
||||
|
||||
|
||||
class HTTPSResolver(BaseResolver):
|
||||
"""
|
||||
Advanced DNS over HTTPS resolver.
|
||||
No common ground emerged from IETF w/ JSON. Following Google’s DNS over HTTPS schematics that is
|
||||
also implemented at Cloudflare.
|
||||
|
||||
Support RFC 8484 without JSON. Disabled by default.
|
||||
"""
|
||||
|
||||
implementation = "urllib3"
|
||||
protocol = ProtocolResolver.DOH
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: str | None,
|
||||
port: int | None = None,
|
||||
*patterns: str,
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
super().__init__(server, port or 443, *patterns, **kwargs)
|
||||
|
||||
self._path: str = "/resolve"
|
||||
|
||||
if "path" in kwargs:
|
||||
if isinstance(kwargs["path"], str) and kwargs["path"] != "/":
|
||||
self._path = kwargs["path"]
|
||||
kwargs.pop("path")
|
||||
|
||||
self._rfc8484: bool = False
|
||||
|
||||
if "rfc8484" in kwargs:
|
||||
if kwargs["rfc8484"]:
|
||||
self._rfc8484 = True
|
||||
kwargs.pop("rfc8484")
|
||||
|
||||
assert self._server is not None
|
||||
|
||||
if "source_address" in kwargs:
|
||||
if isinstance(kwargs["source_address"], str):
|
||||
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
|
||||
|
||||
if bind_ip and bind_port.isdigit():
|
||||
kwargs["source_address"] = (
|
||||
bind_ip,
|
||||
int(bind_port),
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid source_address given in parameters")
|
||||
else:
|
||||
raise ValueError("invalid source_address given in parameters")
|
||||
|
||||
if "proxy" in kwargs:
|
||||
kwargs["_proxy"] = parse_url(kwargs["proxy"])
|
||||
kwargs.pop("proxy")
|
||||
|
||||
if "maxsize" not in kwargs:
|
||||
kwargs["maxsize"] = 10
|
||||
|
||||
if "proxy_headers" in kwargs and "_proxy" in kwargs:
|
||||
proxy_headers = HTTPHeaderDict()
|
||||
|
||||
if not isinstance(kwargs["proxy_headers"], list):
|
||||
kwargs["proxy_headers"] = [kwargs["proxy_headers"]]
|
||||
|
||||
for item in kwargs["proxy_headers"]:
|
||||
if ":" not in item:
|
||||
raise ValueError("Passed header is invalid in DNS parameters")
|
||||
|
||||
k, v = item.split(":", 1)
|
||||
proxy_headers.add(k, v)
|
||||
|
||||
kwargs["_proxy_headers"] = proxy_headers
|
||||
|
||||
if "headers" in kwargs:
|
||||
headers = HTTPHeaderDict()
|
||||
|
||||
if not isinstance(kwargs["headers"], list):
|
||||
kwargs["headers"] = [kwargs["headers"]]
|
||||
|
||||
for item in kwargs["headers"]:
|
||||
if ":" not in item:
|
||||
raise ValueError("Passed header is invalid in DNS parameters")
|
||||
|
||||
k, v = item.split(":", 1)
|
||||
headers.add(k, v)
|
||||
|
||||
kwargs["headers"] = headers
|
||||
|
||||
if "disabled_svn" in kwargs:
|
||||
if not isinstance(kwargs["disabled_svn"], list):
|
||||
kwargs["disabled_svn"] = [kwargs["disabled_svn"]]
|
||||
|
||||
disabled_svn = set()
|
||||
|
||||
for svn in kwargs["disabled_svn"]:
|
||||
svn = svn.lower()
|
||||
|
||||
if svn == "h11":
|
||||
disabled_svn.add(HttpVersion.h11)
|
||||
elif svn == "h2":
|
||||
disabled_svn.add(HttpVersion.h2)
|
||||
elif svn == "h3":
|
||||
disabled_svn.add(HttpVersion.h3)
|
||||
|
||||
kwargs["disabled_svn"] = disabled_svn
|
||||
|
||||
if "on_post_connection" in kwargs and callable(kwargs["on_post_connection"]):
|
||||
self._connection_callback: (
|
||||
typing.Callable[[ConnectionInfo], None] | None
|
||||
) = kwargs["on_post_connection"]
|
||||
kwargs.pop("on_post_connection")
|
||||
else:
|
||||
self._connection_callback = None
|
||||
|
||||
self._pool = HTTPSConnectionPool(self._server, self._port, **kwargs)
|
||||
|
||||
def close(self) -> None:
|
||||
self._pool.close()
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self._pool.pool is not None
|
||||
|
||||
def getaddrinfo(
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if host is None:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Tried to resolve 'localhost' from a HTTPSResolver"
|
||||
)
|
||||
|
||||
if port is None:
|
||||
port = 0 # Defensive: stdlib cpy behavior
|
||||
if isinstance(port, str):
|
||||
port = int(port) # Defensive: stdlib cpy behavior
|
||||
if port < 0:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Servname not supported for ai_socktype"
|
||||
)
|
||||
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii") # Defensive: stdlib cpy behavior
|
||||
|
||||
if is_ipv4(host):
|
||||
if family == socket.AF_INET6:
|
||||
raise socket.gaierror("Address family for hostname not supported")
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
type,
|
||||
6,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
),
|
||||
)
|
||||
]
|
||||
elif is_ipv6(host):
|
||||
if family == socket.AF_INET:
|
||||
raise socket.gaierror("Address family for hostname not supported")
|
||||
return [
|
||||
(
|
||||
socket.AF_INET6,
|
||||
type,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
validate_length_of(host)
|
||||
|
||||
promises: list[HTTPResponse | ResponsePromise] = []
|
||||
remote_preemptive_quic_rr = False
|
||||
|
||||
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
|
||||
quic_upgrade_via_dns_rr = False
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET]:
|
||||
if not self._rfc8484:
|
||||
promises.append(
|
||||
self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{"name": host, "type": "1"},
|
||||
headers={"Accept": "application/dns-json"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
dns_query = DomainNameServerQuery(
|
||||
host, SupportedQueryType.A, override_id=0
|
||||
)
|
||||
dns_payload = bytes(dns_query)
|
||||
promises.append(
|
||||
self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{
|
||||
"dns": b64encode(dns_payload).decode().replace("=", ""),
|
||||
},
|
||||
headers={"Accept": "application/dns-message"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
|
||||
if not self._rfc8484:
|
||||
promises.append(
|
||||
self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{"name": host, "type": "28"},
|
||||
headers={"Accept": "application/dns-json"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
dns_query = DomainNameServerQuery(
|
||||
host, SupportedQueryType.AAAA, override_id=0
|
||||
)
|
||||
dns_payload = bytes(dns_query)
|
||||
|
||||
promises.append(
|
||||
self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{
|
||||
"dns": b64encode(dns_payload).decode().replace("=", ""),
|
||||
},
|
||||
headers={"Accept": "application/dns-message"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
|
||||
if quic_upgrade_via_dns_rr:
|
||||
if not self._rfc8484:
|
||||
promises.append(
|
||||
self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{"name": host, "type": "65"},
|
||||
headers={"Accept": "application/dns-json"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
dns_query = DomainNameServerQuery(
|
||||
host, SupportedQueryType.HTTPS, override_id=0
|
||||
)
|
||||
dns_payload = bytes(dns_query)
|
||||
|
||||
promises.append(
|
||||
self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{
|
||||
"dns": b64encode(dns_payload).decode().replace("=", ""),
|
||||
},
|
||||
headers={"Accept": "application/dns-message"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
|
||||
responses: list[HTTPResponse] = []
|
||||
|
||||
for promise in promises:
|
||||
if isinstance(promise, HTTPResponse):
|
||||
responses.append(promise)
|
||||
continue
|
||||
responses.append(self._pool.get_response(promise=promise)) # type: ignore[arg-type]
|
||||
|
||||
results: list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
] = []
|
||||
|
||||
for response in responses:
|
||||
if response.status >= 300:
|
||||
raise socket.gaierror(
|
||||
f"DNS over HTTPS was unsuccessful, server response status {response.status}."
|
||||
)
|
||||
|
||||
if not self._rfc8484:
|
||||
payload = response.json()
|
||||
|
||||
assert "Status" in payload and isinstance(payload["Status"], int)
|
||||
|
||||
if payload["Status"] != 0:
|
||||
msg = (
|
||||
payload["Comment"]
|
||||
if "Comment" in payload
|
||||
else f"Remote DNS indicated that an error occurred while providing resolution. Status {payload['Status']}."
|
||||
)
|
||||
|
||||
if isinstance(msg, list):
|
||||
msg = ", ".join(msg)
|
||||
|
||||
raise socket.gaierror(msg)
|
||||
|
||||
assert "Question" in payload and isinstance(payload["Question"], list)
|
||||
|
||||
if "Answer" not in payload:
|
||||
continue
|
||||
|
||||
assert isinstance(payload["Answer"], list)
|
||||
|
||||
for answer in payload["Answer"]:
|
||||
if answer["type"] not in [1, 28, 65]:
|
||||
continue
|
||||
|
||||
assert "data" in answer
|
||||
assert isinstance(answer["data"], str)
|
||||
|
||||
# DNS RR/HTTPS
|
||||
if answer["type"] == 65:
|
||||
# "1 . alpn=h3,h2 ipv4hint=104.16.132.229,104.16.133.229 ipv6hint=2606:4700::6810:84e5,2606:4700::6810:85e5"
|
||||
# or..
|
||||
# "1 . alpn=h2,h3"
|
||||
rr: str = answer["data"]
|
||||
|
||||
if rr.startswith("\\#"): # it means, raw, bytes.
|
||||
rr = "".join(rr[2:].split(" ")[2:])
|
||||
|
||||
try:
|
||||
raw_record = bytes.fromhex(rr)
|
||||
except ValueError:
|
||||
raw_record = b""
|
||||
|
||||
https_record = parse_https_rdata(raw_record)
|
||||
|
||||
if "h3" not in https_record["alpn"]:
|
||||
continue
|
||||
|
||||
remote_preemptive_quic_rr = True
|
||||
else:
|
||||
rr_decode: dict[str, str] = dict(
|
||||
tuple(_.lower().split("=", 1)) # type: ignore[misc]
|
||||
for _ in rr.split(" ")
|
||||
if "=" in _
|
||||
)
|
||||
|
||||
if "alpn" not in rr_decode or "h3" not in rr_decode["alpn"]:
|
||||
continue
|
||||
|
||||
remote_preemptive_quic_rr = True
|
||||
|
||||
if "ipv4hint" in rr_decode and family in [
|
||||
socket.AF_UNSPEC,
|
||||
socket.AF_INET,
|
||||
]:
|
||||
for ipv4 in rr_decode["ipv4hint"].split(","):
|
||||
results.append(
|
||||
(
|
||||
socket.AF_INET,
|
||||
socket.SOCK_DGRAM,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
ipv4,
|
||||
port,
|
||||
),
|
||||
)
|
||||
)
|
||||
if "ipv6hint" in rr_decode and family in [
|
||||
socket.AF_UNSPEC,
|
||||
socket.AF_INET6,
|
||||
]:
|
||||
for ipv6 in rr_decode["ipv6hint"].split(","):
|
||||
results.append(
|
||||
(
|
||||
socket.AF_INET6,
|
||||
socket.SOCK_DGRAM,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
ipv6,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
inet_type = (
|
||||
socket.AF_INET if answer["type"] == 1 else socket.AF_INET6
|
||||
)
|
||||
|
||||
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
|
||||
(
|
||||
answer["data"],
|
||||
port,
|
||||
)
|
||||
if inet_type == socket.AF_INET
|
||||
else (
|
||||
answer["data"],
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
)
|
||||
|
||||
results.append(
|
||||
(
|
||||
inet_type,
|
||||
type,
|
||||
6 if type == socket.SOCK_STREAM else 17,
|
||||
"",
|
||||
dst_addr,
|
||||
)
|
||||
)
|
||||
else:
|
||||
dns_resp = DomainNameServerReturn(response.data)
|
||||
|
||||
for record in dns_resp.records:
|
||||
if record[0] == SupportedQueryType.HTTPS:
|
||||
assert isinstance(record[-1], dict)
|
||||
if "h3" in record[-1]["alpn"]:
|
||||
remote_preemptive_quic_rr = True
|
||||
continue
|
||||
|
||||
assert not isinstance(record[-1], dict)
|
||||
|
||||
inet_type = (
|
||||
socket.AF_INET
|
||||
if record[0] == SupportedQueryType.A
|
||||
else socket.AF_INET6
|
||||
)
|
||||
dst_addr = (
|
||||
(
|
||||
record[-1],
|
||||
port,
|
||||
)
|
||||
if inet_type == socket.AF_INET
|
||||
else (
|
||||
record[-1],
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
)
|
||||
|
||||
results.append(
|
||||
(
|
||||
inet_type,
|
||||
type,
|
||||
6 if type == socket.SOCK_STREAM else 17,
|
||||
"",
|
||||
dst_addr,
|
||||
)
|
||||
)
|
||||
|
||||
quic_results: list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
] = []
|
||||
|
||||
if remote_preemptive_quic_rr:
|
||||
any_specified = False
|
||||
|
||||
for result in results:
|
||||
if result[1] == socket.SOCK_STREAM:
|
||||
quic_results.append(
|
||||
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
|
||||
)
|
||||
else:
|
||||
any_specified = True
|
||||
break
|
||||
|
||||
if any_specified:
|
||||
quic_results = []
|
||||
|
||||
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
|
||||
|
||||
|
||||
class GoogleResolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "google"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
if "rfc8484" in kwargs:
|
||||
if kwargs["rfc8484"]:
|
||||
kwargs["path"] = "/dns-query"
|
||||
super().__init__("dns.google", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class CloudflareResolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "cloudflare"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
kwargs.update({"path": "/dns-query"})
|
||||
super().__init__("cloudflare-dns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class AdGuardResolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "adguard"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
kwargs.update({"path": "/dns-query", "rfc8484": True})
|
||||
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class OpenDNSResolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "opendns"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
kwargs.update({"path": "/dns-query", "rfc8484": True})
|
||||
super().__init__("dns.opendns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class Quad9Resolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "quad9"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
kwargs.update({"path": "/dns-query", "rfc8484": True})
|
||||
super().__init__("dns11.quad9.net", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class NextDNSResolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "nextdns"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("dns.nextdns.io", port, *patterns, **kwargs)
|
||||
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
try:
|
||||
from ._qh3 import AdGuardResolver, NextDNSResolver, QUICResolver
|
||||
except ImportError:
|
||||
QUICResolver = None # type: ignore
|
||||
AdGuardResolver = None # type: ignore
|
||||
NextDNSResolver = None # type: ignore
|
||||
|
||||
|
||||
__all__ = (
|
||||
"QUICResolver",
|
||||
"AdGuardResolver",
|
||||
"NextDNSResolver",
|
||||
)
|
||||
@@ -0,0 +1,541 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import ssl
|
||||
import typing
|
||||
from collections import deque
|
||||
from ssl import SSLError
|
||||
from time import time as monotonic
|
||||
|
||||
from qh3.quic.configuration import QuicConfiguration
|
||||
from qh3.quic.connection import QuicConnection
|
||||
from qh3.quic.events import (
|
||||
ConnectionTerminated,
|
||||
HandshakeCompleted,
|
||||
QuicEvent,
|
||||
StopSendingReceived,
|
||||
StreamDataReceived,
|
||||
StreamReset,
|
||||
)
|
||||
|
||||
from ....util.ssl_ import IS_FIPS, resolve_cert_reqs
|
||||
from ...ssa._gro import _sock_has_gro, _sock_has_gso, sync_recv_gro, sync_sendmsg_gso
|
||||
from ..dou import PlainResolver
|
||||
from ..protocols import (
|
||||
COMMON_RCODE_LABEL,
|
||||
DomainNameServerQuery,
|
||||
DomainNameServerReturn,
|
||||
ProtocolResolver,
|
||||
SupportedQueryType,
|
||||
)
|
||||
from ..utils import (
|
||||
is_ipv4,
|
||||
is_ipv6,
|
||||
rfc1035_pack,
|
||||
rfc1035_should_read,
|
||||
rfc1035_unpack,
|
||||
validate_length_of,
|
||||
)
|
||||
|
||||
if IS_FIPS:
|
||||
raise ImportError(
|
||||
"DNS-over-QUIC disabled when Python is built with FIPS-compliant ssl module"
|
||||
)
|
||||
|
||||
|
||||
class QUICResolver(PlainResolver):
|
||||
protocol = ProtocolResolver.DOQ
|
||||
implementation = "qh3"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: str,
|
||||
port: int | None = None,
|
||||
*patterns: str,
|
||||
**kwargs: typing.Any,
|
||||
):
|
||||
super().__init__(server, port or 853, *patterns, **kwargs)
|
||||
|
||||
# qh3 load_default_certs seems off. need to investigate.
|
||||
if "ca_cert_data" not in kwargs and "ca_certs" not in kwargs:
|
||||
kwargs["ca_cert_data"] = []
|
||||
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
|
||||
try:
|
||||
ctx.load_default_certs()
|
||||
|
||||
for der in ctx.get_ca_certs(binary_form=True):
|
||||
kwargs["ca_cert_data"].append(ssl.DER_cert_to_PEM_cert(der))
|
||||
|
||||
if kwargs["ca_cert_data"]:
|
||||
kwargs["ca_cert_data"] = "".join(kwargs["ca_cert_data"])
|
||||
else:
|
||||
del kwargs["ca_cert_data"]
|
||||
except (AttributeError, ValueError, OSError):
|
||||
del kwargs["ca_cert_data"]
|
||||
|
||||
if "ca_cert_data" not in kwargs and "ca_certs" not in kwargs:
|
||||
if (
|
||||
"cert_reqs" not in kwargs
|
||||
or resolve_cert_reqs(kwargs["cert_reqs"]) is ssl.CERT_REQUIRED
|
||||
):
|
||||
raise ssl.SSLError(
|
||||
"DoQ requires at least one CA loaded in order to verify the remote peer certificate. "
|
||||
"Add ?cert_reqs=0 to disable certificate checks."
|
||||
)
|
||||
|
||||
configuration = QuicConfiguration(
|
||||
is_client=True,
|
||||
alpn_protocols=["doq"],
|
||||
server_name=self._server
|
||||
if "server_hostname" not in kwargs
|
||||
else kwargs["server_hostname"],
|
||||
verify_mode=resolve_cert_reqs(kwargs["cert_reqs"])
|
||||
if "cert_reqs" in kwargs
|
||||
else ssl.CERT_REQUIRED,
|
||||
cadata=kwargs["ca_cert_data"].encode()
|
||||
if "ca_cert_data" in kwargs
|
||||
else None,
|
||||
cafile=kwargs["ca_certs"] if "ca_certs" in kwargs else None,
|
||||
idle_timeout=300.0,
|
||||
)
|
||||
|
||||
if "cert_file" in kwargs:
|
||||
configuration.load_cert_chain(
|
||||
kwargs["cert_file"],
|
||||
kwargs["key_file"] if "key_file" in kwargs else None,
|
||||
kwargs["key_password"] if "key_password" in kwargs else None,
|
||||
)
|
||||
elif "cert_data" in kwargs:
|
||||
configuration.load_cert_chain(
|
||||
kwargs["cert_data"],
|
||||
kwargs["key_data"] if "key_data" in kwargs else None,
|
||||
kwargs["key_password"] if "key_password" in kwargs else None,
|
||||
)
|
||||
|
||||
self._quic = QuicConnection(configuration=configuration)
|
||||
|
||||
self._dgram_gro_enabled: bool = _sock_has_gro(self._socket)
|
||||
self._dgram_gso_enabled: bool = _sock_has_gso(self._socket)
|
||||
|
||||
self._quic.connect((self._server, self._port), monotonic())
|
||||
self.__exchange_until(HandshakeCompleted, receive_first=False)
|
||||
|
||||
self._terminated: bool = False
|
||||
self._should_disconnect: bool = False
|
||||
|
||||
# DNS over QUIC mandate the size-prefix (unsigned int, 2b)
|
||||
self._rfc1035_prefix_mandated = True
|
||||
|
||||
self._unconsumed: deque[DomainNameServerReturn] = deque()
|
||||
self._pending: deque[DomainNameServerQuery] = deque()
|
||||
|
||||
def close(self) -> None:
|
||||
if not self._terminated:
|
||||
with self._lock:
|
||||
self._quic.close()
|
||||
|
||||
while True:
|
||||
datagrams = self._quic.datagrams_to_send(monotonic())
|
||||
|
||||
if not datagrams:
|
||||
break
|
||||
|
||||
if self._dgram_gso_enabled and len(datagrams) > 1:
|
||||
sync_sendmsg_gso(self._socket, [d[0] for d in datagrams])
|
||||
else:
|
||||
for datagram in datagrams:
|
||||
self._socket.sendall(datagram[0])
|
||||
|
||||
self._socket.close()
|
||||
self._terminated = True
|
||||
|
||||
def is_available(self) -> bool:
|
||||
self._quic.handle_timer(monotonic())
|
||||
if hasattr(self._quic, "_close_event") and self._quic._close_event is not None:
|
||||
self._terminated = True
|
||||
return not self._terminated
|
||||
|
||||
def getaddrinfo(
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if host is None:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Tried to resolve 'localhost' using the QUICResolver"
|
||||
)
|
||||
if port is None:
|
||||
port = 0 # Defensive: stdlib cpy behavior
|
||||
if isinstance(port, str):
|
||||
port = int(port) # Defensive: stdlib cpy behavior
|
||||
if port < 0:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Servname not supported for ai_socktype"
|
||||
)
|
||||
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii") # Defensive: stdlib cpy behavior
|
||||
|
||||
if is_ipv4(host):
|
||||
if family == socket.AF_INET6:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
type,
|
||||
6,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
),
|
||||
)
|
||||
]
|
||||
elif is_ipv6(host):
|
||||
if family == socket.AF_INET:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET6,
|
||||
type,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
validate_length_of(host)
|
||||
|
||||
remote_preemptive_quic_rr = False
|
||||
|
||||
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
|
||||
quic_upgrade_via_dns_rr = False
|
||||
|
||||
tbq = []
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET]:
|
||||
tbq.append(SupportedQueryType.A)
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
|
||||
tbq.append(SupportedQueryType.AAAA)
|
||||
|
||||
if quic_upgrade_via_dns_rr:
|
||||
tbq.append(SupportedQueryType.HTTPS)
|
||||
|
||||
queries = DomainNameServerQuery.bulk(host, *tbq)
|
||||
open_streams = []
|
||||
|
||||
with self._lock:
|
||||
for q in queries:
|
||||
payload = bytes(q)
|
||||
|
||||
self._pending.append(q)
|
||||
|
||||
if self._rfc1035_prefix_mandated is True:
|
||||
payload = rfc1035_pack(payload)
|
||||
|
||||
stream_id = self._quic.get_next_available_stream_id()
|
||||
self._quic.send_stream_data(stream_id, payload, True)
|
||||
|
||||
open_streams.append(stream_id)
|
||||
|
||||
datagrams = self._quic.datagrams_to_send(monotonic())
|
||||
if self._dgram_gso_enabled and len(datagrams) > 1:
|
||||
sync_sendmsg_gso(self._socket, [d[0] for d in datagrams])
|
||||
else:
|
||||
for dg in datagrams:
|
||||
self._socket.sendall(dg[0])
|
||||
|
||||
responses: list[DomainNameServerReturn] = []
|
||||
|
||||
while len(responses) < len(tbq):
|
||||
with self._lock:
|
||||
if self._unconsumed:
|
||||
dns_resp = None
|
||||
for query in queries:
|
||||
for unconsumed in self._unconsumed:
|
||||
if unconsumed.id == query.id:
|
||||
dns_resp = unconsumed
|
||||
responses.append(dns_resp)
|
||||
break
|
||||
if dns_resp:
|
||||
break
|
||||
if dns_resp:
|
||||
self._unconsumed.remove(dns_resp)
|
||||
self._pending.remove(query)
|
||||
continue
|
||||
|
||||
try:
|
||||
events: list[StreamDataReceived] = self.__exchange_until( # type: ignore[assignment]
|
||||
StreamDataReceived,
|
||||
receive_first=True,
|
||||
event_type_collectable=(StreamDataReceived,),
|
||||
respect_end_stream_signal=False,
|
||||
)
|
||||
|
||||
payload = b"".join([e.data for e in events])
|
||||
|
||||
while rfc1035_should_read(payload):
|
||||
events.extend(
|
||||
self.__exchange_until( # type: ignore[arg-type]
|
||||
StreamDataReceived,
|
||||
receive_first=True,
|
||||
event_type_collectable=(StreamDataReceived,),
|
||||
respect_end_stream_signal=False,
|
||||
)
|
||||
)
|
||||
payload = b"".join([e.data for e in events])
|
||||
except (TimeoutError, OSError, socket.timeout, ConnectionError) as e:
|
||||
raise socket.gaierror(
|
||||
"Got unexpectedly disconnected while waiting for name resolution"
|
||||
) from e
|
||||
|
||||
if not payload:
|
||||
continue
|
||||
|
||||
#: We can receive two responses at once (or more, concatenated). Let's unwrap them.
|
||||
fragments = rfc1035_unpack(payload)
|
||||
|
||||
for fragment in fragments:
|
||||
dns_resp = DomainNameServerReturn(fragment)
|
||||
|
||||
if any(dns_resp.id == _.id for _ in queries):
|
||||
responses.append(dns_resp)
|
||||
|
||||
query_tbr: DomainNameServerQuery | None = None
|
||||
|
||||
for query_tbr in self._pending:
|
||||
if query_tbr.id == dns_resp.id:
|
||||
break
|
||||
if query_tbr:
|
||||
self._pending.remove(query_tbr)
|
||||
else:
|
||||
self._unconsumed.append(dns_resp)
|
||||
|
||||
if self._should_disconnect:
|
||||
with self._lock:
|
||||
self.close()
|
||||
self._should_disconnect = False
|
||||
self._terminated = True
|
||||
|
||||
results = []
|
||||
|
||||
for response in responses:
|
||||
if not response.is_ok:
|
||||
if response.rcode == 2:
|
||||
raise socket.gaierror(
|
||||
f"DNSSEC validation failure. Check http://dnsviz.net/d/{host}/dnssec/ and http://dnssec-debugger.verisignlabs.com/{host} for errors"
|
||||
)
|
||||
raise socket.gaierror(
|
||||
f"DNS returned an error: {COMMON_RCODE_LABEL[response.rcode] if response.rcode in COMMON_RCODE_LABEL else f'code {response.rcode}'}"
|
||||
)
|
||||
|
||||
for record in response.records:
|
||||
if record[0] == SupportedQueryType.HTTPS:
|
||||
assert isinstance(record[-1], dict)
|
||||
if "h3" in record[-1]["alpn"]:
|
||||
remote_preemptive_quic_rr = True
|
||||
continue
|
||||
|
||||
assert not isinstance(record[-1], dict)
|
||||
|
||||
inet_type = (
|
||||
socket.AF_INET
|
||||
if record[0] == SupportedQueryType.A
|
||||
else socket.AF_INET6
|
||||
)
|
||||
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
|
||||
(
|
||||
record[-1],
|
||||
port,
|
||||
)
|
||||
if inet_type == socket.AF_INET
|
||||
else (
|
||||
record[-1],
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
)
|
||||
|
||||
results.append(
|
||||
(
|
||||
inet_type,
|
||||
type,
|
||||
6 if type == socket.SOCK_STREAM else 17,
|
||||
"",
|
||||
dst_addr,
|
||||
)
|
||||
)
|
||||
|
||||
quic_results = []
|
||||
|
||||
if remote_preemptive_quic_rr:
|
||||
any_specified = False
|
||||
|
||||
for result in results:
|
||||
if result[1] == socket.SOCK_STREAM:
|
||||
quic_results.append(
|
||||
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
|
||||
)
|
||||
else:
|
||||
any_specified = True
|
||||
break
|
||||
|
||||
if any_specified:
|
||||
quic_results = []
|
||||
|
||||
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
|
||||
|
||||
def __exchange_until(
|
||||
self,
|
||||
event_type: type[QuicEvent] | tuple[type[QuicEvent], ...],
|
||||
*,
|
||||
receive_first: bool = False,
|
||||
event_type_collectable: type[QuicEvent]
|
||||
| tuple[type[QuicEvent], ...]
|
||||
| None = None,
|
||||
respect_end_stream_signal: bool = True,
|
||||
) -> list[QuicEvent]:
|
||||
while True:
|
||||
if receive_first is False:
|
||||
now = monotonic()
|
||||
while True:
|
||||
datagrams = self._quic.datagrams_to_send(now)
|
||||
|
||||
if not datagrams:
|
||||
break
|
||||
|
||||
if self._dgram_gso_enabled and len(datagrams) > 1:
|
||||
sync_sendmsg_gso(self._socket, [d[0] for d in datagrams])
|
||||
else:
|
||||
for datagram in datagrams:
|
||||
self._socket.sendall(datagram[0])
|
||||
|
||||
events = []
|
||||
|
||||
while True:
|
||||
if not self._quic._events:
|
||||
if self._dgram_gro_enabled:
|
||||
data_in = sync_recv_gro(self._socket, 65535)
|
||||
else:
|
||||
data_in = self._socket.recv(1500)
|
||||
|
||||
if not data_in:
|
||||
break
|
||||
|
||||
now = monotonic()
|
||||
|
||||
if isinstance(data_in, list):
|
||||
for gro_segment in data_in:
|
||||
self._quic.receive_datagram(
|
||||
gro_segment, (self._server, self._port), now
|
||||
)
|
||||
else:
|
||||
self._quic.receive_datagram(
|
||||
data_in, (self._server, self._port), now
|
||||
)
|
||||
|
||||
while True:
|
||||
now = monotonic()
|
||||
datagrams = self._quic.datagrams_to_send(now)
|
||||
|
||||
if not datagrams:
|
||||
break
|
||||
|
||||
if self._dgram_gso_enabled and len(datagrams) > 1:
|
||||
sync_sendmsg_gso(self._socket, [d[0] for d in datagrams])
|
||||
else:
|
||||
for datagram in datagrams:
|
||||
self._socket.sendall(datagram[0])
|
||||
|
||||
for ev in iter(self._quic.next_event, None):
|
||||
if isinstance(ev, ConnectionTerminated):
|
||||
if ev.error_code == 298:
|
||||
raise SSLError(
|
||||
"DNS over QUIC did not succeed (Error 298). Chain certificate verification failed."
|
||||
)
|
||||
raise socket.gaierror(
|
||||
f"DNS over QUIC encountered a unrecoverable failure (error {ev.error_code} {ev.reason_phrase})"
|
||||
)
|
||||
elif isinstance(ev, StreamReset):
|
||||
self._terminated = True
|
||||
raise socket.gaierror(
|
||||
"DNS over QUIC server submitted a StreamReset. A request was rejected."
|
||||
)
|
||||
elif isinstance(ev, StopSendingReceived):
|
||||
self._should_disconnect = True
|
||||
continue
|
||||
|
||||
if event_type_collectable:
|
||||
if isinstance(ev, event_type_collectable):
|
||||
events.append(ev)
|
||||
else:
|
||||
events.append(ev)
|
||||
|
||||
if isinstance(ev, event_type):
|
||||
if not respect_end_stream_signal:
|
||||
return events
|
||||
if hasattr(ev, "stream_ended") and ev.stream_ended:
|
||||
return events
|
||||
elif hasattr(ev, "stream_ended") is False:
|
||||
return events
|
||||
|
||||
return events
|
||||
|
||||
|
||||
class AdGuardResolver(
|
||||
QUICResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "adguard"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any):
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class NextDNSResolver(
|
||||
QUICResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "nextdns"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any):
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
super().__init__("dns.nextdns.io", port, *patterns, **kwargs)
|
||||
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._ssl import (
|
||||
AdGuardResolver,
|
||||
CloudflareResolver,
|
||||
GoogleResolver,
|
||||
OpenDNSResolver,
|
||||
Quad9Resolver,
|
||||
TLSResolver,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"TLSResolver",
|
||||
"GoogleResolver",
|
||||
"CloudflareResolver",
|
||||
"AdGuardResolver",
|
||||
"Quad9Resolver",
|
||||
"OpenDNSResolver",
|
||||
)
|
||||
@@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
|
||||
from ....util.ssl_ import resolve_cert_reqs, ssl_wrap_socket
|
||||
from ..dou import PlainResolver
|
||||
from ..protocols import ProtocolResolver
|
||||
from ..system import SystemResolver
|
||||
|
||||
|
||||
class TLSResolver(PlainResolver):
|
||||
"""
|
||||
Basic DNS resolver over TLS.
|
||||
Comply with RFC 7858: https://datatracker.ietf.org/doc/html/rfc7858
|
||||
"""
|
||||
|
||||
protocol = ProtocolResolver.DOT
|
||||
implementation = "ssl"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: str,
|
||||
port: int | None = None,
|
||||
*patterns: str,
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
if "timeout" in kwargs and isinstance(kwargs["timeout"], (int, float)):
|
||||
timeout = kwargs["timeout"]
|
||||
else:
|
||||
timeout = None
|
||||
|
||||
if "source_address" in kwargs and isinstance(kwargs["source_address"], str):
|
||||
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
|
||||
else:
|
||||
bind_ip, bind_port = "0.0.0.0", "0"
|
||||
|
||||
self._socket = SystemResolver().create_connection(
|
||||
(server, port or 853),
|
||||
timeout=timeout,
|
||||
source_address=(bind_ip, int(bind_port))
|
||||
if bind_ip != "0.0.0.0" or bind_port != "0"
|
||||
else None,
|
||||
socket_options=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1, "tcp"),),
|
||||
socket_kind=socket.SOCK_STREAM,
|
||||
)
|
||||
|
||||
super().__init__(server, port, *patterns, **kwargs)
|
||||
|
||||
self._socket = ssl_wrap_socket(
|
||||
self._socket,
|
||||
server_hostname=server
|
||||
if "server_hostname" not in kwargs
|
||||
else kwargs["server_hostname"],
|
||||
keyfile=kwargs["key_file"] if "key_file" in kwargs else None,
|
||||
certfile=kwargs["cert_file"] if "cert_file" in kwargs else None,
|
||||
cert_reqs=resolve_cert_reqs(kwargs["cert_reqs"])
|
||||
if "cert_reqs" in kwargs
|
||||
else None,
|
||||
ca_certs=kwargs["ca_certs"] if "ca_certs" in kwargs else None,
|
||||
ssl_version=kwargs["ssl_version"] if "ssl_version" in kwargs else None,
|
||||
ciphers=kwargs["ciphers"] if "ciphers" in kwargs else None,
|
||||
ca_cert_dir=kwargs["ca_cert_dir"] if "ca_cert_dir" in kwargs else None,
|
||||
key_password=kwargs["key_password"] if "key_password" in kwargs else None,
|
||||
ca_cert_data=kwargs["ca_cert_data"] if "ca_cert_data" in kwargs else None,
|
||||
certdata=kwargs["cert_data"] if "cert_data" in kwargs else None,
|
||||
keydata=kwargs["key_data"] if "key_data" in kwargs else None,
|
||||
)
|
||||
|
||||
# DNS over TLS mandate the size-prefix (unsigned int, 2 bytes)
|
||||
self._rfc1035_prefix_mandated = True
|
||||
|
||||
|
||||
class GoogleResolver(
|
||||
TLSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "google"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("dns.google", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class CloudflareResolver(
|
||||
TLSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "cloudflare"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("1.1.1.1", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class AdGuardResolver(
|
||||
TLSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "adguard"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class OpenDNSResolver(
|
||||
TLSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "opendns"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("dns.opendns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class Quad9Resolver(
|
||||
TLSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "quad9"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("dns11.quad9.net", port, *patterns, **kwargs)
|
||||
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._socket import (
|
||||
AdGuardResolver,
|
||||
CloudflareResolver,
|
||||
GoogleResolver,
|
||||
PlainResolver,
|
||||
Quad9Resolver,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"PlainResolver",
|
||||
"CloudflareResolver",
|
||||
"GoogleResolver",
|
||||
"Quad9Resolver",
|
||||
"AdGuardResolver",
|
||||
)
|
||||
@@ -0,0 +1,415 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
from collections import deque
|
||||
|
||||
from ...ssa._gro import _sock_has_gro, sync_recv_gro
|
||||
from ..protocols import (
|
||||
COMMON_RCODE_LABEL,
|
||||
BaseResolver,
|
||||
DomainNameServerQuery,
|
||||
DomainNameServerReturn,
|
||||
ProtocolResolver,
|
||||
SupportedQueryType,
|
||||
)
|
||||
from ..system import SystemResolver
|
||||
from ..utils import (
|
||||
is_ipv4,
|
||||
is_ipv6,
|
||||
packet_fragment,
|
||||
rfc1035_pack,
|
||||
rfc1035_should_read,
|
||||
rfc1035_unpack,
|
||||
validate_length_of,
|
||||
)
|
||||
|
||||
|
||||
class PlainResolver(BaseResolver):
|
||||
"""
|
||||
Minimalist DNS resolver over UDP
|
||||
Comply with RFC 1035: https://datatracker.ietf.org/doc/html/rfc1035
|
||||
|
||||
EDNS is not supported, yet. But we plan to. Willing to contribute?
|
||||
"""
|
||||
|
||||
protocol = ProtocolResolver.DOU
|
||||
implementation = "socket"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: str,
|
||||
port: int | None = None,
|
||||
*patterns: str,
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
super().__init__(server, port, *patterns, **kwargs)
|
||||
|
||||
if not hasattr(self, "_socket"):
|
||||
if "timeout" in kwargs and isinstance(
|
||||
kwargs["timeout"],
|
||||
(
|
||||
float,
|
||||
int,
|
||||
),
|
||||
):
|
||||
timeout = kwargs["timeout"]
|
||||
else:
|
||||
timeout = None
|
||||
|
||||
if "source_address" in kwargs and isinstance(kwargs["source_address"], str):
|
||||
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
|
||||
else:
|
||||
bind_ip, bind_port = "0.0.0.0", "0"
|
||||
|
||||
self._socket = SystemResolver().create_connection(
|
||||
(server, port or 53),
|
||||
timeout=timeout,
|
||||
source_address=(bind_ip, int(bind_port))
|
||||
if bind_ip != "0.0.0.0" or bind_port != "0"
|
||||
else None,
|
||||
socket_options=None,
|
||||
socket_kind=socket.SOCK_DGRAM,
|
||||
)
|
||||
|
||||
#: Only useful for inheritance, e.g. DNS over TLS support dns-message but require a prefix.
|
||||
self._rfc1035_prefix_mandated: bool = False
|
||||
|
||||
self._gro_enabled: bool = _sock_has_gro(self._socket)
|
||||
|
||||
self._unconsumed: deque[DomainNameServerReturn] = deque()
|
||||
self._pending: deque[DomainNameServerQuery] = deque()
|
||||
|
||||
self._terminated: bool = False
|
||||
|
||||
def close(self) -> None:
|
||||
if not self._terminated:
|
||||
with self._lock:
|
||||
if self._socket is not None:
|
||||
self._socket.shutdown(0)
|
||||
self._socket.close()
|
||||
self._terminated = True
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return not self._terminated
|
||||
|
||||
def getaddrinfo(
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if host is None:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Tried to resolve 'localhost' from a PlainResolver"
|
||||
)
|
||||
|
||||
if port is None:
|
||||
port = 0 # Defensive: stdlib cpy behavior
|
||||
if isinstance(port, str):
|
||||
port = int(port) # Defensive: stdlib cpy behavior
|
||||
if port < 0:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Servname not supported for ai_socktype"
|
||||
)
|
||||
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii") # Defensive: stdlib cpy behavior
|
||||
|
||||
if is_ipv4(host):
|
||||
if family == socket.AF_INET6:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
type,
|
||||
6,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
),
|
||||
)
|
||||
]
|
||||
elif is_ipv6(host):
|
||||
if family == socket.AF_INET:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET6,
|
||||
type,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
validate_length_of(host)
|
||||
|
||||
remote_preemptive_quic_rr = False
|
||||
|
||||
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
|
||||
quic_upgrade_via_dns_rr = False
|
||||
|
||||
tbq = []
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET]:
|
||||
tbq.append(SupportedQueryType.A)
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
|
||||
tbq.append(SupportedQueryType.AAAA)
|
||||
|
||||
if quic_upgrade_via_dns_rr:
|
||||
tbq.append(SupportedQueryType.HTTPS)
|
||||
|
||||
queries = DomainNameServerQuery.bulk(host, *tbq)
|
||||
|
||||
with self._lock:
|
||||
for q in queries:
|
||||
payload = bytes(q)
|
||||
self._pending.append(q)
|
||||
|
||||
if self._rfc1035_prefix_mandated is True:
|
||||
payload = rfc1035_pack(payload)
|
||||
|
||||
self._socket.sendall(payload)
|
||||
|
||||
responses: list[DomainNameServerReturn] = []
|
||||
|
||||
while len(responses) < len(tbq):
|
||||
with self._lock:
|
||||
#: There we want to verify if another thread got a response that belong to this thread.
|
||||
if self._unconsumed:
|
||||
dns_resp = None
|
||||
|
||||
for query in queries:
|
||||
for unconsumed in self._unconsumed:
|
||||
if unconsumed.id == query.id:
|
||||
dns_resp = unconsumed
|
||||
responses.append(dns_resp)
|
||||
break
|
||||
if dns_resp:
|
||||
break
|
||||
|
||||
if dns_resp:
|
||||
self._pending.remove(query)
|
||||
self._unconsumed.remove(dns_resp)
|
||||
continue
|
||||
|
||||
try:
|
||||
if self._gro_enabled:
|
||||
data_in_or_segments = sync_recv_gro(self._socket, 65535)
|
||||
else:
|
||||
data_in_or_segments = self._socket.recv(1500)
|
||||
|
||||
if isinstance(data_in_or_segments, list):
|
||||
payloads = data_in_or_segments
|
||||
elif data_in_or_segments:
|
||||
payloads = [data_in_or_segments]
|
||||
else:
|
||||
payloads = []
|
||||
|
||||
if self._rfc1035_prefix_mandated is True and payloads:
|
||||
payload = b"".join(payloads)
|
||||
while rfc1035_should_read(payload):
|
||||
extra = self._socket.recv(1500)
|
||||
if isinstance(extra, list):
|
||||
payload += b"".join(extra)
|
||||
else:
|
||||
payload += extra
|
||||
payloads = [payload]
|
||||
except (TimeoutError, OSError, socket.timeout, ConnectionError) as e:
|
||||
raise socket.gaierror(
|
||||
"Got unexpectedly disconnected while waiting for name resolution"
|
||||
) from e
|
||||
|
||||
if not payloads:
|
||||
self._terminated = True
|
||||
raise socket.gaierror(
|
||||
"Got unexpectedly disconnected while waiting for name resolution"
|
||||
)
|
||||
|
||||
pending_raw_identifiers = [_.raw_id for _ in self._pending]
|
||||
|
||||
for payload in payloads:
|
||||
#: We can receive two responses at once (or more, concatenated). Let's unwrap them.
|
||||
if self._rfc1035_prefix_mandated is True:
|
||||
fragments = rfc1035_unpack(payload)
|
||||
else:
|
||||
fragments = packet_fragment(payload, *pending_raw_identifiers)
|
||||
|
||||
for fragment in fragments:
|
||||
dns_resp = DomainNameServerReturn(fragment)
|
||||
|
||||
if any(dns_resp.id == _.id for _ in queries):
|
||||
responses.append(dns_resp)
|
||||
|
||||
query_tbr: DomainNameServerQuery | None = None
|
||||
|
||||
for query_tbr in self._pending:
|
||||
if query_tbr.id == dns_resp.id:
|
||||
break
|
||||
|
||||
if query_tbr:
|
||||
self._pending.remove(query_tbr)
|
||||
else:
|
||||
self._unconsumed.append(dns_resp)
|
||||
|
||||
results = []
|
||||
|
||||
for response in responses:
|
||||
if not response.is_ok:
|
||||
if response.rcode == 2:
|
||||
raise socket.gaierror(
|
||||
f"DNSSEC validation failure. Check http://dnsviz.net/d/{host}/dnssec/ and http://dnssec-debugger.verisignlabs.com/{host} for errors"
|
||||
)
|
||||
raise socket.gaierror(
|
||||
f"DNS returned an error: {COMMON_RCODE_LABEL[response.rcode] if response.rcode in COMMON_RCODE_LABEL else f'code {response.rcode}'}"
|
||||
)
|
||||
|
||||
for record in response.records:
|
||||
if record[0] == SupportedQueryType.HTTPS:
|
||||
assert isinstance(record[-1], dict)
|
||||
if "h3" in record[-1]["alpn"]:
|
||||
remote_preemptive_quic_rr = True
|
||||
continue
|
||||
|
||||
assert not isinstance(record[-1], dict)
|
||||
|
||||
inet_type = (
|
||||
socket.AF_INET
|
||||
if record[0] == SupportedQueryType.A
|
||||
else socket.AF_INET6
|
||||
)
|
||||
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
|
||||
(
|
||||
record[-1],
|
||||
port,
|
||||
)
|
||||
if inet_type == socket.AF_INET
|
||||
else (
|
||||
record[-1],
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
)
|
||||
|
||||
results.append(
|
||||
(
|
||||
inet_type,
|
||||
type,
|
||||
6 if type == socket.SOCK_STREAM else 17,
|
||||
"",
|
||||
dst_addr,
|
||||
)
|
||||
)
|
||||
|
||||
quic_results = []
|
||||
|
||||
if remote_preemptive_quic_rr:
|
||||
any_specified = False
|
||||
|
||||
for result in results:
|
||||
if result[1] == socket.SOCK_STREAM:
|
||||
quic_results.append(
|
||||
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
|
||||
)
|
||||
else:
|
||||
any_specified = True
|
||||
break
|
||||
|
||||
if any_specified:
|
||||
quic_results = []
|
||||
|
||||
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
|
||||
|
||||
|
||||
class CloudflareResolver(
|
||||
PlainResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "cloudflare"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("1.1.1.1", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class GoogleResolver(
|
||||
PlainResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "google"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("8.8.8.8", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class Quad9Resolver(
|
||||
PlainResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "quad9"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("9.9.9.9", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class AdGuardResolver(
|
||||
PlainResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "adguard"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("94.140.14.140", port, *patterns, **kwargs)
|
||||
@@ -0,0 +1,230 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import typing
|
||||
from abc import ABCMeta
|
||||
from base64 import b64encode
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
from ...util import parse_url
|
||||
from .protocols import BaseResolver, ProtocolResolver
|
||||
|
||||
|
||||
class ResolverFactory(metaclass=ABCMeta):
|
||||
@staticmethod
|
||||
def new(
|
||||
protocol: ProtocolResolver,
|
||||
specifier: str | None = None,
|
||||
implementation: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseResolver:
|
||||
package_name: str = __name__.split(".")[0]
|
||||
|
||||
module_expr = f".{protocol.value.replace('-', '_')}"
|
||||
|
||||
if implementation:
|
||||
module_expr += f"._{implementation.replace('-', '_').lower()}"
|
||||
|
||||
spe_msg = " " if specifier is None else f' (w/ specifier "{specifier}") '
|
||||
|
||||
try:
|
||||
resolver_module = importlib.import_module(
|
||||
module_expr, f"{package_name}.contrib.resolver"
|
||||
)
|
||||
except ImportError as e:
|
||||
raise NotImplementedError(
|
||||
f"{protocol}{spe_msg}cannot be loaded. Tried to import '{module_expr}'. Did you specify a non-existent implementation?"
|
||||
) from e
|
||||
|
||||
implementations: list[tuple[str, type[BaseResolver]]] = inspect.getmembers(
|
||||
resolver_module,
|
||||
lambda e: isinstance(e, type)
|
||||
and issubclass(e, BaseResolver)
|
||||
and (
|
||||
(specifier is None and e.specifier is None) or specifier == e.specifier
|
||||
),
|
||||
)
|
||||
|
||||
if not implementations:
|
||||
raise NotImplementedError(
|
||||
f"{protocol}{spe_msg}cannot be loaded. "
|
||||
"No compatible implementation available. "
|
||||
"Make sure your implementation inherit from BaseResolver."
|
||||
)
|
||||
|
||||
implementation_target: type[BaseResolver] = implementations.pop()[1]
|
||||
|
||||
return implementation_target(**kwargs)
|
||||
|
||||
|
||||
class ResolverDescription:
|
||||
"""Describe how a BaseResolver must be instantiated."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
protocol: ProtocolResolver,
|
||||
specifier: str | None = None,
|
||||
implementation: str | None = None,
|
||||
server: str | None = None,
|
||||
port: int | None = None,
|
||||
*host_patterns: str,
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
self.protocol = protocol
|
||||
self.specifier = specifier
|
||||
self.implementation = implementation
|
||||
self.server = server
|
||||
self.port = port
|
||||
self.host_patterns = host_patterns
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __setitem__(self, key: str, value: typing.Any) -> None:
|
||||
self.kwargs[key] = value
|
||||
|
||||
def __contains__(self, item: str) -> bool:
|
||||
return item in self.kwargs
|
||||
|
||||
def new(self) -> BaseResolver:
|
||||
kwargs = {**self.kwargs}
|
||||
|
||||
if self.server:
|
||||
kwargs["server"] = self.server
|
||||
if self.port:
|
||||
kwargs["port"] = self.port
|
||||
if self.host_patterns:
|
||||
kwargs["patterns"] = self.host_patterns
|
||||
|
||||
return ResolverFactory.new(
|
||||
self.protocol,
|
||||
self.specifier,
|
||||
self.implementation,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_url(url: str) -> ResolverDescription:
|
||||
parsed_url = parse_url(url)
|
||||
|
||||
schema = parsed_url.scheme
|
||||
|
||||
if schema is None:
|
||||
raise ValueError("Given DNS url is missing a protocol")
|
||||
|
||||
specifier = None
|
||||
implementation = None
|
||||
|
||||
if "+" in schema:
|
||||
schema, specifier = tuple(schema.lower().split("+", 1))
|
||||
|
||||
protocol = ProtocolResolver(schema)
|
||||
kwargs: dict[str, typing.Any] = {}
|
||||
|
||||
if parsed_url.path:
|
||||
kwargs["path"] = parsed_url.path
|
||||
|
||||
if parsed_url.auth:
|
||||
kwargs["headers"] = dict()
|
||||
if ":" in parsed_url.auth:
|
||||
username, password = parsed_url.auth.split(":")
|
||||
|
||||
username = username.strip("'\"")
|
||||
password = password.strip("'\"")
|
||||
|
||||
kwargs["headers"]["Authorization"] = (
|
||||
f"Basic {b64encode(f'{username}:{password}'.encode()).decode()}"
|
||||
)
|
||||
else:
|
||||
kwargs["headers"]["Authorization"] = f"Bearer {parsed_url.auth}"
|
||||
|
||||
if parsed_url.query:
|
||||
parameters = parse_qs(parsed_url.query)
|
||||
|
||||
for parameter in parameters:
|
||||
if not parameters[parameter]:
|
||||
continue
|
||||
|
||||
parameter_insensible = parameter.lower()
|
||||
|
||||
if (
|
||||
isinstance(parameters[parameter], list)
|
||||
and len(parameters[parameter]) > 1
|
||||
):
|
||||
if parameter == "implementation":
|
||||
raise ValueError("Only one implementation can be passed to URL")
|
||||
|
||||
values = []
|
||||
|
||||
for e in parameters[parameter]:
|
||||
if "," in e:
|
||||
values.extend(e.split(","))
|
||||
else:
|
||||
values.append(e)
|
||||
|
||||
if parameter_insensible in kwargs:
|
||||
if isinstance(kwargs[parameter_insensible], list):
|
||||
kwargs[parameter_insensible].extend(values)
|
||||
else:
|
||||
values.append(kwargs[parameter_insensible])
|
||||
kwargs[parameter_insensible] = values
|
||||
continue
|
||||
|
||||
kwargs[parameter_insensible] = values
|
||||
continue
|
||||
|
||||
value: str = parameters[parameter][0].lower().strip(" ")
|
||||
|
||||
if parameter == "implementation":
|
||||
implementation = value
|
||||
continue
|
||||
|
||||
if "," in value:
|
||||
list_of_values = value.split(",")
|
||||
|
||||
if parameter_insensible in kwargs:
|
||||
if isinstance(kwargs[parameter_insensible], list):
|
||||
kwargs[parameter_insensible].extend(list_of_values)
|
||||
else:
|
||||
list_of_values.append(kwargs[parameter_insensible])
|
||||
continue
|
||||
|
||||
kwargs[parameter_insensible] = list_of_values
|
||||
continue
|
||||
|
||||
value_converted: bool | int | float | None = None
|
||||
|
||||
if value in ["false", "true"]:
|
||||
value_converted = True if value == "true" else False
|
||||
elif value.isdigit():
|
||||
value_converted = int(value)
|
||||
elif (
|
||||
value.count(".") == 1
|
||||
and value.index(".") > 0
|
||||
and value.replace(".", "").isdigit()
|
||||
):
|
||||
value_converted = float(value)
|
||||
|
||||
kwargs[parameter_insensible] = (
|
||||
value if value_converted is None else value_converted
|
||||
)
|
||||
|
||||
host_patterns: list[str] = []
|
||||
|
||||
if "hosts" in kwargs:
|
||||
host_patterns = (
|
||||
kwargs["hosts"].split(",")
|
||||
if isinstance(kwargs["hosts"], str)
|
||||
else kwargs["hosts"]
|
||||
)
|
||||
del kwargs["hosts"]
|
||||
|
||||
return ResolverDescription(
|
||||
protocol,
|
||||
specifier,
|
||||
implementation,
|
||||
parsed_url.host,
|
||||
parsed_url.port,
|
||||
*host_patterns,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -0,0 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._dict import InMemoryResolver
|
||||
|
||||
__all__ = ("InMemoryResolver",)
|
||||
@@ -0,0 +1,192 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
|
||||
from ....util.url import _IPV6_ADDRZ_RE
|
||||
from ..protocols import BaseResolver, ProtocolResolver
|
||||
from ..utils import is_ipv4, is_ipv6
|
||||
|
||||
|
||||
class InMemoryResolver(BaseResolver):
|
||||
protocol = ProtocolResolver.MANUAL
|
||||
implementation = "dict"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any):
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
kwargs.pop("port")
|
||||
super().__init__(None, None, *patterns, **kwargs)
|
||||
|
||||
self._maxsize = 65535 if "maxsize" not in kwargs else int(kwargs["maxsize"])
|
||||
self._hosts: dict[str, list[tuple[socket.AddressFamily, str]]] = {}
|
||||
|
||||
if self._host_patterns:
|
||||
for record in self._host_patterns:
|
||||
if ":" not in record:
|
||||
continue
|
||||
hostname, addr = record.split(":", 1)
|
||||
self.register(hostname, addr)
|
||||
self._host_patterns = tuple([])
|
||||
|
||||
# probably about our happy eyeballs impl (sync only)
|
||||
if len(self._hosts) == 1 and len(self._hosts[list(self._hosts.keys())[0]]) == 1:
|
||||
self._unsafe_expose = True
|
||||
|
||||
def recycle(self) -> BaseResolver:
|
||||
return self
|
||||
|
||||
def close(self) -> None:
|
||||
pass # no-op
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True
|
||||
|
||||
def have_constraints(self) -> bool:
|
||||
return True
|
||||
|
||||
def support(self, hostname: str | bytes | None) -> bool | None:
|
||||
if hostname is None:
|
||||
hostname = "localhost"
|
||||
if isinstance(hostname, bytes):
|
||||
hostname = hostname.decode("ascii")
|
||||
return hostname in self._hosts
|
||||
|
||||
def register(self, hostname: str, ipaddr: str) -> None:
|
||||
with self._lock:
|
||||
if hostname not in self._hosts:
|
||||
self._hosts[hostname] = []
|
||||
else:
|
||||
for e in self._hosts[hostname]:
|
||||
t, addr = e
|
||||
if addr in ipaddr:
|
||||
return
|
||||
|
||||
if _IPV6_ADDRZ_RE.match(ipaddr):
|
||||
self._hosts[hostname].append((socket.AF_INET6, ipaddr[1:-1]))
|
||||
elif is_ipv6(ipaddr):
|
||||
self._hosts[hostname].append((socket.AF_INET6, ipaddr))
|
||||
else:
|
||||
self._hosts[hostname].append((socket.AF_INET, ipaddr))
|
||||
|
||||
if len(self._hosts) > self._maxsize:
|
||||
k = None
|
||||
for k in self._hosts.keys():
|
||||
break
|
||||
if k:
|
||||
self._hosts.pop(k)
|
||||
|
||||
def clear(self, hostname: str) -> None:
|
||||
with self._lock:
|
||||
if hostname in self._hosts:
|
||||
del self._hosts[hostname]
|
||||
|
||||
def getaddrinfo(
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if host is None:
|
||||
host = "localhost" # Defensive: stdlib cpy behavior
|
||||
|
||||
if port is None:
|
||||
port = 0 # Defensive: stdlib cpy behavior
|
||||
if isinstance(port, str):
|
||||
port = int(port) # Defensive: stdlib cpy behavior
|
||||
if port < 0:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Servname not supported for ai_socktype"
|
||||
)
|
||||
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii") # Defensive: stdlib cpy behavior
|
||||
|
||||
if is_ipv4(host):
|
||||
if family == socket.AF_INET6:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
type,
|
||||
6,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
),
|
||||
)
|
||||
]
|
||||
elif is_ipv6(host):
|
||||
if family == socket.AF_INET:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET6,
|
||||
type,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
results: list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
] = []
|
||||
|
||||
with self._lock:
|
||||
if host not in self._hosts:
|
||||
raise socket.gaierror(f"no records found for hostname {host} in-memory")
|
||||
|
||||
for entry in self._hosts[host]:
|
||||
addr_type, addr_target = entry
|
||||
|
||||
if family != socket.AF_UNSPEC:
|
||||
if family != addr_type:
|
||||
continue
|
||||
|
||||
results.append(
|
||||
(
|
||||
addr_type,
|
||||
type,
|
||||
6 if type == socket.SOCK_STREAM else 17,
|
||||
"",
|
||||
(addr_target, port)
|
||||
if addr_type == socket.AF_INET
|
||||
else (addr_target, port, 0, 0),
|
||||
)
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise socket.gaierror(f"no records found for hostname {host} in-memory")
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
|
||||
from ..protocols import BaseResolver, ProtocolResolver
|
||||
from ..utils import is_ipv4, is_ipv6
|
||||
|
||||
|
||||
class NullResolver(BaseResolver):
|
||||
protocol = ProtocolResolver.NULL
|
||||
implementation = "dummy"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any):
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
kwargs.pop("port")
|
||||
super().__init__(None, None, *patterns, **kwargs)
|
||||
|
||||
def recycle(self) -> BaseResolver:
|
||||
return self
|
||||
|
||||
def close(self) -> None:
|
||||
pass # no-op
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True
|
||||
|
||||
def getaddrinfo(
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if host is None:
|
||||
host = "localhost" # Defensive: stdlib cpy behavior
|
||||
|
||||
if port is None:
|
||||
port = 0 # Defensive: stdlib cpy behavior
|
||||
if isinstance(port, str):
|
||||
port = int(port) # Defensive: stdlib cpy behavior
|
||||
if port < 0:
|
||||
raise socket.gaierror(
|
||||
"Servname not supported for ai_socktype"
|
||||
) # Defensive: stdlib cpy behavior
|
||||
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii") # Defensive: stdlib cpy behavior
|
||||
|
||||
if is_ipv4(host):
|
||||
if family == socket.AF_INET6:
|
||||
raise socket.gaierror(
|
||||
"Address family for hostname not supported"
|
||||
) # Defensive: stdlib cpy behavior
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
type,
|
||||
6,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
),
|
||||
)
|
||||
]
|
||||
elif is_ipv6(host):
|
||||
if family == socket.AF_INET:
|
||||
raise socket.gaierror(
|
||||
"Address family for hostname not supported"
|
||||
) # Defensive: stdlib cpy behavior
|
||||
return [
|
||||
(
|
||||
socket.AF_INET6,
|
||||
type,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
raise socket.gaierror(f"Tried to resolve '{host}' using the NullResolver")
|
||||
|
||||
|
||||
__all__ = ("NullResolver",)
|
||||
@@ -0,0 +1,655 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import socket
|
||||
import struct
|
||||
import sys
|
||||
import threading
|
||||
import typing
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from random import randint
|
||||
|
||||
from ..._constant import UDP_LINUX_GRO
|
||||
from ..._typing import _TYPE_SOCKET_OPTIONS, _TYPE_TIMEOUT_INTERNAL
|
||||
from ...exceptions import LocationParseError
|
||||
from ...util.connection import _set_socket_options, allowed_gai_family
|
||||
from ...util.ssl_match_hostname import CertificateError, match_hostname
|
||||
from ...util.timeout import _DEFAULT_TIMEOUT
|
||||
from .utils import inet4_ntoa, inet6_ntoa, parse_https_rdata
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .utils import HttpsRecord
|
||||
|
||||
|
||||
class ProtocolResolver(str, Enum):
|
||||
"""
|
||||
At urllib3.future we aim to propose a wide range of DNS-protocols.
|
||||
The most used techniques are available.
|
||||
"""
|
||||
|
||||
#: Ask the OS native DNS layer
|
||||
SYSTEM = "system"
|
||||
#: DNS over HTTPS
|
||||
DOH = "doh"
|
||||
#: DNS over QUIC
|
||||
DOQ = "doq"
|
||||
#: DNS over TLS
|
||||
DOT = "dot"
|
||||
#: DNS over UDP (insecure)
|
||||
DOU = "dou"
|
||||
#: Manual (e.g. hosts)
|
||||
MANUAL = "in-memory"
|
||||
#: Void (e.g. purposely disable resolution)
|
||||
NULL = "null"
|
||||
#: Custom (e.g. your own implementation, use this when it does not suit any of the protocols specified)
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class BaseResolver(metaclass=ABCMeta):
|
||||
protocol: typing.ClassVar[ProtocolResolver]
|
||||
specifier: typing.ClassVar[str | None] = None
|
||||
|
||||
implementation: typing.ClassVar[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: str | None,
|
||||
port: int | None = None,
|
||||
*patterns: str,
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
self._server = server
|
||||
self._port = port
|
||||
self._host_patterns: tuple[str, ...] = patterns
|
||||
self._lock = threading.Lock()
|
||||
self._kwargs = kwargs
|
||||
|
||||
if not self._host_patterns and "patterns" in kwargs:
|
||||
self._host_patterns = kwargs["patterns"]
|
||||
|
||||
# allow to temporarily expose a sock that is "being" created
|
||||
# this helps with our Happy Eyeballs implementation in sync.
|
||||
self._unsafe_expose: bool = False
|
||||
self._sock_cursor: socket.socket | None = None
|
||||
|
||||
def recycle(self) -> BaseResolver:
|
||||
if self.is_available():
|
||||
raise RuntimeError("Attempting to recycle a Resolver that was not closed")
|
||||
|
||||
args = list(self.__class__.__init__.__code__.co_varnames)
|
||||
args.remove("self")
|
||||
|
||||
kwargs_cpy = deepcopy(self._kwargs)
|
||||
|
||||
if self._server:
|
||||
kwargs_cpy["server"] = self._server
|
||||
if self._port:
|
||||
kwargs_cpy["port"] = self._port
|
||||
|
||||
if "patterns" in args and "kwargs" in args:
|
||||
return self.__class__(*self._host_patterns, **kwargs_cpy) # type: ignore[arg-type]
|
||||
elif "kwargs" in args:
|
||||
return self.__class__(**kwargs_cpy)
|
||||
|
||||
return self.__class__() # type: ignore[call-arg]
|
||||
|
||||
@property
|
||||
def server(self) -> str | None:
|
||||
return self._server
|
||||
|
||||
@property
|
||||
def port(self) -> int | None:
|
||||
return self._port
|
||||
|
||||
def have_constraints(self) -> bool:
|
||||
return bool(self._host_patterns)
|
||||
|
||||
def support(self, hostname: str | bytes | None) -> bool | None:
|
||||
"""
|
||||
Determine if given hostname is especially resolvable by given resolver.
|
||||
If this resolver does not have any constrained list of host, it returns None. Meaning
|
||||
it support any hostname for resolution.
|
||||
"""
|
||||
if not self._host_patterns:
|
||||
return None
|
||||
if hostname is None:
|
||||
hostname = "localhost"
|
||||
if isinstance(hostname, bytes):
|
||||
hostname = hostname.decode("ascii")
|
||||
try:
|
||||
match_hostname(
|
||||
{"subjectAltName": (tuple(("DNS", e) for e in self._host_patterns))},
|
||||
hostname,
|
||||
)
|
||||
except CertificateError:
|
||||
return False
|
||||
return True
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""Terminate the given resolver instance. This should render it unusable. Further inquiries should raise an exception."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def is_available(self) -> bool:
|
||||
"""Determine if Resolver can receive inquiries."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def getaddrinfo(
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
"""This method align itself on the standard library socket.getaddrinfo(). It must be implemented as-is on your Resolver."""
|
||||
raise NotImplementedError
|
||||
|
||||
# This function is copied from socket.py in the Python 2.7 standard
|
||||
# library test suite. Added to its signature is only `socket_options`.
|
||||
# One additional modification is that we avoid binding to IPv6 servers
|
||||
# discovered in DNS if the system doesn't have IPv6 functionality.
|
||||
def create_connection(
|
||||
self,
|
||||
address: tuple[str, int],
|
||||
timeout: _TYPE_TIMEOUT_INTERNAL = _DEFAULT_TIMEOUT,
|
||||
source_address: tuple[str, int] | None = None,
|
||||
socket_options: _TYPE_SOCKET_OPTIONS | None = None,
|
||||
socket_kind: socket.SocketKind = socket.SOCK_STREAM,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
timing_hook: typing.Callable[[tuple[timedelta, timedelta, datetime]], None]
|
||||
| None = None,
|
||||
default_socket_family: socket.AddressFamily = socket.AF_UNSPEC,
|
||||
) -> socket.socket:
|
||||
"""Connect to *address* and return the socket object.
|
||||
|
||||
Convenience function. Connect to *address* (a 2-tuple ``(host,
|
||||
port)``) and return the socket object. Passing the optional
|
||||
*timeout* parameter will set the timeout on the socket instance
|
||||
before attempting to connect. If no *timeout* is supplied, the
|
||||
global default timeout setting returned by :func:`socket.getdefaulttimeout`
|
||||
is used. If *source_address* is set it must be a tuple of (host, port)
|
||||
for the socket to bind as a source address before making the connection.
|
||||
An host of '' or port 0 tells the OS to use the default.
|
||||
"""
|
||||
|
||||
host, port = address
|
||||
if host.startswith("["):
|
||||
host = host.strip("[]")
|
||||
err = None
|
||||
|
||||
# Using the value from allowed_gai_family() in the context of getaddrinfo lets
|
||||
# us select whether to work with IPv4 DNS records, IPv6 records, or both.
|
||||
# The original create_connection function always returns all records.
|
||||
family = allowed_gai_family()
|
||||
|
||||
if family != socket.AF_UNSPEC:
|
||||
default_socket_family = family
|
||||
|
||||
if source_address is not None:
|
||||
if isinstance(
|
||||
ipaddress.ip_address(source_address[0]), ipaddress.IPv4Address
|
||||
):
|
||||
default_socket_family = socket.AF_INET
|
||||
else:
|
||||
default_socket_family = socket.AF_INET6
|
||||
|
||||
try:
|
||||
host.encode("idna")
|
||||
except UnicodeError:
|
||||
raise LocationParseError(f"'{host}', label empty or too long") from None
|
||||
|
||||
dt_pre_resolve = datetime.now(tz=timezone.utc)
|
||||
records = self.getaddrinfo(
|
||||
host,
|
||||
port,
|
||||
default_socket_family,
|
||||
socket_kind,
|
||||
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
|
||||
)
|
||||
delta_post_resolve = datetime.now(tz=timezone.utc) - dt_pre_resolve
|
||||
|
||||
dt_pre_established = datetime.now(tz=timezone.utc)
|
||||
for res in records:
|
||||
af, socktype, proto, canonname, sa = res
|
||||
sock = None
|
||||
try:
|
||||
sock = socket.socket(af, socktype, proto)
|
||||
|
||||
# we need to add this or reusing the same origin port will likely fail within
|
||||
# short period of time. kernel put port on wait shut.
|
||||
if source_address is not None:
|
||||
try:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
|
||||
except (
|
||||
OSError,
|
||||
AttributeError,
|
||||
): # Defensive: Windows or very old OS?
|
||||
try:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
except (
|
||||
OSError,
|
||||
AttributeError,
|
||||
): # Defensive: we can't do anything better than this.
|
||||
pass
|
||||
|
||||
try:
|
||||
sock.setsockopt(
|
||||
socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0)
|
||||
)
|
||||
except (OSError, AttributeError):
|
||||
pass
|
||||
|
||||
sock.bind(source_address)
|
||||
|
||||
# attempt to leverage GRO when under Linux
|
||||
if socktype == socket.SOCK_DGRAM and sys.platform == "linux":
|
||||
try:
|
||||
sock.setsockopt(socket.SOL_UDP, UDP_LINUX_GRO, 1)
|
||||
except OSError: # Defensive: oh, well(...) anyway!
|
||||
pass
|
||||
|
||||
# If provided, set socket level options before connecting.
|
||||
_set_socket_options(sock, socket_options)
|
||||
|
||||
if timeout is not _DEFAULT_TIMEOUT:
|
||||
sock.settimeout(timeout)
|
||||
|
||||
if self._unsafe_expose:
|
||||
self._sock_cursor = sock
|
||||
|
||||
sock.connect(sa)
|
||||
|
||||
if self._unsafe_expose:
|
||||
self._sock_cursor = None
|
||||
# Break explicitly a reference cycle
|
||||
err = None
|
||||
|
||||
delta_post_established = (
|
||||
datetime.now(tz=timezone.utc) - dt_pre_established
|
||||
)
|
||||
|
||||
if timing_hook is not None:
|
||||
timing_hook(
|
||||
(
|
||||
delta_post_resolve,
|
||||
delta_post_established,
|
||||
datetime.now(tz=timezone.utc),
|
||||
)
|
||||
)
|
||||
|
||||
return sock
|
||||
except (OSError, OverflowError) as _:
|
||||
err = _
|
||||
if sock is not None:
|
||||
sock.close()
|
||||
if isinstance(_, OverflowError):
|
||||
break
|
||||
|
||||
if err is not None:
|
||||
try:
|
||||
raise err
|
||||
finally:
|
||||
# Break explicitly a reference cycle
|
||||
err = None
|
||||
else:
|
||||
raise OSError("getaddrinfo returns an empty list")
|
||||
|
||||
|
||||
class ManyResolver(BaseResolver):
|
||||
"""
|
||||
Special resolver that use many child resolver. Priorities
|
||||
are based on given order (list of BaseResolver).
|
||||
"""
|
||||
|
||||
def __init__(self, *resolvers: BaseResolver) -> None:
|
||||
super().__init__(None, None)
|
||||
|
||||
self._size = len(resolvers)
|
||||
|
||||
self._unconstrained: list[BaseResolver] = [
|
||||
_ for _ in resolvers if not _.have_constraints()
|
||||
]
|
||||
self._constrained: list[BaseResolver] = [
|
||||
_ for _ in resolvers if _.have_constraints()
|
||||
]
|
||||
|
||||
self._concurrent: int = 0
|
||||
self._terminated: bool = False
|
||||
|
||||
def recycle(self) -> BaseResolver:
|
||||
resolvers = []
|
||||
|
||||
for resolver in self._unconstrained + self._constrained:
|
||||
resolvers.append(resolver.recycle())
|
||||
|
||||
return ManyResolver(*resolvers)
|
||||
|
||||
def close(self) -> None:
|
||||
for resolver in self._unconstrained + self._constrained:
|
||||
resolver.close()
|
||||
|
||||
self._terminated = True
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return not self._terminated
|
||||
|
||||
def __resolvers(
|
||||
self, constrained: bool = False
|
||||
) -> typing.Generator[BaseResolver, None, None]:
|
||||
resolvers = self._unconstrained if not constrained else self._constrained
|
||||
|
||||
if not resolvers:
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
self._concurrent += 1
|
||||
|
||||
try:
|
||||
resolver_count = len(resolvers)
|
||||
start_idx = (self._concurrent - 1) % resolver_count
|
||||
|
||||
for idx in range(start_idx, resolver_count):
|
||||
if not resolvers[idx].is_available():
|
||||
with self._lock:
|
||||
resolvers[idx] = resolvers[idx].recycle()
|
||||
yield resolvers[idx]
|
||||
|
||||
if start_idx > 0:
|
||||
for idx in range(0, start_idx):
|
||||
if not resolvers[idx].is_available():
|
||||
with self._lock:
|
||||
resolvers[idx] = resolvers[idx].recycle()
|
||||
yield resolvers[idx]
|
||||
finally:
|
||||
with self._lock:
|
||||
self._concurrent -= 1
|
||||
|
||||
def getaddrinfo(
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii")
|
||||
if host is None:
|
||||
host = "localhost"
|
||||
|
||||
tested_resolvers = []
|
||||
|
||||
any_constrained_tried: bool = False
|
||||
|
||||
for resolver in self.__resolvers(True):
|
||||
can_resolve = resolver.support(host)
|
||||
|
||||
if can_resolve is True:
|
||||
any_constrained_tried = True
|
||||
|
||||
try:
|
||||
results = resolver.getaddrinfo(
|
||||
host,
|
||||
port,
|
||||
family,
|
||||
type,
|
||||
proto,
|
||||
flags,
|
||||
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
|
||||
)
|
||||
|
||||
if results:
|
||||
return results
|
||||
except socket.gaierror as exc:
|
||||
if isinstance(exc.args[0], str) and (
|
||||
"DNSSEC" in exc.args[0] or "DNSKEY" in exc.args[0]
|
||||
):
|
||||
raise
|
||||
continue
|
||||
elif can_resolve is False:
|
||||
tested_resolvers.append(resolver)
|
||||
|
||||
if any_constrained_tried:
|
||||
raise socket.gaierror(
|
||||
f"Name or service not known: {host} using {self._size - len(self._unconstrained)} resolver(s)"
|
||||
)
|
||||
|
||||
for resolver in self.__resolvers():
|
||||
try:
|
||||
results = resolver.getaddrinfo(
|
||||
host,
|
||||
port,
|
||||
family,
|
||||
type,
|
||||
proto,
|
||||
flags,
|
||||
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
|
||||
)
|
||||
|
||||
if results:
|
||||
return results
|
||||
except socket.gaierror as exc:
|
||||
if isinstance(exc.args[0], str) and (
|
||||
"DNSSEC" in exc.args[0] or "DNSKEY" in exc.args[0]
|
||||
):
|
||||
raise
|
||||
continue
|
||||
|
||||
raise socket.gaierror(
|
||||
f"Name or service not known: {host} using {self._size - len(self._constrained)} resolver(s)"
|
||||
)
|
||||
|
||||
|
||||
class SupportedQueryType(int, Enum):
|
||||
"""
|
||||
urllib3.future does not need anything else so far. let's be pragmatic.
|
||||
Each type is associated with its hex value as per the RFC.
|
||||
"""
|
||||
|
||||
A = 0x0001
|
||||
AAAA = 0x001C
|
||||
HTTPS = 0x0041
|
||||
|
||||
|
||||
class DomainNameServerQuery:
|
||||
"""
|
||||
Minimalist DNS query/message to ask for A, AAAA and HTTPS records.
|
||||
Only meant for urllib3.future use. Does not cover all of possible extent of use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, host: str, query_type: SupportedQueryType, override_id: int | None = None
|
||||
) -> None:
|
||||
self._id = struct.pack(
|
||||
"!H", randint(0x0000, 0xFFFF) if override_id is None else override_id
|
||||
)
|
||||
self._host = host
|
||||
self._query = query_type
|
||||
self._flags = struct.pack("!H", 0x0100)
|
||||
self._qd_count = struct.pack("!H", 1)
|
||||
|
||||
self._cached: bytes | None = None
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return struct.unpack("!H", self._id)[0] # type: ignore[no-any-return]
|
||||
|
||||
@property
|
||||
def raw_id(self) -> bytes:
|
||||
return self._id
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Query '{self._host}' IN {self._query.name}>"
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
if self._cached:
|
||||
return self._cached
|
||||
|
||||
payload = b""
|
||||
|
||||
payload += self._id
|
||||
payload += self._flags
|
||||
payload += self._qd_count
|
||||
payload += b"\x00\x00"
|
||||
payload += b"\x00\x00"
|
||||
payload += b"\x00\x00"
|
||||
|
||||
for ext in self._host.split("."):
|
||||
payload += struct.pack("!B", len(ext))
|
||||
payload += ext.encode("ascii")
|
||||
|
||||
payload += b"\x00"
|
||||
payload += struct.pack("!H", self._query.value)
|
||||
payload += struct.pack("!H", 0x0001)
|
||||
|
||||
self._cached = payload
|
||||
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def bulk(host: str, *types: SupportedQueryType) -> list[DomainNameServerQuery]:
|
||||
queries = []
|
||||
|
||||
for query_type in types:
|
||||
queries.append(DomainNameServerQuery(host, query_type=query_type))
|
||||
|
||||
return queries
|
||||
|
||||
|
||||
#: Most common status code, not exhaustive at all.
|
||||
COMMON_RCODE_LABEL: dict[int, str] = {
|
||||
0: "No Error",
|
||||
1: "Format Error",
|
||||
2: "Server Failure",
|
||||
3: "Non-Existent Domain",
|
||||
5: "Query Refused",
|
||||
9: "Not Authorized",
|
||||
}
|
||||
|
||||
|
||||
class DomainNameServerParseException(Exception): ...
|
||||
|
||||
|
||||
class DomainNameServerReturn:
|
||||
"""
|
||||
Minimalist DNS response parser. Allow to quickly extract key-data out of it.
|
||||
Meant for A, AAAA and HTTPS records. Basically only what we need.
|
||||
"""
|
||||
|
||||
def __init__(self, payload: bytes) -> None:
|
||||
try:
|
||||
up = struct.unpack("!HHHHHH", payload[:12])
|
||||
|
||||
self._id = up[0]
|
||||
self._flags = up[1]
|
||||
self._qd_count = up[2]
|
||||
self._an_count = up[3]
|
||||
|
||||
self._rcode = int(f"0x{hex(payload[3])[-1]}", 16)
|
||||
|
||||
self._hostname: str = ""
|
||||
|
||||
idx = 12
|
||||
|
||||
while True:
|
||||
c = payload[idx]
|
||||
|
||||
if c == 0:
|
||||
idx += 1
|
||||
break
|
||||
|
||||
self._hostname += payload[idx + 1 : idx + 1 + c].decode("ascii") + "."
|
||||
|
||||
idx += c + 1
|
||||
|
||||
self._records: list[tuple[SupportedQueryType, int, str | HttpsRecord]] = []
|
||||
|
||||
if self._an_count:
|
||||
idx += 4
|
||||
|
||||
while idx < len(payload):
|
||||
up = struct.unpack("!HHHI", payload[idx : idx + 10])
|
||||
entry_size = struct.unpack("!H", payload[idx + 10 : idx + 12])[0]
|
||||
|
||||
data = payload[idx + 12 : idx + 12 + entry_size]
|
||||
|
||||
if len(data) == 4:
|
||||
decoded_data: str | HttpsRecord = inet4_ntoa(data)
|
||||
elif len(data) == 16:
|
||||
decoded_data = inet6_ntoa(data)
|
||||
elif data:
|
||||
decoded_data = parse_https_rdata(data)
|
||||
else:
|
||||
continue
|
||||
|
||||
try:
|
||||
self._records.append(
|
||||
(SupportedQueryType(up[1]), up[-1], decoded_data)
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
idx += 12 + entry_size
|
||||
except (struct.error, IndexError, ValueError, UnicodeDecodeError) as e:
|
||||
raise DomainNameServerParseException(
|
||||
"A protocol error occurred while parsing the DNS response payload: "
|
||||
f"{str(e)}"
|
||||
) from e
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._id # type: ignore[no-any-return]
|
||||
|
||||
@property
|
||||
def hostname(self) -> str:
|
||||
return self._hostname
|
||||
|
||||
@property
|
||||
def records(self) -> list[tuple[SupportedQueryType, int, str | HttpsRecord]]:
|
||||
return self._records
|
||||
|
||||
@property
|
||||
def is_found(self) -> bool:
|
||||
return bool(self._records)
|
||||
|
||||
@property
|
||||
def rcode(self) -> int:
|
||||
return self._rcode
|
||||
|
||||
@property
|
||||
def is_ok(self) -> bool:
|
||||
return self._rcode == 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self.is_ok:
|
||||
return f"<Records '{self.hostname}' {self._records}>"
|
||||
return f"<DNS Error '{self.hostname}' with Status {self.rcode} ({COMMON_RCODE_LABEL[self.rcode] if self.rcode in COMMON_RCODE_LABEL else 'Unknown'})>"
|
||||
@@ -0,0 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._socket import SystemResolver
|
||||
|
||||
__all__ = ("SystemResolver",)
|
||||
@@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
|
||||
from ..protocols import BaseResolver, ProtocolResolver
|
||||
|
||||
|
||||
class SystemResolver(BaseResolver):
|
||||
implementation = "socket"
|
||||
protocol = ProtocolResolver.SYSTEM
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any):
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
kwargs.pop("port")
|
||||
super().__init__(None, None, *patterns, **kwargs)
|
||||
|
||||
def support(self, hostname: str | bytes | None) -> bool | None:
|
||||
if hostname is None:
|
||||
return True
|
||||
if isinstance(hostname, bytes):
|
||||
hostname = hostname.decode("ascii")
|
||||
if hostname == "localhost":
|
||||
return True
|
||||
return super().support(hostname)
|
||||
|
||||
def recycle(self) -> BaseResolver:
|
||||
return self
|
||||
|
||||
def close(self) -> None:
|
||||
pass # no-op!
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True
|
||||
|
||||
def getaddrinfo(
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
# the | tuple[int, bytes] is silently ignored, can't happen with our cases.
|
||||
return socket.getaddrinfo( # type: ignore[return-value]
|
||||
host=host,
|
||||
port=port,
|
||||
family=family,
|
||||
type=type,
|
||||
proto=proto,
|
||||
flags=flags,
|
||||
)
|
||||
@@ -0,0 +1,322 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import socket
|
||||
import struct
|
||||
import typing
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
|
||||
class HttpsRecord(typing.TypedDict):
|
||||
priority: int
|
||||
target: str
|
||||
alpn: list[str]
|
||||
ipv4hint: list[str]
|
||||
ipv6hint: list[str]
|
||||
echconfig: list[str]
|
||||
|
||||
|
||||
def inet4_ntoa(address: bytes) -> str:
|
||||
"""
|
||||
Convert an IPv4 address from bytes to str.
|
||||
"""
|
||||
if len(address) != 4:
|
||||
raise ValueError(
|
||||
f"IPv4 addresses are 4 bytes long, got {len(address)} byte(s) instead"
|
||||
)
|
||||
|
||||
return "%u.%u.%u.%u" % (address[0], address[1], address[2], address[3])
|
||||
|
||||
|
||||
def inet6_ntoa(address: bytes) -> str:
|
||||
"""
|
||||
Convert an IPv6 address from bytes to str.
|
||||
"""
|
||||
if len(address) != 16:
|
||||
raise ValueError(
|
||||
f"IPv6 addresses are 16 bytes long, got {len(address)} byte(s) instead"
|
||||
)
|
||||
|
||||
hex = binascii.hexlify(address)
|
||||
chunks = []
|
||||
|
||||
i = 0
|
||||
length = len(hex)
|
||||
|
||||
while i < length:
|
||||
chunk = hex[i : i + 4].decode().lstrip("0") or "0"
|
||||
chunks.append(chunk)
|
||||
i += 4
|
||||
|
||||
# Compress the longest subsequence of 0-value chunks to ::
|
||||
best_start = 0
|
||||
best_len = 0
|
||||
start = -1
|
||||
last_was_zero = False
|
||||
|
||||
for i in range(8):
|
||||
if chunks[i] != "0":
|
||||
if last_was_zero:
|
||||
end = i
|
||||
current_len = end - start
|
||||
if current_len > best_len:
|
||||
best_start = start
|
||||
best_len = current_len
|
||||
last_was_zero = False
|
||||
elif not last_was_zero:
|
||||
start = i
|
||||
last_was_zero = True
|
||||
if last_was_zero:
|
||||
end = 8
|
||||
current_len = end - start
|
||||
if current_len > best_len:
|
||||
best_start = start
|
||||
best_len = current_len
|
||||
if best_len > 1:
|
||||
if best_start == 0 and (best_len == 6 or best_len == 5 and chunks[5] == "ffff"):
|
||||
# We have an embedded IPv4 address
|
||||
if best_len == 6:
|
||||
prefix = "::"
|
||||
else:
|
||||
prefix = "::ffff:"
|
||||
thex = prefix + inet4_ntoa(address[12:])
|
||||
else:
|
||||
thex = (
|
||||
":".join(chunks[:best_start])
|
||||
+ "::"
|
||||
+ ":".join(chunks[best_start + best_len :])
|
||||
)
|
||||
else:
|
||||
thex = ":".join(chunks)
|
||||
|
||||
return thex
|
||||
|
||||
|
||||
def packet_fragment(payload: bytes, *identifiers: bytes) -> tuple[bytes, ...]:
|
||||
results = []
|
||||
|
||||
offset = 0
|
||||
|
||||
start_packet_idx = []
|
||||
lead_identifier = None
|
||||
|
||||
for identifier in identifiers:
|
||||
idx = payload[:12].find(identifier)
|
||||
|
||||
if idx == -1:
|
||||
continue
|
||||
|
||||
if idx != 0:
|
||||
offset = idx
|
||||
|
||||
start_packet_idx.append(idx - offset)
|
||||
|
||||
lead_identifier = identifier
|
||||
break
|
||||
|
||||
for identifier in identifiers:
|
||||
if identifier == lead_identifier:
|
||||
continue
|
||||
|
||||
if offset == 0:
|
||||
idx = payload.find(b"\x02" + identifier)
|
||||
else:
|
||||
idx = payload.find(identifier)
|
||||
|
||||
if idx == -1:
|
||||
continue
|
||||
|
||||
start_packet_idx.append(idx - offset)
|
||||
|
||||
if not start_packet_idx:
|
||||
raise ValueError(
|
||||
"no identifiable dns message emerged from given payload. "
|
||||
"this should not happen at all. networking issue?"
|
||||
)
|
||||
|
||||
if len(start_packet_idx) == 1:
|
||||
return (payload,)
|
||||
|
||||
start_packet_idx = sorted(start_packet_idx)
|
||||
|
||||
previous_idx = None
|
||||
|
||||
for idx in start_packet_idx:
|
||||
if previous_idx is None:
|
||||
previous_idx = idx
|
||||
continue
|
||||
results.append(payload[previous_idx:idx])
|
||||
previous_idx = idx
|
||||
|
||||
results.append(payload[previous_idx:])
|
||||
|
||||
return tuple(results)
|
||||
|
||||
|
||||
def is_ipv4(addr: str) -> bool:
|
||||
try:
|
||||
socket.inet_aton(addr)
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def is_ipv6(addr: str) -> bool:
|
||||
try:
|
||||
socket.inet_pton(socket.AF_INET6, addr)
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def validate_length_of(hostname: str) -> None:
|
||||
"""RFC 1035 impose a limit on a domain name length. We verify it there."""
|
||||
if len(hostname.strip(".")) > 253:
|
||||
raise UnicodeError("hostname to resolve exceed 253 characters")
|
||||
elif any([len(_) > 63 for _ in hostname.split(".")]):
|
||||
raise UnicodeError("at least one label to resolve exceed 63 characters")
|
||||
|
||||
|
||||
def rfc1035_should_read(payload: bytes) -> bool:
|
||||
if not payload:
|
||||
return False
|
||||
if len(payload) <= 2:
|
||||
return True
|
||||
|
||||
cursor = payload
|
||||
|
||||
while True:
|
||||
expected_size: int = struct.unpack("!H", cursor[:2])[0]
|
||||
|
||||
if len(cursor[2:]) == expected_size:
|
||||
return False
|
||||
elif len(cursor[2:]) < expected_size:
|
||||
return True
|
||||
|
||||
cursor = cursor[2 + expected_size :]
|
||||
|
||||
|
||||
def rfc1035_unpack(payload: bytes) -> tuple[bytes, ...]:
|
||||
cursor = payload
|
||||
packets = []
|
||||
|
||||
while cursor:
|
||||
expected_size: int = struct.unpack("!H", cursor[:2])[0]
|
||||
|
||||
packets.append(cursor[2 : 2 + expected_size])
|
||||
cursor = cursor[2 + expected_size :]
|
||||
|
||||
return tuple(packets)
|
||||
|
||||
|
||||
def rfc1035_pack(message: bytes) -> bytes:
|
||||
return struct.pack("!H", len(message)) + message
|
||||
|
||||
|
||||
def read_name(data: bytes, offset: int) -> tuple[str, int]:
|
||||
"""
|
||||
Read a DNS‐encoded name (with compression pointers) from data[offset:].
|
||||
Returns (name, new_offset).
|
||||
"""
|
||||
labels = []
|
||||
while True:
|
||||
length = data[offset]
|
||||
# compression pointer?
|
||||
if length & 0xC0 == 0xC0:
|
||||
pointer = struct.unpack_from("!H", data, offset)[0] & 0x3FFF
|
||||
subname, _ = read_name(data, pointer)
|
||||
labels.append(subname)
|
||||
offset += 2
|
||||
break
|
||||
if length == 0:
|
||||
offset += 1
|
||||
break
|
||||
offset += 1
|
||||
labels.append(data[offset : offset + length].decode())
|
||||
offset += length
|
||||
return ".".join(labels), offset
|
||||
|
||||
|
||||
def parse_echconfigs(buf: bytes) -> list[str]:
|
||||
"""
|
||||
buf is the raw bytes of the ECHConfig vector:
|
||||
- 2-byte total length, then for each:
|
||||
- 2-byte cfg length + that many bytes of cfg
|
||||
We return a list of Base64 strings (one per config).
|
||||
"""
|
||||
if len(buf) < 2:
|
||||
return []
|
||||
off = 2
|
||||
total = struct.unpack_from("!H", buf, 0)[0]
|
||||
end = 2 + total
|
||||
out = []
|
||||
while off + 2 <= end:
|
||||
cfg_len = struct.unpack_from("!H", buf, off)[0]
|
||||
off += 2
|
||||
cfg = buf[off : off + cfg_len]
|
||||
off += cfg_len
|
||||
out.append(base64.b64encode(cfg).decode())
|
||||
return out
|
||||
|
||||
|
||||
def parse_https_rdata(rdata: bytes) -> HttpsRecord:
|
||||
"""
|
||||
Parse the RDATA of an SVCB/HTTPS record.
|
||||
Returns a dict with keys: priority, target, alpn, ipv4hint, ipv6hint, echconfig.
|
||||
"""
|
||||
off = 0
|
||||
priority = struct.unpack_from("!H", rdata, off)[0]
|
||||
off += 2
|
||||
|
||||
target, off = read_name(rdata, off)
|
||||
|
||||
# pull out all the key/value params
|
||||
params = {}
|
||||
while off + 4 <= len(rdata):
|
||||
key, length = struct.unpack_from("!HH", rdata, off)
|
||||
off += 4
|
||||
params[key] = rdata[off : off + length]
|
||||
off += length
|
||||
|
||||
# decode ALPN (key=1), IPv4 (4), IPv6 (6), ECHConfig (5)
|
||||
def parse_alpn(buf: bytes) -> list[str]:
|
||||
out = []
|
||||
i: int = 0
|
||||
while i < len(buf):
|
||||
ln = buf[i]
|
||||
out.append(buf[i + 1 : i + 1 + ln].decode())
|
||||
i += 1 + ln
|
||||
return out
|
||||
|
||||
alpn: list[str] = parse_alpn(params.get(1, b""))
|
||||
ipv4 = [
|
||||
inet4_ntoa(params[4][i : i + 4]) for i in range(0, len(params.get(4, b"")), 4)
|
||||
]
|
||||
ipv6 = [
|
||||
inet6_ntoa(params[6][i : i + 16]) for i in range(0, len(params.get(6, b"")), 16)
|
||||
]
|
||||
echconfs = parse_echconfigs(params.get(5, b""))
|
||||
|
||||
return {
|
||||
"priority": priority,
|
||||
"target": target or ".", # empty name → root
|
||||
"alpn": alpn,
|
||||
"ipv4hint": ipv4,
|
||||
"ipv6hint": ipv6,
|
||||
"echconfig": echconfs,
|
||||
}
|
||||
|
||||
|
||||
__all__ = (
|
||||
"inet4_ntoa",
|
||||
"inet6_ntoa",
|
||||
"packet_fragment",
|
||||
"is_ipv4",
|
||||
"is_ipv6",
|
||||
"validate_length_of",
|
||||
"rfc1035_pack",
|
||||
"rfc1035_unpack",
|
||||
"rfc1035_should_read",
|
||||
"parse_https_rdata",
|
||||
)
|
||||
@@ -0,0 +1,453 @@
|
||||
"""
|
||||
This module contains provisional support for SOCKS proxies from within
|
||||
urllib3. This module supports SOCKS4, SOCKS4A (an extension of SOCKS4), and
|
||||
SOCKS5. To enable its functionality, either install python-socks or install this
|
||||
module with the ``socks`` extra.
|
||||
|
||||
The SOCKS implementation supports the full range of urllib3 features. It also
|
||||
supports the following SOCKS features:
|
||||
|
||||
- SOCKS4A (``proxy_url='socks4a://...``)
|
||||
- SOCKS4 (``proxy_url='socks4://...``)
|
||||
- SOCKS5 with remote DNS (``proxy_url='socks5h://...``)
|
||||
- SOCKS5 with local DNS (``proxy_url='socks5://...``)
|
||||
- Usernames and passwords for the SOCKS proxy
|
||||
|
||||
.. note::
|
||||
It is recommended to use ``socks5h://`` or ``socks4a://`` schemes in
|
||||
your ``proxy_url`` to ensure that DNS resolution is done from the remote
|
||||
server instead of client-side when connecting to a domain name.
|
||||
|
||||
SOCKS4 supports IPv4 and domain names with the SOCKS4A extension. SOCKS5
|
||||
supports IPv4, IPv6, and domain names.
|
||||
|
||||
When connecting to a SOCKS4 proxy the ``username`` portion of the ``proxy_url``
|
||||
will be sent as the ``userid`` section of the SOCKS request:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
proxy_url="socks4a://<userid>@proxy-host"
|
||||
|
||||
When connecting to a SOCKS5 proxy the ``username`` and ``password`` portion
|
||||
of the ``proxy_url`` will be sent as the username/password to authenticate
|
||||
with the proxy:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
proxy_url="socks5h://<username>:<password>@proxy-host"
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
|
||||
#: We purposely want to support PySocks[...] due to our shadowing of the legacy "urllib3". "Dot not disturb" policy.
|
||||
BYPASS_SOCKS_LEGACY: bool = False
|
||||
|
||||
try:
|
||||
from python_socks import (
|
||||
ProxyConnectionError,
|
||||
ProxyError,
|
||||
ProxyTimeoutError,
|
||||
ProxyType,
|
||||
)
|
||||
from python_socks.sync import Proxy
|
||||
|
||||
from ._socks_override import AsyncioProxy
|
||||
except ImportError:
|
||||
from ..exceptions import DependencyWarning
|
||||
|
||||
try:
|
||||
import socks # noqa
|
||||
except ImportError:
|
||||
warnings.warn(
|
||||
(
|
||||
"SOCKS support in urllib3.future requires the installation of an optional "
|
||||
"dependency: python-socks. For more information, see "
|
||||
"https://urllib3future.readthedocs.io/en/latest/contrib.html#socks-proxies"
|
||||
),
|
||||
DependencyWarning,
|
||||
)
|
||||
else:
|
||||
from ._socks_legacy import (
|
||||
SOCKSConnection,
|
||||
SOCKSHTTPConnectionPool,
|
||||
SOCKSHTTPSConnection,
|
||||
SOCKSHTTPSConnectionPool,
|
||||
SOCKSProxyManager,
|
||||
)
|
||||
|
||||
BYPASS_SOCKS_LEGACY = True
|
||||
|
||||
if not BYPASS_SOCKS_LEGACY:
|
||||
raise
|
||||
|
||||
if not BYPASS_SOCKS_LEGACY:
|
||||
import typing
|
||||
from socket import socket
|
||||
from socket import timeout as SocketTimeout
|
||||
|
||||
# asynchronous part
|
||||
from .._async.connection import AsyncHTTPConnection, AsyncHTTPSConnection
|
||||
from .._async.connectionpool import (
|
||||
AsyncHTTPConnectionPool,
|
||||
AsyncHTTPSConnectionPool,
|
||||
)
|
||||
from .._async.poolmanager import AsyncPoolManager
|
||||
from .._typing import _TYPE_SOCKS_OPTIONS
|
||||
from ..backend import HttpVersion
|
||||
|
||||
# synchronous part
|
||||
from ..connection import HTTPConnection, HTTPSConnection
|
||||
from ..connectionpool import HTTPConnectionPool, HTTPSConnectionPool
|
||||
from ..contrib.ssa import AsyncSocket
|
||||
from ..exceptions import ConnectTimeoutError, NewConnectionError
|
||||
from ..poolmanager import PoolManager
|
||||
from ..util.url import parse_url
|
||||
|
||||
try:
|
||||
import ssl
|
||||
except ImportError:
|
||||
ssl = None # type: ignore[assignment]
|
||||
|
||||
class SOCKSConnection(HTTPConnection): # type: ignore[no-redef]
|
||||
"""
|
||||
A plain-text HTTP connection that connects via a SOCKS proxy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_socks_options: _TYPE_SOCKS_OPTIONS,
|
||||
*args: typing.Any,
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
self._socks_options = _socks_options
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _new_conn(self) -> socket:
|
||||
"""
|
||||
Establish a new connection via the SOCKS proxy.
|
||||
"""
|
||||
extra_kw: dict[str, typing.Any] = {}
|
||||
if self.source_address:
|
||||
extra_kw["source_address"] = self.source_address
|
||||
|
||||
if self.socket_options:
|
||||
only_tcp_options = []
|
||||
|
||||
for opt in self.socket_options:
|
||||
if len(opt) == 3:
|
||||
only_tcp_options.append(opt)
|
||||
elif len(opt) == 4:
|
||||
protocol: str = opt[3].lower()
|
||||
if protocol == "udp":
|
||||
continue
|
||||
only_tcp_options.append(opt[:3])
|
||||
|
||||
extra_kw["socket_options"] = only_tcp_options
|
||||
|
||||
try:
|
||||
assert self._socks_options["proxy_host"] is not None
|
||||
assert self._socks_options["proxy_port"] is not None
|
||||
|
||||
p = Proxy(
|
||||
proxy_type=self._socks_options["socks_version"], # type: ignore[arg-type]
|
||||
host=self._socks_options["proxy_host"],
|
||||
port=int(self._socks_options["proxy_port"]),
|
||||
username=self._socks_options["username"],
|
||||
password=self._socks_options["password"],
|
||||
rdns=self._socks_options["rdns"],
|
||||
)
|
||||
|
||||
_socket = self._resolver.create_connection(
|
||||
(
|
||||
self._socks_options["proxy_host"],
|
||||
int(self._socks_options["proxy_port"]),
|
||||
),
|
||||
timeout=self.timeout,
|
||||
source_address=self.source_address,
|
||||
socket_options=extra_kw["socket_options"],
|
||||
quic_upgrade_via_dns_rr=False,
|
||||
timing_hook=lambda _: setattr(self, "_connect_timings", _),
|
||||
)
|
||||
|
||||
# our dependency started to deprecate passing "_socket"
|
||||
# which is ... vital for our integration. We'll start by silencing the warning.
|
||||
# then we'll think on how to proceed.
|
||||
# A) the maintainer agrees to revert https://github.com/romis2012/python-socks/commit/173a7390469c06aa033f8dca67c827854b462bc3#diff-e4086fa970d1c98b1eb341e58cb70e9ceffe7391b2feecc4b66c7e92ea2de76fR64
|
||||
# B) the maintainer pursue the removal -> do we vendor our copy of python-socks? is there an alternative?
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
return p.connect(
|
||||
self.host,
|
||||
self.port,
|
||||
self.timeout,
|
||||
_socket=_socket,
|
||||
)
|
||||
except (SocketTimeout, ProxyTimeoutError) as e:
|
||||
raise ConnectTimeoutError(
|
||||
self,
|
||||
f"Connection to {self.host} timed out. (connect timeout={self.timeout})",
|
||||
) from e
|
||||
|
||||
except (ProxyConnectionError, ProxyError) as e:
|
||||
raise NewConnectionError(
|
||||
self, f"Failed to establish a new connection: {e}"
|
||||
) from e
|
||||
|
||||
except OSError as e: # Defensive: PySocks should catch all these.
|
||||
raise NewConnectionError(
|
||||
self, f"Failed to establish a new connection: {e}"
|
||||
) from e
|
||||
|
||||
# We don't need to duplicate the Verified/Unverified distinction from
|
||||
# urllib3/connection.py here because the HTTPSConnection will already have been
|
||||
# correctly set to either the Verified or Unverified form by that module. This
|
||||
# means the SOCKSHTTPSConnection will automatically be the correct type.
|
||||
class SOCKSHTTPSConnection(SOCKSConnection, HTTPSConnection): # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
class SOCKSHTTPConnectionPool(HTTPConnectionPool): # type: ignore[no-redef]
|
||||
ConnectionCls = SOCKSConnection
|
||||
|
||||
class SOCKSHTTPSConnectionPool(HTTPSConnectionPool): # type: ignore[no-redef]
|
||||
ConnectionCls = SOCKSHTTPSConnection
|
||||
|
||||
class SOCKSProxyManager(PoolManager): # type: ignore[no-redef]
|
||||
"""
|
||||
A version of the urllib3 ProxyManager that routes connections via the
|
||||
defined SOCKS proxy.
|
||||
"""
|
||||
|
||||
pool_classes_by_scheme = {
|
||||
"http": SOCKSHTTPConnectionPool,
|
||||
"https": SOCKSHTTPSConnectionPool,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proxy_url: str,
|
||||
username: str | None = None,
|
||||
password: str | None = None,
|
||||
num_pools: int = 10,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
**connection_pool_kw: typing.Any,
|
||||
):
|
||||
parsed = parse_url(proxy_url)
|
||||
|
||||
if username is None and password is None and parsed.auth is not None:
|
||||
split = parsed.auth.split(":")
|
||||
if len(split) == 2:
|
||||
username, password = split
|
||||
if parsed.scheme == "socks5":
|
||||
socks_version = ProxyType.SOCKS5
|
||||
rdns = False
|
||||
elif parsed.scheme == "socks5h":
|
||||
socks_version = ProxyType.SOCKS5
|
||||
rdns = True
|
||||
elif parsed.scheme == "socks4":
|
||||
socks_version = ProxyType.SOCKS4
|
||||
rdns = False
|
||||
elif parsed.scheme == "socks4a":
|
||||
socks_version = ProxyType.SOCKS4
|
||||
rdns = True
|
||||
else:
|
||||
raise ValueError(f"Unable to determine SOCKS version from {proxy_url}")
|
||||
|
||||
self.proxy_url = proxy_url
|
||||
|
||||
socks_options = {
|
||||
"socks_version": socks_version,
|
||||
"proxy_host": parsed.host,
|
||||
"proxy_port": parsed.port,
|
||||
"username": username,
|
||||
"password": password,
|
||||
"rdns": rdns,
|
||||
}
|
||||
connection_pool_kw["_socks_options"] = socks_options
|
||||
|
||||
if "disabled_svn" not in connection_pool_kw:
|
||||
connection_pool_kw["disabled_svn"] = set()
|
||||
|
||||
connection_pool_kw["disabled_svn"].add(HttpVersion.h3)
|
||||
|
||||
super().__init__(num_pools, headers, **connection_pool_kw)
|
||||
|
||||
self.pool_classes_by_scheme = SOCKSProxyManager.pool_classes_by_scheme
|
||||
|
||||
class AsyncSOCKSConnection(AsyncHTTPConnection):
|
||||
"""
|
||||
A plain-text HTTP connection that connects via a SOCKS proxy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
_socks_options: _TYPE_SOCKS_OPTIONS,
|
||||
*args: typing.Any,
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
self._socks_options = _socks_options
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def _new_conn(self) -> AsyncSocket: # type: ignore[override]
|
||||
"""
|
||||
Establish a new connection via the SOCKS proxy.
|
||||
"""
|
||||
extra_kw: dict[str, typing.Any] = {}
|
||||
if self.source_address:
|
||||
extra_kw["source_address"] = self.source_address
|
||||
|
||||
if self.socket_options:
|
||||
only_tcp_options = []
|
||||
|
||||
for opt in self.socket_options:
|
||||
if len(opt) == 3:
|
||||
only_tcp_options.append(opt)
|
||||
elif len(opt) == 4:
|
||||
protocol: str = opt[3].lower()
|
||||
if protocol == "udp":
|
||||
continue
|
||||
only_tcp_options.append(opt[:3])
|
||||
|
||||
extra_kw["socket_options"] = only_tcp_options
|
||||
|
||||
try:
|
||||
assert self._socks_options["proxy_host"] is not None
|
||||
assert self._socks_options["proxy_port"] is not None
|
||||
|
||||
p = AsyncioProxy(
|
||||
proxy_type=self._socks_options["socks_version"], # type: ignore[arg-type]
|
||||
host=self._socks_options["proxy_host"],
|
||||
port=int(self._socks_options["proxy_port"]),
|
||||
username=self._socks_options["username"],
|
||||
password=self._socks_options["password"],
|
||||
rdns=self._socks_options["rdns"],
|
||||
)
|
||||
|
||||
_socket = await self._resolver.create_connection(
|
||||
(
|
||||
self._socks_options["proxy_host"],
|
||||
int(self._socks_options["proxy_port"]),
|
||||
),
|
||||
timeout=self.timeout,
|
||||
source_address=self.source_address,
|
||||
socket_options=extra_kw["socket_options"],
|
||||
quic_upgrade_via_dns_rr=False,
|
||||
timing_hook=lambda _: setattr(self, "_connect_timings", _),
|
||||
)
|
||||
|
||||
return await p.connect(
|
||||
self.host,
|
||||
self.port,
|
||||
self.timeout,
|
||||
_socket,
|
||||
)
|
||||
except (SocketTimeout, ProxyTimeoutError) as e:
|
||||
raise ConnectTimeoutError(
|
||||
self,
|
||||
f"Connection to {self.host} timed out. (connect timeout={self.timeout})",
|
||||
) from e
|
||||
|
||||
except (ProxyConnectionError, ProxyError) as e:
|
||||
raise NewConnectionError(
|
||||
self, f"Failed to establish a new connection: {e}"
|
||||
) from e
|
||||
|
||||
except OSError as e: # Defensive: PySocks should catch all these.
|
||||
raise NewConnectionError(
|
||||
self, f"Failed to establish a new connection: {e}"
|
||||
) from e
|
||||
|
||||
# We don't need to duplicate the Verified/Unverified distinction from
|
||||
# urllib3/connection.py here because the HTTPSConnection will already have been
|
||||
# correctly set to either the Verified or Unverified form by that module. This
|
||||
# means the SOCKSHTTPSConnection will automatically be the correct type.
|
||||
class AsyncSOCKSHTTPSConnection(AsyncSOCKSConnection, AsyncHTTPSConnection):
|
||||
pass
|
||||
|
||||
class AsyncSOCKSHTTPConnectionPool(AsyncHTTPConnectionPool):
|
||||
ConnectionCls = AsyncSOCKSConnection
|
||||
|
||||
class AsyncSOCKSHTTPSConnectionPool(AsyncHTTPSConnectionPool):
|
||||
ConnectionCls = AsyncSOCKSHTTPSConnection
|
||||
|
||||
class AsyncSOCKSProxyManager(AsyncPoolManager):
|
||||
"""
|
||||
A version of the urllib3 ProxyManager that routes connections via the
|
||||
defined SOCKS proxy.
|
||||
"""
|
||||
|
||||
pool_classes_by_scheme = {
|
||||
"http": AsyncSOCKSHTTPConnectionPool,
|
||||
"https": AsyncSOCKSHTTPSConnectionPool,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
proxy_url: str,
|
||||
username: str | None = None,
|
||||
password: str | None = None,
|
||||
num_pools: int = 10,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
**connection_pool_kw: typing.Any,
|
||||
):
|
||||
parsed = parse_url(proxy_url)
|
||||
|
||||
if username is None and password is None and parsed.auth is not None:
|
||||
split = parsed.auth.split(":")
|
||||
if len(split) == 2:
|
||||
username, password = split
|
||||
if parsed.scheme == "socks5":
|
||||
socks_version = ProxyType.SOCKS5
|
||||
rdns = False
|
||||
elif parsed.scheme == "socks5h":
|
||||
socks_version = ProxyType.SOCKS5
|
||||
rdns = True
|
||||
elif parsed.scheme == "socks4":
|
||||
socks_version = ProxyType.SOCKS4
|
||||
rdns = False
|
||||
elif parsed.scheme == "socks4a":
|
||||
socks_version = ProxyType.SOCKS4
|
||||
rdns = True
|
||||
else:
|
||||
raise ValueError(f"Unable to determine SOCKS version from {proxy_url}")
|
||||
|
||||
self.proxy_url = proxy_url
|
||||
|
||||
socks_options = {
|
||||
"socks_version": socks_version,
|
||||
"proxy_host": parsed.host,
|
||||
"proxy_port": parsed.port,
|
||||
"username": username,
|
||||
"password": password,
|
||||
"rdns": rdns,
|
||||
}
|
||||
connection_pool_kw["_socks_options"] = socks_options
|
||||
|
||||
if "disabled_svn" not in connection_pool_kw:
|
||||
connection_pool_kw["disabled_svn"] = set()
|
||||
|
||||
connection_pool_kw["disabled_svn"].add(HttpVersion.h3)
|
||||
|
||||
super().__init__(num_pools, headers, **connection_pool_kw)
|
||||
|
||||
self.pool_classes_by_scheme = AsyncSOCKSProxyManager.pool_classes_by_scheme
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SOCKSConnection",
|
||||
"SOCKSProxyManager",
|
||||
"SOCKSHTTPSConnection",
|
||||
"SOCKSHTTPSConnectionPool",
|
||||
"SOCKSHTTPConnectionPool",
|
||||
]
|
||||
|
||||
if not BYPASS_SOCKS_LEGACY:
|
||||
__all__ += [
|
||||
"AsyncSOCKSConnection",
|
||||
"AsyncSOCKSHTTPSConnection",
|
||||
"AsyncSOCKSHTTPConnectionPool",
|
||||
"AsyncSOCKSHTTPSConnectionPool",
|
||||
"AsyncSOCKSProxyManager",
|
||||
]
|
||||
@@ -0,0 +1,520 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import platform
|
||||
import socket
|
||||
import typing
|
||||
import warnings
|
||||
|
||||
from ._timeout import timeout
|
||||
from ._gro import open_dgram_connection, DatagramReader, DatagramWriter
|
||||
|
||||
StandardTimeoutError = socket.timeout
|
||||
|
||||
try:
|
||||
from concurrent.futures import TimeoutError as FutureTimeoutError
|
||||
except ImportError:
|
||||
FutureTimeoutError = TimeoutError # type: ignore[misc]
|
||||
|
||||
try:
|
||||
AsyncioTimeoutError = asyncio.exceptions.TimeoutError
|
||||
except AttributeError:
|
||||
AsyncioTimeoutError = TimeoutError # type: ignore[misc]
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import ssl
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ..._typing import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT
|
||||
|
||||
|
||||
def _can_shutdown_and_close_selector_loop_bug() -> bool:
|
||||
import platform
|
||||
|
||||
if platform.system() == "Windows" and platform.python_version_tuple()[:2] == (
|
||||
"3",
|
||||
"7",
|
||||
):
|
||||
return int(platform.python_version_tuple()[-1]) >= 17
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# Windows + asyncio bug where doing our shutdown procedure induce a crash
|
||||
# in SelectorLoop
|
||||
# File "C:\hostedtoolcache\windows\Python\3.7.9\x64\lib\selectors.py", line 314, in _select
|
||||
# r, w, x = select.select(r, w, w, timeout)
|
||||
# [WinError 10038] An operation was attempted on something that is not a socket
|
||||
_CPYTHON_SELECTOR_CLOSE_BUG_EXIST = _can_shutdown_and_close_selector_loop_bug() is False
|
||||
|
||||
|
||||
class AsyncSocket:
|
||||
"""
|
||||
This class is brought to add a level of abstraction to an asyncio transport (reader, or writer)
|
||||
We don't want to have two distinct code (async/sync) but rather a unified and easily verifiable
|
||||
code base.
|
||||
|
||||
'ssa' stands for Simplified - Socket - Asynchronous.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
family: socket.AddressFamily = socket.AF_INET,
|
||||
type: socket.SocketKind = socket.SOCK_STREAM,
|
||||
proto: int = -1,
|
||||
fileno: int | None = None,
|
||||
) -> None:
|
||||
self.family: socket.AddressFamily = family
|
||||
self.type: socket.SocketKind = type
|
||||
self.proto: int = proto
|
||||
self._fileno: int | None = fileno
|
||||
|
||||
self._connect_called: bool = False
|
||||
self._established: asyncio.Event = asyncio.Event()
|
||||
|
||||
# we do that everytime to forward properly options / advanced settings
|
||||
self._sock: socket.socket = socket.socket(
|
||||
family=self.family, type=self.type, proto=self.proto, fileno=fileno
|
||||
)
|
||||
# set nonblocking / or cause the loop to block with dgram socket...
|
||||
self._sock.settimeout(0)
|
||||
|
||||
self._writer: asyncio.StreamWriter | DatagramWriter | None = None
|
||||
self._reader: asyncio.StreamReader | DatagramReader | None = None
|
||||
|
||||
self._writer_semaphore: asyncio.Semaphore = asyncio.Semaphore()
|
||||
self._reader_semaphore: asyncio.Semaphore = asyncio.Semaphore()
|
||||
|
||||
self._addr: tuple[str, int] | tuple[str, int, int, int] | None = None
|
||||
|
||||
self._external_timeout: float | int | None = None
|
||||
self._tls_in_tls = False
|
||||
|
||||
def fileno(self) -> int:
|
||||
return self._fileno if self._fileno is not None else self._sock.fileno()
|
||||
|
||||
async def wait_for_close(self) -> None:
|
||||
if self._connect_called:
|
||||
return
|
||||
|
||||
if self._writer is None:
|
||||
return
|
||||
|
||||
try:
|
||||
# report made in https://github.com/jawah/niquests/issues/184
|
||||
# made us believe that sometime ssl_transport is freed before
|
||||
# getting there. So we could end up there with a half broken
|
||||
# writer state. The original user was using Windows at the time.
|
||||
is_ssl = self._writer.get_extra_info("ssl_object") is not None
|
||||
except AttributeError:
|
||||
is_ssl = False
|
||||
|
||||
if is_ssl:
|
||||
# Give the connection a chance to write any data in the buffer,
|
||||
# and then forcibly tear down the SSL connection.
|
||||
await asyncio.sleep(0)
|
||||
self._writer.transport.abort()
|
||||
|
||||
try:
|
||||
# wait_closed can hang indefinitely!
|
||||
# on Python 3.8 and 3.9
|
||||
# there's some case where Python want an explicit EOT
|
||||
# (spoiler: it was a CPython bug) fixed in recent interpreters.
|
||||
# to circumvent this and still have a proper close
|
||||
# we enforce a maximum delay (1000ms).
|
||||
async with timeout(1):
|
||||
await self._writer.wait_closed()
|
||||
except TimeoutError:
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
if self._writer is not None:
|
||||
self._writer.close()
|
||||
|
||||
edge_case_close_bug_exist = _CPYTHON_SELECTOR_CLOSE_BUG_EXIST
|
||||
|
||||
# Windows + asyncio + asyncio.SelectorEventLoop limits us on how far
|
||||
# we can safely shutdown the socket.
|
||||
if not edge_case_close_bug_exist and platform.system() == "Windows":
|
||||
if hasattr(asyncio, "SelectorEventLoop") and isinstance(
|
||||
asyncio.get_running_loop(), asyncio.SelectorEventLoop
|
||||
):
|
||||
edge_case_close_bug_exist = True
|
||||
|
||||
try:
|
||||
# see https://github.com/MagicStack/uvloop/issues/241
|
||||
# and https://github.com/jawah/niquests/issues/166
|
||||
# probably not just uvloop.
|
||||
uvloop_edge_case_bug = False
|
||||
|
||||
# keep track of our clean exit procedure
|
||||
shutdown_called = False
|
||||
close_called = False
|
||||
|
||||
if hasattr(self._sock, "shutdown"):
|
||||
try:
|
||||
self._sock.shutdown(socket.SHUT_RD)
|
||||
shutdown_called = True
|
||||
except TypeError:
|
||||
uvloop_edge_case_bug = True
|
||||
# uvloop don't support shutdown! and sometime does not support close()...
|
||||
# see https://github.com/jawah/niquests/issues/166 for ctx.
|
||||
try:
|
||||
self._sock.close()
|
||||
close_called = True
|
||||
except TypeError:
|
||||
# last chance of releasing properly the underlying fd!
|
||||
try:
|
||||
direct_sock = socket.socket(fileno=self._sock.fileno())
|
||||
except (OSError, ValueError):
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
direct_sock.shutdown(socket.SHUT_RD)
|
||||
shutdown_called = True
|
||||
except OSError:
|
||||
warnings.warn(
|
||||
(
|
||||
"urllib3-future is unable to properly close your async socket. "
|
||||
"This mean that you are probably using an asyncio implementation like uvloop "
|
||||
"that does not support shutdown() or/and close() on the socket transport. "
|
||||
"This will lead to unclosed socket (fd)."
|
||||
),
|
||||
ResourceWarning,
|
||||
)
|
||||
finally:
|
||||
direct_sock.detach()
|
||||
# we have to force call close() on our sock object (even after shutdown).
|
||||
# or we'll get a resource warning for sure!
|
||||
if isinstance(self._sock, socket.socket) and hasattr(self._sock, "close"):
|
||||
if not uvloop_edge_case_bug and not edge_case_close_bug_exist:
|
||||
try:
|
||||
self._sock.close()
|
||||
close_called = True
|
||||
except (OSError, TypeError):
|
||||
pass
|
||||
|
||||
if not close_called or not shutdown_called:
|
||||
# this branch detect whether we have an asyncio.TransportSocket instead of socket.socket.
|
||||
if hasattr(self._sock, "_sock") and not edge_case_close_bug_exist:
|
||||
try:
|
||||
self._sock._sock.close()
|
||||
except (AttributeError, OSError, TypeError):
|
||||
pass
|
||||
|
||||
except (
|
||||
OSError
|
||||
): # branch where we failed to connect and still try to release resource
|
||||
if isinstance(self._sock, socket.socket):
|
||||
try:
|
||||
self._sock.close() # don't call close on asyncio.TransportSocket
|
||||
except (OSError, TypeError, AttributeError):
|
||||
pass
|
||||
elif hasattr(self._sock, "_sock") and not edge_case_close_bug_exist:
|
||||
try:
|
||||
self._sock._sock.detach()
|
||||
except (AttributeError, OSError, TypeError):
|
||||
pass
|
||||
|
||||
self._connect_called = False
|
||||
self._established.clear()
|
||||
|
||||
async def wait_for_readiness(self) -> None:
|
||||
await self._established.wait()
|
||||
|
||||
def setsockopt(self, __level: int, __optname: int, __value: int | bytes) -> None:
|
||||
self._sock.setsockopt(__level, __optname, __value)
|
||||
|
||||
@typing.overload
|
||||
def getsockopt(self, __level: int, __optname: int) -> int: ...
|
||||
|
||||
@typing.overload
|
||||
def getsockopt(self, __level: int, __optname: int, buflen: int) -> bytes: ...
|
||||
|
||||
def getsockopt(
|
||||
self, __level: int, __optname: int, buflen: int | None = None
|
||||
) -> int | bytes:
|
||||
if buflen is None:
|
||||
return self._sock.getsockopt(__level, __optname)
|
||||
return self._sock.getsockopt(__level, __optname, buflen)
|
||||
|
||||
def should_connect(self) -> bool:
|
||||
return self._connect_called is False
|
||||
|
||||
async def connect(self, addr: tuple[str, int] | tuple[str, int, int, int]) -> None:
|
||||
if self._connect_called:
|
||||
raise OSError(
|
||||
"attempted to connect twice on a already established connection"
|
||||
)
|
||||
|
||||
self._connect_called = True
|
||||
|
||||
# there's a particularity on Windows
|
||||
# we must not forward non-IP in addr due to
|
||||
# a limitation in the network bridge used in asyncio
|
||||
if platform.system() == "Windows":
|
||||
from ..resolver.utils import is_ipv4, is_ipv6
|
||||
|
||||
host, port = addr[:2]
|
||||
|
||||
if not is_ipv4(host) and not is_ipv6(host):
|
||||
res = await asyncio.get_running_loop().getaddrinfo(
|
||||
host,
|
||||
port,
|
||||
family=self.family,
|
||||
type=self.type,
|
||||
)
|
||||
|
||||
if not res:
|
||||
raise socket.gaierror(f"unable to resolve hostname {host}")
|
||||
|
||||
addr = res[0][-1]
|
||||
|
||||
if self._external_timeout is not None:
|
||||
try:
|
||||
async with timeout(self._external_timeout):
|
||||
await asyncio.get_running_loop().sock_connect(self._sock, addr)
|
||||
except (FutureTimeoutError, AsyncioTimeoutError, TimeoutError) as e:
|
||||
self._connect_called = False
|
||||
raise StandardTimeoutError from e
|
||||
except RuntimeError:
|
||||
raise ConnectionError(
|
||||
"Likely FD Kernel/Loop Racing Allocation Error. You should retry."
|
||||
)
|
||||
else:
|
||||
try:
|
||||
await asyncio.get_running_loop().sock_connect(self._sock, addr)
|
||||
except RuntimeError: # Defensive: CPython might raise RuntimeError if there is a FD allocation error.
|
||||
raise ConnectionError(
|
||||
"Likely FD Kernel/Loop Racing Allocation Error. You should retry."
|
||||
)
|
||||
|
||||
if self.type == socket.SOCK_STREAM or self.type == -1: # type: ignore[comparison-overlap]
|
||||
self._reader, self._writer = await asyncio.open_connection(sock=self._sock)
|
||||
elif self.type == socket.SOCK_DGRAM:
|
||||
self._reader, self._writer = await open_dgram_connection(sock=self._sock)
|
||||
|
||||
# can become an asyncio.TransportSocket
|
||||
assert self._writer is not None
|
||||
self._sock = self._writer.get_extra_info("socket", self._sock)
|
||||
|
||||
self._addr = addr
|
||||
self._established.set()
|
||||
|
||||
async def wrap_socket(
|
||||
self,
|
||||
ctx: ssl.SSLContext,
|
||||
*,
|
||||
server_hostname: str | None = None,
|
||||
ssl_handshake_timeout: float | None = None,
|
||||
) -> SSLAsyncSocket:
|
||||
await self._established.wait()
|
||||
self._established.clear()
|
||||
|
||||
# only if Python <= 3.10
|
||||
try:
|
||||
setattr(
|
||||
asyncio.sslproto._SSLProtocolTransport, # type: ignore[attr-defined]
|
||||
"_start_tls_compatible",
|
||||
True,
|
||||
)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if self.type == socket.SOCK_STREAM:
|
||||
assert self._writer is not None
|
||||
assert isinstance(self._writer, asyncio.StreamWriter)
|
||||
|
||||
# bellow is hard to maintain. Starting with 3.11+, it is useless.
|
||||
protocol = self._writer._protocol # type: ignore[attr-defined]
|
||||
await self._writer.drain()
|
||||
|
||||
new_transport = await self._writer._loop.start_tls( # type: ignore[attr-defined]
|
||||
self._writer._transport, # type: ignore[attr-defined]
|
||||
protocol,
|
||||
ctx,
|
||||
server_side=False,
|
||||
server_hostname=server_hostname,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout,
|
||||
)
|
||||
|
||||
self._writer._transport = new_transport # type: ignore[attr-defined]
|
||||
|
||||
transport = self._writer.transport
|
||||
protocol._stream_writer = self._writer
|
||||
protocol._transport = transport
|
||||
protocol._over_ssl = transport.get_extra_info("sslcontext") is not None
|
||||
|
||||
self._tls_ctx = ctx
|
||||
else:
|
||||
raise RuntimeError("Unsupported socket type")
|
||||
|
||||
self._established.set()
|
||||
self.__class__ = SSLAsyncSocket
|
||||
|
||||
return self # type: ignore[return-value]
|
||||
|
||||
async def recv(self, size: int = -1) -> bytes | list[bytes]:
|
||||
"""Receive data from the socket.
|
||||
|
||||
Returns ``bytes`` for a single datagram (or stream chunk), or
|
||||
``list[bytes]`` when GRO / batch-receive delivered multiple
|
||||
coalesced datagrams in one syscall. The caller can then feed
|
||||
all segments to the QUIC state-machine in a tight loop before
|
||||
probing, avoiding per-datagram overhead."""
|
||||
if size == -1:
|
||||
size = 65536
|
||||
assert self._reader is not None
|
||||
await self._established.wait()
|
||||
await self._reader_semaphore.acquire()
|
||||
|
||||
try:
|
||||
if self._external_timeout is not None:
|
||||
try:
|
||||
async with timeout(self._external_timeout):
|
||||
return await self._reader.read(n=size)
|
||||
except (FutureTimeoutError, AsyncioTimeoutError, TimeoutError) as e:
|
||||
self._reader_semaphore.release()
|
||||
raise StandardTimeoutError from e
|
||||
except OSError as e: # Defensive: treat any OSError as ConnReset!
|
||||
raise ConnectionResetError() from e
|
||||
return await self._reader.read(n=size)
|
||||
finally:
|
||||
self._reader_semaphore.release()
|
||||
|
||||
async def read_exact(self, size: int = -1) -> bytes | list[bytes]:
|
||||
"""Just an alias for recv(), it is needed due to our custom AsyncSocks override."""
|
||||
return await self.recv(size=size)
|
||||
|
||||
async def read(self) -> bytes | list[bytes]:
|
||||
"""Just an alias for recv(), it is needed due to our custom AsyncSocks override."""
|
||||
return await self.recv()
|
||||
|
||||
async def sendall(self, data: bytes | bytearray | memoryview | list[bytes]) -> None:
|
||||
assert self._writer is not None
|
||||
await self._established.wait()
|
||||
await self._writer_semaphore.acquire()
|
||||
try:
|
||||
self._writer.write(data) # type: ignore[arg-type]
|
||||
await self._writer.drain()
|
||||
except Exception:
|
||||
raise
|
||||
finally:
|
||||
self._writer_semaphore.release()
|
||||
|
||||
async def write_all(
|
||||
self, data: bytes | bytearray | memoryview | list[bytes]
|
||||
) -> None:
|
||||
"""Just an alias for sendall(), it is needed due to our custom AsyncSocks override."""
|
||||
await self.sendall(data)
|
||||
|
||||
async def send(self, data: bytes | bytearray | memoryview | list[bytes]) -> None:
|
||||
await self.sendall(data)
|
||||
|
||||
def settimeout(self, __value: float | None = None) -> None:
|
||||
self._external_timeout = __value
|
||||
|
||||
def gettimeout(self) -> float | None:
|
||||
return self._external_timeout
|
||||
|
||||
def getpeername(self) -> tuple[str, int]:
|
||||
return self._sock.getpeername() # type: ignore[no-any-return]
|
||||
|
||||
def bind(self, addr: tuple[str, int]) -> None:
|
||||
self._sock.bind(addr)
|
||||
|
||||
|
||||
class SSLAsyncSocket(AsyncSocket):
|
||||
_tls_ctx: ssl.SSLContext
|
||||
_tls_in_tls: bool
|
||||
|
||||
@typing.overload
|
||||
def getpeercert(
|
||||
self, binary_form: Literal[False] = ...
|
||||
) -> _TYPE_PEER_CERT_RET_DICT | None: ...
|
||||
|
||||
@typing.overload
|
||||
def getpeercert(self, binary_form: Literal[True]) -> bytes | None: ...
|
||||
|
||||
def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET:
|
||||
return self.sslobj.getpeercert(binary_form=binary_form) # type: ignore[return-value]
|
||||
|
||||
def selected_alpn_protocol(self) -> str | None:
|
||||
return self.sslobj.selected_alpn_protocol()
|
||||
|
||||
@property
|
||||
def sslobj(self) -> ssl.SSLSocket | ssl.SSLObject:
|
||||
if self._writer is not None:
|
||||
sslobj: ssl.SSLSocket | ssl.SSLObject | None = self._writer.get_extra_info(
|
||||
"ssl_object"
|
||||
)
|
||||
|
||||
if sslobj is not None:
|
||||
return sslobj
|
||||
|
||||
raise RuntimeError(
|
||||
'"ssl_object" could not be extracted from this SslAsyncSock instance'
|
||||
)
|
||||
|
||||
def version(self) -> str | None:
|
||||
return self.sslobj.version()
|
||||
|
||||
@property
|
||||
def context(self) -> ssl.SSLContext:
|
||||
return self.sslobj.context
|
||||
|
||||
@property
|
||||
def _sslobj(self) -> ssl.SSLSocket | ssl.SSLObject:
|
||||
return self.sslobj
|
||||
|
||||
def cipher(self) -> tuple[str, str, int] | None:
|
||||
return self.sslobj.cipher()
|
||||
|
||||
async def wrap_socket(
|
||||
self,
|
||||
ctx: ssl.SSLContext,
|
||||
*,
|
||||
server_hostname: str | None = None,
|
||||
ssl_handshake_timeout: float | None = None,
|
||||
) -> SSLAsyncSocket:
|
||||
self._tls_in_tls = True
|
||||
|
||||
return await super().wrap_socket(
|
||||
ctx,
|
||||
server_hostname=server_hostname,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout,
|
||||
)
|
||||
|
||||
|
||||
def _has_complete_support_dgram() -> bool:
|
||||
"""A bug exist in PyPy asyncio implementation that prevent us to use a DGRAM socket.
|
||||
This piece of code inform us, potentially, if PyPy has fixed the winapi implementation.
|
||||
See https://github.com/pypy/pypy/issues/4008 and https://github.com/jawah/niquests/pull/87
|
||||
|
||||
The stacktrace look as follows:
|
||||
File "C:\\hostedtoolcache\\windows\\PyPy\3.10.13\x86\\Lib\asyncio\\windows_events.py", line 594, in connect
|
||||
_overlapped.WSAConnect(conn.fileno(), address)
|
||||
AttributeError: module '_overlapped' has no attribute 'WSAConnect'
|
||||
"""
|
||||
import platform
|
||||
|
||||
if platform.system() == "Windows" and platform.python_implementation() == "PyPy":
|
||||
try:
|
||||
import _overlapped # type: ignore[import-not-found]
|
||||
except ImportError: # Defensive:
|
||||
return False
|
||||
|
||||
if hasattr(_overlapped, "WSAConnect"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
__all__ = (
|
||||
"AsyncSocket",
|
||||
"SSLAsyncSocket",
|
||||
"_has_complete_support_dgram",
|
||||
)
|
||||
@@ -0,0 +1,640 @@
|
||||
"""
|
||||
High-performance asyncio DatagramTransport with Linux-specific UDP
|
||||
receive/send coalescing:
|
||||
|
||||
- GRO (receive): ``setsockopt(SOL_UDP, UDP_GRO)`` + ``recvmsg`` cmsg
|
||||
- GSO (send): ``sendmsg`` with ``UDP_SEGMENT`` cmsg
|
||||
|
||||
All other platforms fall back to the standard asyncio DatagramTransport.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import socket
|
||||
import struct
|
||||
from collections import deque
|
||||
from typing import Any, Callable
|
||||
|
||||
from ..._constant import UDP_LINUX_GRO, UDP_LINUX_SEGMENT
|
||||
|
||||
_UINT16 = struct.Struct("=H")
|
||||
|
||||
_DEFAULT_GRO_BUF = 65535
|
||||
|
||||
# Flow control watermarks for the custom write queue
|
||||
_HIGH_WATERMARK = 64 * 1024
|
||||
_LOW_WATERMARK = 16 * 1024
|
||||
|
||||
# GSO kernel limit: max segments per sendmsg call
|
||||
_GSO_MAX_SEGMENTS = 64
|
||||
|
||||
|
||||
def _sock_has_gro(sock: socket.socket) -> bool:
|
||||
"""Check if GRO is enabled on *sock* (caller must have set it)."""
|
||||
try:
|
||||
return sock.getsockopt(socket.SOL_UDP, UDP_LINUX_GRO) == 1
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def _sock_has_gso(sock: socket.socket) -> bool:
|
||||
"""Check if the kernel supports GSO on *sock*."""
|
||||
try:
|
||||
sock.getsockopt(socket.SOL_UDP, UDP_LINUX_SEGMENT)
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def _split_gro_buffer(buf: bytes, segment_size: int) -> list[bytes]:
|
||||
if segment_size <= 0 or len(buf) <= segment_size:
|
||||
return [buf]
|
||||
segments = []
|
||||
mv = memoryview(buf)
|
||||
for offset in range(0, len(buf), segment_size):
|
||||
segments.append(bytes(mv[offset : offset + segment_size]))
|
||||
return segments
|
||||
|
||||
|
||||
def _group_by_segment_size(datagrams: list[bytes]) -> list[tuple[int, list[bytes]]]:
|
||||
"""Group consecutive same-size datagrams for Linux UDP GSO.
|
||||
|
||||
GSO requires all segments to be the same size (except the last,
|
||||
which may be shorter). Max 64 segments per ``sendmsg`` call."""
|
||||
if not datagrams:
|
||||
return []
|
||||
groups: list[tuple[int, list[bytes]]] = []
|
||||
current_size = len(datagrams[0])
|
||||
current_group: list[bytes] = [datagrams[0]]
|
||||
for dgram in datagrams[1:]:
|
||||
if len(dgram) == current_size and len(current_group) < _GSO_MAX_SEGMENTS:
|
||||
current_group.append(dgram)
|
||||
else:
|
||||
groups.append((current_size, current_group))
|
||||
current_size = len(dgram)
|
||||
current_group = [dgram]
|
||||
groups.append((current_size, current_group))
|
||||
return groups
|
||||
|
||||
|
||||
def sync_recv_gro(
|
||||
sock: socket.socket, bufsize: int, gro_segment_size: int = 1280
|
||||
) -> bytes | list[bytes]:
|
||||
"""Blocking recvmsg with GRO cmsg parsing. Returns bytes or list[bytes]."""
|
||||
ancbufsize = socket.CMSG_SPACE(_UINT16.size)
|
||||
|
||||
data, ancdata, _flags, addr = sock.recvmsg(bufsize, ancbufsize)
|
||||
|
||||
if not data:
|
||||
return b""
|
||||
|
||||
segment_size = gro_segment_size
|
||||
|
||||
for cmsg_level, cmsg_type, cmsg_data in ancdata:
|
||||
if cmsg_level == socket.SOL_UDP and cmsg_type == UDP_LINUX_GRO:
|
||||
(segment_size,) = _UINT16.unpack(cmsg_data[:2])
|
||||
break
|
||||
|
||||
if len(data) <= segment_size:
|
||||
return data
|
||||
|
||||
return _split_gro_buffer(data, segment_size)
|
||||
|
||||
|
||||
def sync_sendmsg_gso(sock: socket.socket, datagrams: list[bytes]) -> None:
|
||||
"""Batch-send datagrams using GSO. Falls back to individual sends."""
|
||||
for segment_size, group in _group_by_segment_size(datagrams):
|
||||
if len(group) == 1:
|
||||
sock.sendall(group[0])
|
||||
continue
|
||||
buf = b"".join(group)
|
||||
sock.sendmsg(
|
||||
[buf],
|
||||
[(socket.SOL_UDP, UDP_LINUX_SEGMENT, _UINT16.pack(segment_size))],
|
||||
)
|
||||
|
||||
|
||||
class _OptimizedDatagramTransport(asyncio.DatagramTransport):
|
||||
__slots__ = (
|
||||
"_loop",
|
||||
"_sock",
|
||||
"_protocol",
|
||||
"_address",
|
||||
"_gro_enabled",
|
||||
"_gso_enabled",
|
||||
"_gro_segment_size",
|
||||
"_recv_buf_size",
|
||||
"_closing",
|
||||
"_closed_fut",
|
||||
"_extra",
|
||||
"_paused",
|
||||
"_write_ready",
|
||||
"_send_queue",
|
||||
"_buffer_size",
|
||||
"_protocol_paused",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
sock: socket.socket,
|
||||
protocol: asyncio.DatagramProtocol,
|
||||
address: tuple[str, int] | None,
|
||||
gro_enabled: bool,
|
||||
gso_enabled: bool,
|
||||
gro_segment_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._loop = loop
|
||||
self._sock = sock
|
||||
self._protocol = protocol
|
||||
self._address = address
|
||||
self._gro_enabled = gro_enabled
|
||||
self._gso_enabled = gso_enabled
|
||||
self._gro_segment_size = gro_segment_size
|
||||
self._closing = False
|
||||
self._closed_fut: asyncio.Future[None] = loop.create_future()
|
||||
self._paused = False
|
||||
self._write_ready = True
|
||||
|
||||
# Write buffer state
|
||||
self._send_queue: deque[tuple[bytes, tuple[str, int] | None]] = (
|
||||
collections.deque()
|
||||
)
|
||||
self._buffer_size = 0
|
||||
self._protocol_paused = False
|
||||
|
||||
self._recv_buf_size = _DEFAULT_GRO_BUF if gro_enabled else gro_segment_size
|
||||
|
||||
self._extra = {
|
||||
"peername": address,
|
||||
"socket": sock,
|
||||
"sockname": sock.getsockname(),
|
||||
}
|
||||
|
||||
def get_extra_info(self, name: str, default: Any = None) -> Any:
|
||||
return self._extra.get(name, default)
|
||||
|
||||
def is_closing(self) -> bool:
|
||||
return self._closing
|
||||
|
||||
def close(self) -> None:
|
||||
if self._closing:
|
||||
return
|
||||
self._closing = True
|
||||
self._loop.remove_reader(self._sock.fileno())
|
||||
# Drain the write queue gracefully in the background
|
||||
if not self._send_queue:
|
||||
self._loop.call_soon(self._call_connection_lost, None)
|
||||
|
||||
def abort(self) -> None:
|
||||
self._closing = True
|
||||
self._call_connection_lost(None)
|
||||
|
||||
def _call_connection_lost(self, exc: Exception | None) -> None:
|
||||
try:
|
||||
self._loop.remove_reader(self._sock.fileno())
|
||||
self._loop.remove_writer(self._sock.fileno())
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self._protocol.connection_lost(exc)
|
||||
finally:
|
||||
self._sock.close()
|
||||
if not self._closed_fut.done():
|
||||
self._closed_fut.set_result(None)
|
||||
|
||||
def sendto(self, data: bytes, addr: tuple[str, int] | None = None) -> None: # type: ignore[override]
|
||||
if self._closing:
|
||||
raise OSError("Transport is closing")
|
||||
|
||||
target = addr or self._address
|
||||
if not self._write_ready:
|
||||
self._queue_write(data, target)
|
||||
return
|
||||
|
||||
try:
|
||||
if target is not None:
|
||||
self._sock.sendto(data, target)
|
||||
else:
|
||||
self._sock.send(data)
|
||||
except BlockingIOError:
|
||||
self._write_ready = False
|
||||
self._loop.add_writer(self._sock.fileno(), self._on_write_ready)
|
||||
self._queue_write(data, target)
|
||||
except OSError as exc:
|
||||
self._protocol.error_received(exc)
|
||||
|
||||
def sendto_many(self, datagrams: list[bytes]) -> None:
|
||||
"""Send multiple datagrams, using GSO when available.
|
||||
|
||||
Falls back to individual ``sendto`` calls when GSO is not
|
||||
supported or the socket write buffer is full."""
|
||||
if self._closing:
|
||||
raise OSError("Transport is closing")
|
||||
|
||||
if not self._write_ready:
|
||||
target = self._address
|
||||
for dgram in datagrams:
|
||||
self._queue_write(dgram, target)
|
||||
return
|
||||
|
||||
if self._gso_enabled:
|
||||
self._send_linux_gso(datagrams)
|
||||
else:
|
||||
for dgram in datagrams:
|
||||
self.sendto(dgram)
|
||||
|
||||
def _send_linux_gso(self, datagrams: list[bytes]) -> None:
|
||||
for segment_size, group in _group_by_segment_size(datagrams):
|
||||
if len(group) == 1:
|
||||
# Single datagram — plain send (GSO needs >1 segment)
|
||||
try:
|
||||
self._sock.send(group[0])
|
||||
except BlockingIOError:
|
||||
self._write_ready = False
|
||||
self._loop.add_writer(self._sock.fileno(), self._on_write_ready)
|
||||
self._queue_write(group[0], self._address)
|
||||
return
|
||||
except OSError as exc:
|
||||
self._protocol.error_received(exc)
|
||||
continue
|
||||
|
||||
buf = b"".join(group)
|
||||
try:
|
||||
self._sock.sendmsg(
|
||||
[buf],
|
||||
[(socket.SOL_UDP, UDP_LINUX_SEGMENT, _UINT16.pack(segment_size))],
|
||||
)
|
||||
except BlockingIOError:
|
||||
self._write_ready = False
|
||||
self._loop.add_writer(self._sock.fileno(), self._on_write_ready)
|
||||
# Queue individual datagrams as fallback
|
||||
for dgram in group:
|
||||
self._queue_write(dgram, self._address)
|
||||
return
|
||||
except OSError as exc:
|
||||
self._protocol.error_received(exc)
|
||||
|
||||
def _queue_write(self, data: bytes, addr: tuple[str, int] | None) -> None:
|
||||
self._send_queue.append((data, addr))
|
||||
self._buffer_size += len(data)
|
||||
self._maybe_pause_protocol()
|
||||
|
||||
def _maybe_pause_protocol(self) -> None:
|
||||
if self._buffer_size >= _HIGH_WATERMARK and not self._protocol_paused:
|
||||
self._protocol_paused = True
|
||||
try:
|
||||
self._protocol.pause_writing()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def _maybe_resume_protocol(self) -> None:
|
||||
if self._protocol_paused and self._buffer_size <= _LOW_WATERMARK:
|
||||
self._protocol_paused = False
|
||||
try:
|
||||
self._protocol.resume_writing()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def _on_write_ready(self) -> None:
|
||||
while self._send_queue:
|
||||
data, addr = self._send_queue[0]
|
||||
try:
|
||||
if addr is not None:
|
||||
self._sock.sendto(data, addr)
|
||||
else:
|
||||
self._sock.send(data)
|
||||
except BlockingIOError:
|
||||
return
|
||||
except OSError as exc:
|
||||
self._protocol.error_received(exc)
|
||||
|
||||
self._send_queue.popleft()
|
||||
self._buffer_size -= len(data)
|
||||
|
||||
self._maybe_resume_protocol()
|
||||
self._write_ready = True
|
||||
self._loop.remove_writer(self._sock.fileno())
|
||||
|
||||
if self._closing:
|
||||
self._call_connection_lost(None)
|
||||
|
||||
def pause_reading(self) -> None:
|
||||
if not self._paused:
|
||||
self._paused = True
|
||||
self._loop.remove_reader(self._sock.fileno())
|
||||
|
||||
def resume_reading(self) -> None:
|
||||
if self._paused:
|
||||
self._paused = False
|
||||
self._loop.add_reader(self._sock.fileno(), self._on_readable)
|
||||
|
||||
def _start(self) -> None:
|
||||
self._loop.call_soon(self._protocol.connection_made, self)
|
||||
self._loop.add_reader(self._sock.fileno(), self._on_readable)
|
||||
|
||||
def _on_readable(self) -> None:
|
||||
if self._closing:
|
||||
return
|
||||
self._recv_linux_gro()
|
||||
|
||||
def _recv_linux_gro(self) -> None:
|
||||
ancbufsize = socket.CMSG_SPACE(_UINT16.size)
|
||||
while True:
|
||||
try:
|
||||
data, ancdata, _flags, addr = self._sock.recvmsg(
|
||||
self._recv_buf_size, ancbufsize
|
||||
)
|
||||
except BlockingIOError:
|
||||
return
|
||||
except OSError as exc:
|
||||
self._protocol.error_received(exc)
|
||||
return
|
||||
|
||||
if not data:
|
||||
return
|
||||
|
||||
segment_size = self._gro_segment_size
|
||||
for cmsg_level, cmsg_type, cmsg_data in ancdata:
|
||||
if cmsg_level == socket.SOL_UDP and cmsg_type == UDP_LINUX_GRO:
|
||||
(segment_size,) = _UINT16.unpack(cmsg_data[:2])
|
||||
break
|
||||
|
||||
if len(data) <= segment_size:
|
||||
self._protocol.datagram_received(data, addr)
|
||||
else:
|
||||
segments = _split_gro_buffer(data, segment_size)
|
||||
self._protocol.datagrams_received(segments, addr) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def create_udp_endpoint(
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
protocol_factory: Callable[[], asyncio.DatagramProtocol],
|
||||
*,
|
||||
local_addr: tuple[str, int] | None = None,
|
||||
remote_addr: tuple[str, int] | None = None,
|
||||
family: int = socket.AF_UNSPEC,
|
||||
reuse_port: bool = False,
|
||||
gro_segment_size: int = 1280,
|
||||
sock: socket.socket | None = None,
|
||||
) -> tuple[asyncio.DatagramTransport, asyncio.DatagramProtocol]:
|
||||
if sock is not None:
|
||||
# Caller provided a pre-connected socket — skip creation/bind/connect.
|
||||
try:
|
||||
connected_addr = sock.getpeername()
|
||||
except OSError:
|
||||
connected_addr = None
|
||||
else:
|
||||
# 1. Resolve Addresses
|
||||
if family == socket.AF_UNSPEC:
|
||||
target_addr = local_addr or remote_addr
|
||||
if target_addr:
|
||||
infos = await loop.getaddrinfo(
|
||||
target_addr[0], target_addr[1], type=socket.SOCK_DGRAM
|
||||
)
|
||||
family = infos[0][0]
|
||||
else:
|
||||
family = socket.AF_INET
|
||||
|
||||
# 2. Create Socket
|
||||
sock = socket.socket(family, socket.SOCK_DGRAM)
|
||||
sock.setblocking(False)
|
||||
|
||||
if reuse_port:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
|
||||
if local_addr:
|
||||
sock.bind(local_addr)
|
||||
|
||||
connected_addr = None
|
||||
|
||||
if remote_addr:
|
||||
await loop.sock_connect(sock, remote_addr)
|
||||
connected_addr = remote_addr
|
||||
|
||||
# 3. Determine capabilities — the caller is responsible for
|
||||
# enabling GRO via setsockopt before handing us the socket.
|
||||
gro_enabled = _sock_has_gro(sock)
|
||||
gso_enabled = _sock_has_gso(sock)
|
||||
|
||||
if not gro_enabled and not gso_enabled:
|
||||
return await loop.create_datagram_endpoint(
|
||||
lambda: protocol_factory(), sock=sock
|
||||
)
|
||||
|
||||
# 4. Wire up optimized transport
|
||||
protocol = protocol_factory()
|
||||
|
||||
transport = _OptimizedDatagramTransport(
|
||||
loop=loop,
|
||||
sock=sock,
|
||||
protocol=protocol,
|
||||
address=connected_addr,
|
||||
gro_enabled=gro_enabled,
|
||||
gso_enabled=gso_enabled,
|
||||
gro_segment_size=gro_segment_size,
|
||||
)
|
||||
|
||||
transport._start()
|
||||
|
||||
return transport, protocol
|
||||
|
||||
|
||||
class DatagramReader:
|
||||
"""API-compatible with ``asyncio.StreamReader`` (duck-typed) so that
|
||||
``AsyncSocket`` can assign an instance to ``self._reader`` and the
|
||||
existing ``recv()`` code works unchanged.
|
||||
|
||||
When GRO delivers multiple coalesced segments in a single syscall,
|
||||
``feed_datagrams()`` stores them as a single ``list[bytes]`` entry.
|
||||
``read()`` then returns that list directly so the caller can feed
|
||||
all segments to the QUIC state-machine in one pass before probing —
|
||||
avoiding the per-datagram recv→feed→probe round-trip overhead."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._buffer: deque[bytes | list[bytes]] = collections.deque()
|
||||
self._waiter: asyncio.Future[None] | None = None
|
||||
self._exception: BaseException | None = None
|
||||
self._eof = False
|
||||
|
||||
def feed_datagram(self, data: bytes, addr: Any) -> None:
|
||||
"""Feed a single (non-coalesced) datagram."""
|
||||
self._buffer.append(data)
|
||||
self._wake_waiter()
|
||||
|
||||
def feed_datagrams(self, data: list[bytes], addr: Any) -> None:
|
||||
"""Feed a batch of coalesced datagrams as a single entry."""
|
||||
self._buffer.append(data)
|
||||
self._wake_waiter()
|
||||
|
||||
def set_exception(self, exc: BaseException) -> None:
|
||||
self._exception = exc
|
||||
self._wake_waiter()
|
||||
|
||||
def connection_lost(self, exc: BaseException | None) -> None:
|
||||
self._eof = True
|
||||
if exc is not None:
|
||||
self._exception = exc
|
||||
self._wake_waiter()
|
||||
|
||||
def _wake_waiter(self) -> None:
|
||||
waiter = self._waiter
|
||||
if waiter is not None and not waiter.done():
|
||||
waiter.set_result(None)
|
||||
|
||||
async def read(self, n: int = -1) -> bytes | list[bytes]:
|
||||
"""Return the next entry from the buffer.
|
||||
|
||||
* ``bytes`` — a single datagram (non-coalesced).
|
||||
* ``list[bytes]`` — a batch of coalesced datagrams from one
|
||||
GRO syscall.
|
||||
* ``b""`` — EOF.
|
||||
"""
|
||||
if self._buffer:
|
||||
return self._buffer.popleft()
|
||||
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
|
||||
if self._eof:
|
||||
return b""
|
||||
|
||||
self._waiter = asyncio.get_running_loop().create_future()
|
||||
try:
|
||||
await self._waiter
|
||||
finally:
|
||||
self._waiter = None
|
||||
|
||||
if self._buffer:
|
||||
return self._buffer.popleft()
|
||||
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
|
||||
return b""
|
||||
|
||||
|
||||
class DatagramWriter:
|
||||
"""API-compatible with ``asyncio.StreamWriter`` (duck-typed) so that
|
||||
``AsyncSocket`` can assign an instance to ``self._writer`` and the
|
||||
existing ``sendall()``, ``close()``, ``wait_for_close()`` code works
|
||||
unchanged."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: asyncio.DatagramTransport,
|
||||
) -> None:
|
||||
self._transport = transport
|
||||
self._address: tuple[str, int] | None = transport.get_extra_info("peername")
|
||||
self._closed_event = asyncio.Event()
|
||||
self._paused = False
|
||||
self._drain_waiter: asyncio.Future[None] | None = None
|
||||
|
||||
@property
|
||||
def transport(self) -> asyncio.DatagramTransport:
|
||||
return self._transport
|
||||
|
||||
def write(self, data: bytes | bytearray | memoryview | list[bytes]) -> None:
|
||||
if self._transport.is_closing():
|
||||
return
|
||||
if isinstance(data, list):
|
||||
if hasattr(self._transport, "sendto_many"):
|
||||
self._transport.sendto_many(data)
|
||||
else:
|
||||
# Plain asyncio transport — send individually
|
||||
for dgram in data:
|
||||
self._transport.sendto(dgram, self._address)
|
||||
else:
|
||||
self._transport.sendto(bytes(data), self._address)
|
||||
|
||||
async def drain(self) -> None:
|
||||
if not self._paused:
|
||||
return
|
||||
self._drain_waiter = asyncio.get_running_loop().create_future()
|
||||
try:
|
||||
await self._drain_waiter
|
||||
finally:
|
||||
self._drain_waiter = None
|
||||
|
||||
def close(self) -> None:
|
||||
self._transport.close()
|
||||
|
||||
async def wait_closed(self) -> None:
|
||||
await self._closed_event.wait()
|
||||
|
||||
def get_extra_info(self, name: str, default: Any = None) -> Any:
|
||||
return self._transport.get_extra_info(name, default)
|
||||
|
||||
def _pause_writing(self) -> None:
|
||||
self._paused = True
|
||||
|
||||
def _resume_writing(self) -> None:
|
||||
self._paused = False
|
||||
waiter = self._drain_waiter
|
||||
if waiter is not None and not waiter.done():
|
||||
waiter.set_result(None)
|
||||
|
||||
|
||||
class _DatagramBridgeProtocol(asyncio.DatagramProtocol):
|
||||
"""Bridges ``asyncio.DatagramProtocol`` callbacks to
|
||||
``DatagramReader`` / ``DatagramWriter``."""
|
||||
|
||||
def __init__(self, reader: DatagramReader) -> None:
|
||||
self._reader = reader
|
||||
self._writer: DatagramWriter | None = None
|
||||
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
pass # transport is already wired via DatagramWriter
|
||||
|
||||
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
|
||||
self._reader.feed_datagram(data, addr)
|
||||
|
||||
def datagrams_received(self, data: list[bytes], addr: tuple[str, int]) -> None:
|
||||
self._reader.feed_datagrams(data, addr)
|
||||
|
||||
def error_received(self, exc: Exception) -> None:
|
||||
self._reader.set_exception(exc)
|
||||
|
||||
def connection_lost(self, exc: BaseException | None) -> None:
|
||||
self._reader.connection_lost(exc)
|
||||
if self._writer is not None:
|
||||
self._writer._closed_event.set()
|
||||
|
||||
def pause_writing(self) -> None:
|
||||
if self._writer is not None:
|
||||
self._writer._pause_writing()
|
||||
|
||||
def resume_writing(self) -> None:
|
||||
if self._writer is not None:
|
||||
self._writer._resume_writing()
|
||||
|
||||
|
||||
async def open_dgram_connection(
|
||||
remote_addr: tuple[str, int] | None = None,
|
||||
*,
|
||||
local_addr: tuple[str, int] | None = None,
|
||||
family: int = socket.AF_UNSPEC,
|
||||
sock: socket.socket | None = None,
|
||||
gro_segment_size: int = 1280,
|
||||
) -> tuple[DatagramReader, DatagramWriter]:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
reader = DatagramReader()
|
||||
protocol = _DatagramBridgeProtocol(reader)
|
||||
|
||||
transport, _ = await create_udp_endpoint(
|
||||
loop,
|
||||
lambda: protocol,
|
||||
local_addr=local_addr,
|
||||
remote_addr=remote_addr,
|
||||
family=family,
|
||||
gro_segment_size=gro_segment_size,
|
||||
sock=sock,
|
||||
)
|
||||
|
||||
writer = DatagramWriter(transport)
|
||||
protocol._writer = writer
|
||||
|
||||
return reader, writer
|
||||
@@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from asyncio import CancelledError, events, tasks
|
||||
from types import TracebackType
|
||||
|
||||
__all__ = (
|
||||
"Timeout",
|
||||
"timeout",
|
||||
)
|
||||
|
||||
|
||||
class _State(enum.Enum):
|
||||
CREATED = "created"
|
||||
ENTERED = "active"
|
||||
EXPIRING = "expiring"
|
||||
EXPIRED = "expired"
|
||||
EXITED = "finished"
|
||||
|
||||
|
||||
class Timeout:
|
||||
"""Asynchronous context manager for cancelling overdue coroutines.
|
||||
|
||||
Use `timeout()` or `timeout_at()` rather than instantiating this class directly.
|
||||
"""
|
||||
|
||||
def __init__(self, when: float | None) -> None:
|
||||
"""Schedule a timeout that will trigger at a given loop time.
|
||||
|
||||
- If `when` is `None`, the timeout will never trigger.
|
||||
- If `when < loop.time()`, the timeout will trigger on the next
|
||||
iteration of the event loop.
|
||||
"""
|
||||
self._state = _State.CREATED
|
||||
|
||||
self._timeout_handler: events.TimerHandle | events.Handle | None = None
|
||||
self._task: tasks.Task | None = None # type: ignore[type-arg]
|
||||
self._when = when
|
||||
|
||||
def when(self) -> float | None:
|
||||
"""Return the current deadline."""
|
||||
return self._when
|
||||
|
||||
def reschedule(self, when: float | None) -> None:
|
||||
"""Reschedule the timeout."""
|
||||
if self._state is not _State.ENTERED:
|
||||
if self._state is _State.CREATED:
|
||||
raise RuntimeError("Timeout has not been entered")
|
||||
raise RuntimeError(
|
||||
f"Cannot change state of {self._state.value} Timeout",
|
||||
)
|
||||
|
||||
self._when = when
|
||||
|
||||
if self._timeout_handler is not None:
|
||||
self._timeout_handler.cancel()
|
||||
|
||||
if when is None:
|
||||
self._timeout_handler = None
|
||||
else:
|
||||
loop = events.get_running_loop()
|
||||
if when <= loop.time():
|
||||
self._timeout_handler = loop.call_soon(self._on_timeout)
|
||||
else:
|
||||
self._timeout_handler = loop.call_at(when, self._on_timeout)
|
||||
|
||||
def expired(self) -> bool:
|
||||
"""Is timeout expired during execution?"""
|
||||
return self._state in (_State.EXPIRING, _State.EXPIRED)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
info = [""]
|
||||
if self._state is _State.ENTERED:
|
||||
when = round(self._when, 3) if self._when is not None else None
|
||||
info.append(f"when={when}")
|
||||
info_str = " ".join(info)
|
||||
return f"<Timeout [{self._state.value}]{info_str}>"
|
||||
|
||||
async def __aenter__(self) -> "Timeout":
|
||||
if self._state is not _State.CREATED:
|
||||
raise RuntimeError("Timeout has already been entered")
|
||||
task = tasks.current_task()
|
||||
if task is None:
|
||||
raise RuntimeError("Timeout should be used inside a task")
|
||||
self._state = _State.ENTERED
|
||||
self._task = task
|
||||
self.reschedule(self._when)
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
assert self._state in (_State.ENTERED, _State.EXPIRING)
|
||||
assert self._task is not None
|
||||
|
||||
if self._timeout_handler is not None:
|
||||
self._timeout_handler.cancel()
|
||||
self._timeout_handler = None
|
||||
|
||||
if self._state is _State.EXPIRING:
|
||||
self._state = _State.EXPIRED
|
||||
|
||||
if exc_type is CancelledError:
|
||||
# Since there are no new cancel requests, we're
|
||||
# handling this.
|
||||
raise TimeoutError from exc_val
|
||||
elif self._state is _State.ENTERED:
|
||||
self._state = _State.EXITED
|
||||
|
||||
return None
|
||||
|
||||
def _on_timeout(self) -> None:
|
||||
assert self._state is _State.ENTERED
|
||||
assert self._task is not None
|
||||
|
||||
self._task.cancel()
|
||||
self._state = _State.EXPIRING
|
||||
# drop the reference early
|
||||
self._timeout_handler = None
|
||||
|
||||
|
||||
def timeout(delay: float | None) -> Timeout:
|
||||
"""Timeout async context manager.
|
||||
|
||||
Useful in cases when you want to apply timeout logic around block
|
||||
of code or in cases when asyncio.wait_for is not suitable. For example:
|
||||
|
||||
>>> async with asyncio.timeout(10): # 10 seconds timeout
|
||||
... await long_running_task()
|
||||
|
||||
|
||||
delay - value in seconds or None to disable timeout logic
|
||||
|
||||
long_running_task() is interrupted by raising asyncio.CancelledError,
|
||||
the top-most affected timeout() context manager converts CancelledError
|
||||
into TimeoutError.
|
||||
"""
|
||||
loop = events.get_running_loop()
|
||||
return Timeout(loop.time() + delay if delay is not None else None)
|
||||
@@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .protocol import ExtensionFromHTTP
|
||||
from .raw import RawExtensionFromHTTP
|
||||
from .sse import ServerSideEventExtensionFromHTTP
|
||||
|
||||
try:
|
||||
from .ws import WebSocketExtensionFromHTTP, WebSocketExtensionFromMultiplexedHTTP
|
||||
except ImportError:
|
||||
WebSocketExtensionFromHTTP = None # type: ignore[misc, assignment]
|
||||
WebSocketExtensionFromMultiplexedHTTP = None # type: ignore[misc, assignment]
|
||||
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def recursive_subclasses(cls: type[T]) -> list[type[T]]:
|
||||
all_subclasses = []
|
||||
|
||||
for subclass in cls.__subclasses__():
|
||||
all_subclasses.append(subclass)
|
||||
all_subclasses.extend(recursive_subclasses(subclass))
|
||||
|
||||
return all_subclasses
|
||||
|
||||
|
||||
def load_extension(
|
||||
scheme: str | None, implementation: str | None = None
|
||||
) -> type[ExtensionFromHTTP]:
|
||||
if scheme is None:
|
||||
return RawExtensionFromHTTP
|
||||
|
||||
scheme = scheme.lower()
|
||||
|
||||
if implementation:
|
||||
implementation = implementation.lower()
|
||||
|
||||
for extension in recursive_subclasses(ExtensionFromHTTP):
|
||||
if scheme in extension.supported_schemes():
|
||||
if (
|
||||
implementation is not None
|
||||
and extension.implementation() != implementation
|
||||
):
|
||||
continue
|
||||
return extension
|
||||
|
||||
raise ImportError(
|
||||
f"Tried to load HTTP extension '{scheme}' but no available plugin support it."
|
||||
)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"ExtensionFromHTTP",
|
||||
"RawExtensionFromHTTP",
|
||||
"WebSocketExtensionFromHTTP",
|
||||
"WebSocketExtensionFromMultiplexedHTTP",
|
||||
"ServerSideEventExtensionFromHTTP",
|
||||
"load_extension",
|
||||
)
|
||||
@@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .protocol import AsyncExtensionFromHTTP
|
||||
from .raw import AsyncRawExtensionFromHTTP
|
||||
from .sse import AsyncServerSideEventExtensionFromHTTP
|
||||
|
||||
try:
|
||||
from .ws import (
|
||||
AsyncWebSocketExtensionFromHTTP,
|
||||
AsyncWebSocketExtensionFromMultiplexedHTTP,
|
||||
)
|
||||
except ImportError:
|
||||
AsyncWebSocketExtensionFromHTTP = None # type: ignore[misc, assignment]
|
||||
AsyncWebSocketExtensionFromMultiplexedHTTP = None # type: ignore[misc, assignment]
|
||||
|
||||
from .. import recursive_subclasses
|
||||
|
||||
|
||||
def load_extension(
|
||||
scheme: str | None, implementation: str | None = None
|
||||
) -> type[AsyncExtensionFromHTTP]:
|
||||
if scheme is None:
|
||||
return AsyncRawExtensionFromHTTP
|
||||
|
||||
scheme = scheme.lower()
|
||||
|
||||
if implementation:
|
||||
implementation = implementation.lower()
|
||||
|
||||
for extension in recursive_subclasses(AsyncExtensionFromHTTP):
|
||||
if scheme in extension.supported_schemes():
|
||||
if (
|
||||
implementation is not None
|
||||
and extension.implementation() != implementation
|
||||
):
|
||||
continue
|
||||
return extension
|
||||
|
||||
raise ImportError(
|
||||
f"Tried to load HTTP extension '{scheme}' but no available plugin support it."
|
||||
)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"AsyncExtensionFromHTTP",
|
||||
"AsyncRawExtensionFromHTTP",
|
||||
"AsyncWebSocketExtensionFromHTTP",
|
||||
"AsyncWebSocketExtensionFromMultiplexedHTTP",
|
||||
"AsyncServerSideEventExtensionFromHTTP",
|
||||
"load_extension",
|
||||
)
|
||||
@@ -0,0 +1,188 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from abc import ABCMeta
|
||||
from contextlib import asynccontextmanager
|
||||
from socket import timeout as SocketTimeout
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ...._async.response import AsyncHTTPResponse
|
||||
from ....backend import HttpVersion
|
||||
from ....backend._async._base import AsyncDirectStreamAccess
|
||||
from ....util._async.traffic_police import AsyncTrafficPolice
|
||||
|
||||
from ....exceptions import (
|
||||
BaseSSLError,
|
||||
ProtocolError,
|
||||
ReadTimeoutError,
|
||||
SSLError,
|
||||
MustRedialError,
|
||||
)
|
||||
|
||||
|
||||
class AsyncExtensionFromHTTP(metaclass=ABCMeta):
|
||||
"""Represent an extension that can be negotiated just after a "101 Switching Protocol" HTTP response.
|
||||
This will considerably ease downstream integration."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._dsa: AsyncDirectStreamAccess | None = None
|
||||
self._response: AsyncHTTPResponse | None = None
|
||||
self._police_officer: AsyncTrafficPolice | None = None # type: ignore[type-arg]
|
||||
|
||||
@asynccontextmanager
|
||||
async def _read_error_catcher(self) -> typing.AsyncGenerator[None, None]:
|
||||
"""
|
||||
Catch low-level python exceptions, instead re-raising urllib3
|
||||
variants, so that low-level exceptions are not leaked in the
|
||||
high-level api.
|
||||
|
||||
On unrecoverable issues, release the connection back to the pool.
|
||||
"""
|
||||
clean_exit = False
|
||||
|
||||
try:
|
||||
try:
|
||||
yield
|
||||
|
||||
except SocketTimeout as e:
|
||||
clean_exit = True
|
||||
pool = (
|
||||
self._response._pool
|
||||
if self._response and hasattr(self._response, "_pool")
|
||||
else None
|
||||
)
|
||||
raise ReadTimeoutError(pool, None, "Read timed out.") from e # type: ignore[arg-type]
|
||||
|
||||
except BaseSSLError as e:
|
||||
# FIXME: Is there a better way to differentiate between SSLErrors?
|
||||
if "read operation timed out" not in str(e):
|
||||
# SSL errors related to framing/MAC get wrapped and reraised here
|
||||
raise SSLError(e) from e
|
||||
clean_exit = True # ws algorithms based on timeouts can expect this without being harmful!
|
||||
pool = (
|
||||
self._response._pool
|
||||
if self._response and hasattr(self._response, "_pool")
|
||||
else None
|
||||
)
|
||||
raise ReadTimeoutError(pool, None, "Read timed out.") from e # type: ignore[arg-type]
|
||||
|
||||
except (OSError, MustRedialError) as e:
|
||||
# This includes IncompleteRead.
|
||||
raise ProtocolError(f"Connection broken: {e!r}", e) from e
|
||||
|
||||
# If no exception is thrown, we should avoid cleaning up
|
||||
# unnecessarily.
|
||||
clean_exit = True
|
||||
finally:
|
||||
# If we didn't terminate cleanly, we need to throw away our
|
||||
# connection.
|
||||
if not clean_exit:
|
||||
# The response may not be closed but we're not going to use it
|
||||
# anymore so close it now to ensure that the connection is
|
||||
# released back to the pool.
|
||||
if self._response:
|
||||
await self.close()
|
||||
|
||||
@asynccontextmanager
|
||||
async def _write_error_catcher(self) -> typing.AsyncGenerator[None, None]:
|
||||
"""
|
||||
Catch low-level python exceptions, instead re-raising urllib3
|
||||
variants, so that low-level exceptions are not leaked in the
|
||||
high-level api.
|
||||
|
||||
On unrecoverable issues, release the connection back to the pool.
|
||||
"""
|
||||
clean_exit = False
|
||||
|
||||
try:
|
||||
try:
|
||||
yield
|
||||
|
||||
except SocketTimeout as e:
|
||||
pool = (
|
||||
self._response._pool
|
||||
if self._response and hasattr(self._response, "_pool")
|
||||
else None
|
||||
)
|
||||
raise ReadTimeoutError(pool, None, "Read timed out.") from e # type: ignore[arg-type]
|
||||
|
||||
except BaseSSLError as e:
|
||||
raise SSLError(e) from e
|
||||
|
||||
except OSError as e:
|
||||
# This includes IncompleteRead.
|
||||
raise ProtocolError(f"Connection broken: {e!r}", e) from e
|
||||
|
||||
# If no exception is thrown, we should avoid cleaning up
|
||||
# unnecessarily.
|
||||
clean_exit = True
|
||||
finally:
|
||||
# If we didn't terminate cleanly, we need to throw away our
|
||||
# connection.
|
||||
if not clean_exit:
|
||||
# The response may not be closed but we're not going to use it
|
||||
# anymore so close it now to ensure that the connection is
|
||||
# released back to the pool.
|
||||
if self._response:
|
||||
await self.close()
|
||||
|
||||
@property
|
||||
def urlopen_kwargs(self) -> dict[str, typing.Any]:
|
||||
return {}
|
||||
|
||||
async def start(self, response: AsyncHTTPResponse) -> None:
|
||||
"""The HTTP server gave us the go-to start negotiating another protocol."""
|
||||
if response._fp is None or not hasattr(response._fp, "_dsa"):
|
||||
raise OSError("The HTTP extension is closed or uninitialized")
|
||||
|
||||
self._dsa = response._fp._dsa
|
||||
self._police_officer = response._police_officer
|
||||
self._response = response
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._dsa is None
|
||||
|
||||
@staticmethod
|
||||
def supported_svn() -> set[HttpVersion]:
|
||||
"""Hint about supported parent SVN for this extension."""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def supported_schemes() -> set[str]:
|
||||
"""Recognized schemes for the extension."""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def scheme_to_http_scheme(scheme: str) -> str:
|
||||
"""Convert the extension scheme to a known http scheme (either http or https)"""
|
||||
raise NotImplementedError
|
||||
|
||||
def headers(self, http_version: HttpVersion) -> dict[str, str]:
|
||||
"""Specific HTTP headers required (request) before the 101 status response."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def close(self) -> None:
|
||||
"""End/Notify close for sub protocol."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def next_payload(self) -> str | bytes | None:
|
||||
"""Unpack the next received message/payload from remote. This call does read from the socket.
|
||||
If the method return None, it means that the remote closed the (extension) pipeline.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def send_payload(self, buf: str | bytes) -> None:
|
||||
"""Dispatch a buffer to remote."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def on_payload(
|
||||
self, callback: typing.Callable[[str | bytes | None], typing.Awaitable[None]]
|
||||
) -> None:
|
||||
"""Set up a callback that will be invoked automatically once a payload is received.
|
||||
Meaning that you stop calling manually next_payload()."""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ....backend import HttpVersion
|
||||
from .protocol import AsyncExtensionFromHTTP
|
||||
|
||||
|
||||
class AsyncRawExtensionFromHTTP(AsyncExtensionFromHTTP):
|
||||
"""Raw I/O from given HTTP stream after a 101 Switching Protocol Status."""
|
||||
|
||||
@staticmethod
|
||||
def supported_svn() -> set[HttpVersion]:
|
||||
return {HttpVersion.h11, HttpVersion.h2, HttpVersion.h3}
|
||||
|
||||
def headers(self, http_version: HttpVersion) -> dict[str, str]:
|
||||
"""Specific HTTP headers required (request) before the 101 status response."""
|
||||
return {}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""End/Notify close for sub protocol."""
|
||||
if self._dsa is not None:
|
||||
await self._dsa.close()
|
||||
self._dsa = None
|
||||
if self._response is not None:
|
||||
await self._response.close()
|
||||
self._response = None
|
||||
self._police_officer = None
|
||||
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
return "raw"
|
||||
|
||||
@staticmethod
|
||||
def supported_schemes() -> set[str]:
|
||||
return set()
|
||||
|
||||
@staticmethod
|
||||
def scheme_to_http_scheme(scheme: str) -> str:
|
||||
return scheme
|
||||
|
||||
async def next_payload(self) -> bytes | None:
|
||||
if self._police_officer is None or self._dsa is None:
|
||||
raise OSError("The HTTP extension is closed or uninitialized")
|
||||
async with self._police_officer.borrow(self._response):
|
||||
async with self._read_error_catcher():
|
||||
data, eot, _ = await self._dsa.recv_extended(None)
|
||||
return data
|
||||
|
||||
async def send_payload(self, buf: str | bytes) -> None:
|
||||
if self._police_officer is None or self._dsa is None:
|
||||
raise OSError("The HTTP extension is closed or uninitialized")
|
||||
|
||||
if isinstance(buf, str):
|
||||
buf = buf.encode()
|
||||
|
||||
async with self._police_officer.borrow(self._response):
|
||||
async with self._write_error_catcher():
|
||||
await self._dsa.sendall(buf)
|
||||
@@ -0,0 +1,132 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ...._async.response import AsyncHTTPResponse
|
||||
|
||||
from ....backend import HttpVersion
|
||||
from ..sse import ServerSentEvent
|
||||
from .protocol import AsyncExtensionFromHTTP
|
||||
|
||||
|
||||
class AsyncServerSideEventExtensionFromHTTP(AsyncExtensionFromHTTP):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._last_event_id: str | None = None
|
||||
self._buffer: str = ""
|
||||
self._stream: typing.AsyncGenerator[bytes, None] | None = None
|
||||
|
||||
@staticmethod
|
||||
def supported_svn() -> set[HttpVersion]:
|
||||
return {HttpVersion.h11, HttpVersion.h2, HttpVersion.h3}
|
||||
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
return "native"
|
||||
|
||||
@property
|
||||
def urlopen_kwargs(self) -> dict[str, typing.Any]:
|
||||
return {"preload_content": False}
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._stream is not None and self._response is not None:
|
||||
await self._stream.aclose()
|
||||
if (
|
||||
self._response._fp is not None
|
||||
and self._police_officer is not None
|
||||
and hasattr(self._response._fp, "abort")
|
||||
):
|
||||
async with self._police_officer.borrow(self._response):
|
||||
await self._response._fp.abort()
|
||||
self._stream = None
|
||||
self._response = None
|
||||
self._police_officer = None
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._stream is None
|
||||
|
||||
async def start(self, response: AsyncHTTPResponse) -> None:
|
||||
await super().start(response)
|
||||
|
||||
self._stream = response.stream(-1, decode_content=True)
|
||||
|
||||
def headers(self, http_version: HttpVersion) -> dict[str, str]:
|
||||
return {"accept": "text/event-stream", "cache-control": "no-store"}
|
||||
|
||||
@typing.overload
|
||||
async def next_payload(self, *, raw: typing.Literal[True] = True) -> str | None: ...
|
||||
|
||||
@typing.overload
|
||||
async def next_payload(
|
||||
self, *, raw: typing.Literal[False] = False
|
||||
) -> ServerSentEvent | None: ...
|
||||
|
||||
async def next_payload(self, *, raw: bool = False) -> ServerSentEvent | str | None:
|
||||
"""Unpack the next received message/payload from remote."""
|
||||
if self._response is None or self._stream is None:
|
||||
raise OSError("The HTTP extension is closed or uninitialized")
|
||||
|
||||
try:
|
||||
raw_payload: str = (await self._stream.__anext__()).decode("utf-8")
|
||||
except StopAsyncIteration:
|
||||
await self._stream.aclose()
|
||||
self._stream = None
|
||||
return None
|
||||
|
||||
if self._buffer:
|
||||
raw_payload = self._buffer + raw_payload
|
||||
self._buffer = ""
|
||||
|
||||
kwargs: dict[str, typing.Any] = {}
|
||||
eot = False
|
||||
|
||||
for line in raw_payload.splitlines():
|
||||
if not line:
|
||||
eot = True
|
||||
break
|
||||
key, _, value = line.partition(":")
|
||||
if key not in {"event", "data", "retry", "id"}:
|
||||
continue
|
||||
if value.startswith(" "):
|
||||
value = value[1:]
|
||||
if key == "id":
|
||||
if "\u0000" in value:
|
||||
continue
|
||||
if key == "retry":
|
||||
try:
|
||||
value = int(value) # type: ignore[assignment]
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
kwargs[key] = value
|
||||
|
||||
if eot is False:
|
||||
self._buffer = raw_payload
|
||||
return await self.next_payload(raw=raw) # type: ignore[call-overload,no-any-return]
|
||||
|
||||
if "id" not in kwargs and self._last_event_id is not None:
|
||||
kwargs["id"] = self._last_event_id
|
||||
|
||||
event = ServerSentEvent(**kwargs)
|
||||
|
||||
if event.id:
|
||||
self._last_event_id = event.id
|
||||
|
||||
if raw is True:
|
||||
return raw_payload
|
||||
|
||||
return event
|
||||
|
||||
async def send_payload(self, buf: str | bytes) -> None:
|
||||
"""Dispatch a buffer to remote."""
|
||||
raise NotImplementedError("SSE is only one-way. Sending is forbidden.")
|
||||
|
||||
@staticmethod
|
||||
def supported_schemes() -> set[str]:
|
||||
return {"sse", "psse"}
|
||||
|
||||
@staticmethod
|
||||
def scheme_to_http_scheme(scheme: str) -> str:
|
||||
return {"sse": "https", "psse": "http"}[scheme]
|
||||
@@ -0,0 +1,238 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ...._async.response import AsyncHTTPResponse
|
||||
|
||||
from wsproto import ConnectionType, WSConnection
|
||||
from wsproto.events import (
|
||||
AcceptConnection,
|
||||
BytesMessage,
|
||||
CloseConnection,
|
||||
Ping,
|
||||
Pong,
|
||||
Request,
|
||||
TextMessage,
|
||||
)
|
||||
from wsproto.extensions import PerMessageDeflate
|
||||
from wsproto.utilities import ProtocolError as WebSocketProtocolError
|
||||
|
||||
from ....backend import HttpVersion
|
||||
from ....exceptions import ProtocolError
|
||||
from .protocol import AsyncExtensionFromHTTP
|
||||
|
||||
|
||||
class AsyncWebSocketExtensionFromHTTP(AsyncExtensionFromHTTP):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._protocol = WSConnection(ConnectionType.CLIENT)
|
||||
self._request_headers: dict[str, str] | None = None
|
||||
self._remote_shutdown: bool = False
|
||||
|
||||
@staticmethod
|
||||
def supported_svn() -> set[HttpVersion]:
|
||||
return {HttpVersion.h11}
|
||||
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
return "wsproto"
|
||||
|
||||
async def start(self, response: AsyncHTTPResponse) -> None:
|
||||
await super().start(response)
|
||||
|
||||
fake_http_response = b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n"
|
||||
|
||||
fake_http_response += b"Sec-Websocket-Accept: "
|
||||
|
||||
accept_token: str | None = response.headers.get("Sec-Websocket-Accept")
|
||||
|
||||
if accept_token is None:
|
||||
raise ProtocolError(
|
||||
"The WebSocket HTTP extension requires 'Sec-Websocket-Accept' header in the server response but was not present."
|
||||
)
|
||||
|
||||
fake_http_response += accept_token.encode() + b"\r\n"
|
||||
|
||||
if "sec-websocket-extensions" in response.headers:
|
||||
fake_http_response += (
|
||||
b"Sec-Websocket-Extensions: "
|
||||
+ response.headers.get("sec-websocket-extensions").encode() # type: ignore[union-attr]
|
||||
+ b"\r\n"
|
||||
)
|
||||
|
||||
fake_http_response += b"\r\n"
|
||||
|
||||
try:
|
||||
self._protocol.receive_data(fake_http_response)
|
||||
except WebSocketProtocolError as e:
|
||||
raise ProtocolError from e # Defensive: should never occur!
|
||||
|
||||
event = next(self._protocol.events())
|
||||
|
||||
if not isinstance(event, AcceptConnection):
|
||||
raise RuntimeError(
|
||||
"The WebSocket state-machine did not pass the handshake phase when expected."
|
||||
)
|
||||
|
||||
def headers(self, http_version: HttpVersion) -> dict[str, str]:
|
||||
"""Specific HTTP headers required (request) before the 101 status response."""
|
||||
if self._request_headers is not None:
|
||||
return self._request_headers
|
||||
|
||||
try:
|
||||
raw_data_to_socket = self._protocol.send(
|
||||
Request(
|
||||
host="example.com", target="/", extensions=(PerMessageDeflate(),)
|
||||
)
|
||||
)
|
||||
except WebSocketProtocolError as e:
|
||||
raise ProtocolError from e # Defensive: should never occur!
|
||||
|
||||
raw_headers = raw_data_to_socket.split(b"\r\n")[2:-2]
|
||||
request_headers: dict[str, str] = {}
|
||||
|
||||
for raw_header in raw_headers:
|
||||
k, v = raw_header.decode().split(": ")
|
||||
request_headers[k.lower()] = v
|
||||
|
||||
if http_version != HttpVersion.h11:
|
||||
del request_headers["upgrade"]
|
||||
del request_headers["connection"]
|
||||
request_headers[":protocol"] = "websocket"
|
||||
request_headers[":method"] = "CONNECT"
|
||||
|
||||
self._request_headers = request_headers
|
||||
|
||||
return request_headers
|
||||
|
||||
async def close(self) -> None:
|
||||
"""End/Notify close for sub protocol."""
|
||||
if self._dsa is not None:
|
||||
if self._police_officer is not None:
|
||||
async with self._police_officer.borrow(self._response):
|
||||
if self._remote_shutdown is False:
|
||||
try:
|
||||
data_to_send: bytes = self._protocol.send(
|
||||
CloseConnection(0)
|
||||
)
|
||||
except WebSocketProtocolError:
|
||||
pass
|
||||
else:
|
||||
async with self._write_error_catcher():
|
||||
await self._dsa.sendall(data_to_send)
|
||||
await self._dsa.close()
|
||||
self._dsa = None
|
||||
else:
|
||||
self._dsa = None
|
||||
if self._response is not None:
|
||||
if self._police_officer is not None:
|
||||
self._police_officer.forget(self._response)
|
||||
else:
|
||||
await self._response.close()
|
||||
self._response = None
|
||||
|
||||
self._police_officer = None
|
||||
|
||||
async def next_payload(self) -> str | bytes | None:
|
||||
"""Unpack the next received message/payload from remote."""
|
||||
if self._dsa is None or self._response is None or self._police_officer is None:
|
||||
raise OSError("The HTTP extension is closed or uninitialized")
|
||||
|
||||
async with self._police_officer.borrow(self._response):
|
||||
for event in self._protocol.events():
|
||||
if isinstance(event, TextMessage):
|
||||
return event.data
|
||||
elif isinstance(event, BytesMessage):
|
||||
return event.data
|
||||
elif isinstance(event, CloseConnection):
|
||||
self._remote_shutdown = True
|
||||
await self.close()
|
||||
return None
|
||||
elif isinstance(event, Ping):
|
||||
try:
|
||||
data_to_send: bytes = self._protocol.send(event.response())
|
||||
except WebSocketProtocolError as e:
|
||||
await self.close()
|
||||
raise ProtocolError from e
|
||||
|
||||
async with self._write_error_catcher():
|
||||
await self._dsa.sendall(data_to_send)
|
||||
|
||||
while True:
|
||||
async with self._read_error_catcher():
|
||||
data, eot, _ = await self._dsa.recv_extended(None)
|
||||
|
||||
try:
|
||||
self._protocol.receive_data(data)
|
||||
except WebSocketProtocolError as e:
|
||||
raise ProtocolError from e
|
||||
|
||||
for event in self._protocol.events():
|
||||
if isinstance(event, TextMessage):
|
||||
return event.data
|
||||
elif isinstance(event, BytesMessage):
|
||||
return event.data
|
||||
elif isinstance(event, CloseConnection):
|
||||
self._remote_shutdown = True
|
||||
await self.close()
|
||||
return None
|
||||
elif isinstance(event, Ping):
|
||||
data_to_send = self._protocol.send(event.response())
|
||||
async with self._write_error_catcher():
|
||||
await self._dsa.sendall(data_to_send)
|
||||
elif isinstance(event, Pong):
|
||||
continue
|
||||
|
||||
async def send_payload(self, buf: str | bytes) -> None:
|
||||
"""Dispatch a buffer to remote."""
|
||||
if self._dsa is None or self._response is None or self._police_officer is None:
|
||||
raise OSError("The HTTP extension is closed or uninitialized")
|
||||
|
||||
async with self._police_officer.borrow(self._response):
|
||||
try:
|
||||
if isinstance(buf, str):
|
||||
data_to_send: bytes = self._protocol.send(TextMessage(buf))
|
||||
else:
|
||||
data_to_send = self._protocol.send(BytesMessage(buf))
|
||||
except WebSocketProtocolError as e:
|
||||
raise ProtocolError from e
|
||||
|
||||
async with self._write_error_catcher():
|
||||
await self._dsa.sendall(data_to_send)
|
||||
|
||||
async def ping(self) -> None:
|
||||
if self._dsa is None or self._response is None or self._police_officer is None:
|
||||
raise OSError("The HTTP extension is closed or uninitialized")
|
||||
|
||||
async with self._police_officer.borrow(self._response):
|
||||
try:
|
||||
data_to_send: bytes = self._protocol.send(Ping())
|
||||
except WebSocketProtocolError as e:
|
||||
raise ProtocolError from e
|
||||
|
||||
async with self._write_error_catcher():
|
||||
await self._dsa.sendall(data_to_send)
|
||||
|
||||
@staticmethod
|
||||
def supported_schemes() -> set[str]:
|
||||
return {"ws", "wss"}
|
||||
|
||||
@staticmethod
|
||||
def scheme_to_http_scheme(scheme: str) -> str:
|
||||
return {"ws": "http", "wss": "https"}[scheme]
|
||||
|
||||
|
||||
class AsyncWebSocketExtensionFromMultiplexedHTTP(AsyncWebSocketExtensionFromHTTP):
|
||||
"""
|
||||
Plugin that support doing WebSocket over HTTP 2 and 3.
|
||||
This implement RFC8441. Beware that this isn't actually supported by much server around internet.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
return "rfc8441"
|
||||
|
||||
@staticmethod
|
||||
def supported_svn() -> set[HttpVersion]:
|
||||
return {HttpVersion.h11, HttpVersion.h2, HttpVersion.h3}
|
||||
@@ -0,0 +1,189 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from abc import ABCMeta
|
||||
from contextlib import contextmanager
|
||||
from socket import timeout as SocketTimeout
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ...backend import HttpVersion
|
||||
from ...backend._base import DirectStreamAccess
|
||||
from ...response import HTTPResponse
|
||||
from ...util.traffic_police import TrafficPolice
|
||||
|
||||
from ...exceptions import (
|
||||
BaseSSLError,
|
||||
ProtocolError,
|
||||
ReadTimeoutError,
|
||||
SSLError,
|
||||
MustRedialError,
|
||||
)
|
||||
|
||||
|
||||
class ExtensionFromHTTP(metaclass=ABCMeta):
|
||||
"""Represent an extension that can be negotiated just after a "101 Switching Protocol" HTTP response.
|
||||
This will considerably ease downstream integration."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._dsa: DirectStreamAccess | None = None
|
||||
self._response: HTTPResponse | None = None
|
||||
self._police_officer: TrafficPolice | None = None # type: ignore[type-arg]
|
||||
|
||||
@contextmanager
|
||||
def _read_error_catcher(self) -> typing.Generator[None, None, None]:
|
||||
"""
|
||||
Catch low-level python exceptions, instead re-raising urllib3
|
||||
variants, so that low-level exceptions are not leaked in the
|
||||
high-level api.
|
||||
|
||||
On unrecoverable issues, release the connection back to the pool.
|
||||
"""
|
||||
clean_exit = False
|
||||
|
||||
try:
|
||||
try:
|
||||
yield
|
||||
|
||||
except SocketTimeout as e:
|
||||
clean_exit = True
|
||||
pool = (
|
||||
self._response._pool
|
||||
if self._response and hasattr(self._response, "_pool")
|
||||
else None
|
||||
)
|
||||
raise ReadTimeoutError(pool, None, "Read timed out.") from e # type: ignore[arg-type]
|
||||
|
||||
except BaseSSLError as e:
|
||||
# FIXME: Is there a better way to differentiate between SSLErrors?
|
||||
if "read operation timed out" not in str(e):
|
||||
# SSL errors related to framing/MAC get wrapped and reraised here
|
||||
raise SSLError(e) from e
|
||||
clean_exit = True # ws algorithms based on timeouts can expect this without being harmful!
|
||||
pool = (
|
||||
self._response._pool
|
||||
if self._response and hasattr(self._response, "_pool")
|
||||
else None
|
||||
)
|
||||
raise ReadTimeoutError(pool, None, "Read timed out.") from e # type: ignore[arg-type]
|
||||
|
||||
except (OSError, MustRedialError) as e:
|
||||
# This includes IncompleteRead.
|
||||
raise ProtocolError(f"Connection broken: {e!r}", e) from e
|
||||
|
||||
# If no exception is thrown, we should avoid cleaning up
|
||||
# unnecessarily.
|
||||
clean_exit = True
|
||||
finally:
|
||||
# If we didn't terminate cleanly, we need to throw away our
|
||||
# connection.
|
||||
if not clean_exit:
|
||||
# The response may not be closed but we're not going to use it
|
||||
# anymore so close it now to ensure that the connection is
|
||||
# released back to the pool.
|
||||
if self._response:
|
||||
self.close()
|
||||
|
||||
@contextmanager
|
||||
def _write_error_catcher(self) -> typing.Generator[None, None, None]:
|
||||
"""
|
||||
Catch low-level python exceptions, instead re-raising urllib3
|
||||
variants, so that low-level exceptions are not leaked in the
|
||||
high-level api.
|
||||
|
||||
On unrecoverable issues, release the connection back to the pool.
|
||||
"""
|
||||
clean_exit = False
|
||||
|
||||
try:
|
||||
try:
|
||||
yield
|
||||
|
||||
except SocketTimeout as e:
|
||||
pool = (
|
||||
self._response._pool
|
||||
if self._response and hasattr(self._response, "_pool")
|
||||
else None
|
||||
)
|
||||
raise ReadTimeoutError(pool, None, "Read timed out.") from e # type: ignore[arg-type]
|
||||
|
||||
except BaseSSLError as e:
|
||||
raise SSLError(e) from e
|
||||
|
||||
except OSError as e:
|
||||
# This includes IncompleteRead.
|
||||
raise ProtocolError(f"Connection broken: {e!r}", e) from e
|
||||
|
||||
# If no exception is thrown, we should avoid cleaning up
|
||||
# unnecessarily.
|
||||
clean_exit = True
|
||||
finally:
|
||||
# If we didn't terminate cleanly, we need to throw away our
|
||||
# connection.
|
||||
if not clean_exit:
|
||||
# The response may not be closed but we're not going to use it
|
||||
# anymore so close it now to ensure that the connection is
|
||||
# released back to the pool.
|
||||
if self._response:
|
||||
self.close()
|
||||
|
||||
@property
|
||||
def urlopen_kwargs(self) -> dict[str, typing.Any]:
|
||||
"""Return prerequisites. Must be passed as additional parameters to urlopen."""
|
||||
return {}
|
||||
|
||||
def start(self, response: HTTPResponse) -> None:
|
||||
"""The HTTP server gave us the go-to start negotiating another protocol."""
|
||||
if response._fp is None or not hasattr(response._fp, "_dsa"):
|
||||
raise RuntimeError(
|
||||
"Attempt to start an HTTP extension without direct I/O access to the stream"
|
||||
)
|
||||
|
||||
self._dsa = response._fp._dsa
|
||||
self._police_officer = response._police_officer
|
||||
self._response = response
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._dsa is None
|
||||
|
||||
@staticmethod
|
||||
def supported_svn() -> set[HttpVersion]:
|
||||
"""Hint about supported parent SVN for this extension."""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def supported_schemes() -> set[str]:
|
||||
"""Recognized schemes for the extension."""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def scheme_to_http_scheme(scheme: str) -> str:
|
||||
"""Convert the extension scheme to a known http scheme (either http or https)"""
|
||||
raise NotImplementedError
|
||||
|
||||
def headers(self, http_version: HttpVersion) -> dict[str, str]:
|
||||
"""Specific HTTP headers required (request) before the 101 status response."""
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self) -> None:
|
||||
"""End/Notify close for sub protocol."""
|
||||
raise NotImplementedError
|
||||
|
||||
def next_payload(self) -> str | bytes | None:
|
||||
"""Unpack the next received message/payload from remote. This call does read from the socket.
|
||||
If the method return None, it means that the remote closed the (extension) pipeline.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def send_payload(self, buf: str | bytes) -> None:
|
||||
"""Dispatch a buffer to remote."""
|
||||
raise NotImplementedError
|
||||
|
||||
def on_payload(self, callback: typing.Callable[[str | bytes | None], None]) -> None:
|
||||
"""Set up a callback that will be invoked automatically once a payload is received.
|
||||
Meaning that you stop calling manually next_payload()."""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ...backend import HttpVersion
|
||||
from .protocol import ExtensionFromHTTP
|
||||
|
||||
|
||||
class RawExtensionFromHTTP(ExtensionFromHTTP):
|
||||
"""Raw I/O from given HTTP stream after a 101 Switching Protocol Status."""
|
||||
|
||||
@staticmethod
|
||||
def supported_svn() -> set[HttpVersion]:
|
||||
return {HttpVersion.h11, HttpVersion.h2, HttpVersion.h3}
|
||||
|
||||
def headers(self, http_version: HttpVersion) -> dict[str, str]:
|
||||
"""Specific HTTP headers required (request) before the 101 status response."""
|
||||
return {}
|
||||
|
||||
def close(self) -> None:
|
||||
"""End/Notify close for sub protocol."""
|
||||
if self._dsa is not None:
|
||||
with self._write_error_catcher():
|
||||
self._dsa.close()
|
||||
self._dsa = None
|
||||
if self._response is not None:
|
||||
self._response.close()
|
||||
self._response = None
|
||||
self._police_officer = None
|
||||
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
return "raw"
|
||||
|
||||
@staticmethod
|
||||
def supported_schemes() -> set[str]:
|
||||
return set()
|
||||
|
||||
@staticmethod
|
||||
def scheme_to_http_scheme(scheme: str) -> str:
|
||||
return scheme
|
||||
|
||||
def next_payload(self) -> bytes | None:
|
||||
if self._police_officer is None or self._dsa is None:
|
||||
raise OSError("The HTTP extension is closed or uninitialized")
|
||||
with self._police_officer.borrow(self._response):
|
||||
with self._read_error_catcher():
|
||||
data, eot, _ = self._dsa.recv_extended(None)
|
||||
return data
|
||||
|
||||
def send_payload(self, buf: str | bytes) -> None:
|
||||
if self._police_officer is None or self._dsa is None:
|
||||
raise OSError("The HTTP extension is closed or uninitialized")
|
||||
|
||||
if isinstance(buf, str):
|
||||
buf = buf.encode()
|
||||
|
||||
with self._police_officer.borrow(self._response):
|
||||
with self._write_error_catcher():
|
||||
self._dsa.sendall(buf)
|
||||
@@ -0,0 +1,185 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import typing
|
||||
from threading import RLock
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ...response import HTTPResponse
|
||||
|
||||
from ...backend import HttpVersion
|
||||
from .protocol import ExtensionFromHTTP
|
||||
|
||||
|
||||
class ServerSentEvent:
|
||||
def __init__(
|
||||
self,
|
||||
event: str | None = None,
|
||||
data: str | None = None,
|
||||
id: str | None = None,
|
||||
retry: int | None = None,
|
||||
) -> None:
|
||||
if not event:
|
||||
event = "message"
|
||||
|
||||
if data is None:
|
||||
data = ""
|
||||
|
||||
if id is None:
|
||||
id = ""
|
||||
|
||||
self._event = event
|
||||
self._data = data
|
||||
self._id = id
|
||||
self._retry = retry
|
||||
|
||||
@property
|
||||
def event(self) -> str:
|
||||
return self._event
|
||||
|
||||
@property
|
||||
def data(self) -> str:
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def retry(self) -> int | None:
|
||||
return self._retry
|
||||
|
||||
def json(self) -> typing.Any:
|
||||
return json.loads(self.data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
pieces = [f"event={self.event!r}"]
|
||||
if self.data != "":
|
||||
pieces.append(f"data={self.data!r}")
|
||||
if self.id != "":
|
||||
pieces.append(f"id={self.id!r}")
|
||||
if self.retry is not None:
|
||||
pieces.append(f"retry={self.retry!r}")
|
||||
return f"ServerSentEvent({', '.join(pieces)})"
|
||||
|
||||
|
||||
class ServerSideEventExtensionFromHTTP(ExtensionFromHTTP):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._last_event_id: str | None = None
|
||||
self._buffer: str = ""
|
||||
self._lock = RLock()
|
||||
self._stream: typing.Generator[bytes, None, None] | None = None
|
||||
|
||||
@staticmethod
|
||||
def supported_svn() -> set[HttpVersion]:
|
||||
return {HttpVersion.h11, HttpVersion.h2, HttpVersion.h3}
|
||||
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
return "native"
|
||||
|
||||
@property
|
||||
def urlopen_kwargs(self) -> dict[str, typing.Any]:
|
||||
return {"preload_content": False}
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._stream is None
|
||||
|
||||
def close(self) -> None:
|
||||
if self._stream is not None and self._response is not None:
|
||||
self._stream.close()
|
||||
if (
|
||||
self._response._fp is not None
|
||||
and self._police_officer is not None
|
||||
and hasattr(self._response._fp, "abort")
|
||||
):
|
||||
with self._police_officer.borrow(self._response):
|
||||
self._response._fp.abort()
|
||||
self._stream = None
|
||||
self._response = None
|
||||
self._police_officer = None
|
||||
|
||||
def start(self, response: HTTPResponse) -> None:
|
||||
super().start(response)
|
||||
|
||||
self._stream = response.stream(-1, decode_content=True)
|
||||
|
||||
def headers(self, http_version: HttpVersion) -> dict[str, str]:
|
||||
return {"accept": "text/event-stream", "cache-control": "no-store"}
|
||||
|
||||
@typing.overload
|
||||
def next_payload(self, *, raw: typing.Literal[True] = True) -> str | None: ...
|
||||
|
||||
@typing.overload
|
||||
def next_payload(
|
||||
self, *, raw: typing.Literal[False] = False
|
||||
) -> ServerSentEvent | None: ...
|
||||
|
||||
def next_payload(self, *, raw: bool = False) -> ServerSentEvent | str | None:
|
||||
"""Unpack the next received message/payload from remote."""
|
||||
if self._response is None or self._stream is None:
|
||||
raise OSError("The HTTP extension is closed or uninitialized")
|
||||
with self._lock:
|
||||
try:
|
||||
raw_payload: str = next(self._stream).decode("utf-8")
|
||||
except StopIteration:
|
||||
self._stream = None
|
||||
return None
|
||||
|
||||
if self._buffer:
|
||||
raw_payload = self._buffer + raw_payload
|
||||
self._buffer = ""
|
||||
|
||||
kwargs: dict[str, typing.Any] = {}
|
||||
eot = False
|
||||
|
||||
for line in raw_payload.splitlines():
|
||||
if not line:
|
||||
eot = True
|
||||
break
|
||||
key, _, value = line.partition(":")
|
||||
if key not in {"event", "data", "retry", "id"}:
|
||||
continue
|
||||
if value.startswith(" "):
|
||||
value = value[1:]
|
||||
if key == "id":
|
||||
if "\u0000" in value:
|
||||
continue
|
||||
if key == "retry":
|
||||
try:
|
||||
value = int(value) # type: ignore[assignment]
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
kwargs[key] = value
|
||||
|
||||
if eot is False:
|
||||
self._buffer = raw_payload
|
||||
return self.next_payload(raw=raw) # type: ignore[call-overload,no-any-return]
|
||||
|
||||
if "id" not in kwargs and self._last_event_id is not None:
|
||||
kwargs["id"] = self._last_event_id
|
||||
|
||||
event = ServerSentEvent(**kwargs)
|
||||
|
||||
if event.id:
|
||||
self._last_event_id = event.id
|
||||
|
||||
if raw is True:
|
||||
return raw_payload
|
||||
|
||||
return event
|
||||
|
||||
def send_payload(self, buf: str | bytes) -> None:
|
||||
"""Dispatch a buffer to remote."""
|
||||
raise NotImplementedError("SSE is only one-way. Sending is forbidden.")
|
||||
|
||||
@staticmethod
|
||||
def supported_schemes() -> set[str]:
|
||||
return {"sse", "psse"}
|
||||
|
||||
@staticmethod
|
||||
def scheme_to_http_scheme(scheme: str) -> str:
|
||||
return {"sse": "https", "psse": "http"}[scheme]
|
||||
@@ -0,0 +1,247 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ...response import HTTPResponse
|
||||
|
||||
from wsproto import ConnectionType, WSConnection
|
||||
from wsproto.events import (
|
||||
AcceptConnection,
|
||||
BytesMessage,
|
||||
CloseConnection,
|
||||
Ping,
|
||||
Pong,
|
||||
Request,
|
||||
TextMessage,
|
||||
)
|
||||
from wsproto.extensions import PerMessageDeflate
|
||||
from wsproto.utilities import ProtocolError as WebSocketProtocolError
|
||||
|
||||
from ...backend import HttpVersion
|
||||
from ...exceptions import ProtocolError
|
||||
from .protocol import ExtensionFromHTTP
|
||||
|
||||
|
||||
class WebSocketExtensionFromHTTP(ExtensionFromHTTP):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._protocol = WSConnection(ConnectionType.CLIENT)
|
||||
self._request_headers: dict[str, str] | None = None
|
||||
self._remote_shutdown: bool = False
|
||||
|
||||
@staticmethod
|
||||
def supported_svn() -> set[HttpVersion]:
|
||||
return {HttpVersion.h11}
|
||||
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
return "wsproto"
|
||||
|
||||
def start(self, response: HTTPResponse) -> None:
|
||||
super().start(response)
|
||||
|
||||
fake_http_response = b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n"
|
||||
|
||||
fake_http_response += b"Sec-Websocket-Accept: "
|
||||
|
||||
accept_token: str | None = response.headers.get("Sec-Websocket-Accept")
|
||||
|
||||
if accept_token is None:
|
||||
raise ProtocolError(
|
||||
"The WebSocket HTTP extension requires 'Sec-Websocket-Accept' header in the server response but was not present."
|
||||
)
|
||||
|
||||
fake_http_response += accept_token.encode() + b"\r\n"
|
||||
|
||||
if "sec-websocket-extensions" in response.headers:
|
||||
fake_http_response += (
|
||||
b"Sec-Websocket-Extensions: "
|
||||
+ response.headers.get("sec-websocket-extensions").encode() # type: ignore[union-attr]
|
||||
+ b"\r\n"
|
||||
)
|
||||
|
||||
fake_http_response += b"\r\n"
|
||||
|
||||
try:
|
||||
self._protocol.receive_data(fake_http_response)
|
||||
except WebSocketProtocolError as e:
|
||||
raise ProtocolError from e # Defensive: should never happen
|
||||
|
||||
event = next(self._protocol.events())
|
||||
|
||||
if not isinstance(event, AcceptConnection):
|
||||
raise RuntimeError(
|
||||
"The WebSocket state-machine did not pass the handshake phase when expected."
|
||||
)
|
||||
|
||||
def headers(self, http_version: HttpVersion) -> dict[str, str]:
|
||||
"""Specific HTTP headers required (request) before the 101 status response."""
|
||||
if self._request_headers is not None:
|
||||
return self._request_headers
|
||||
|
||||
try:
|
||||
raw_data_to_socket = self._protocol.send(
|
||||
Request(
|
||||
host="example.com", target="/", extensions=(PerMessageDeflate(),)
|
||||
)
|
||||
)
|
||||
except WebSocketProtocolError as e:
|
||||
raise ProtocolError from e # Defensive: should never happen
|
||||
|
||||
raw_headers = raw_data_to_socket.split(b"\r\n")[2:-2]
|
||||
request_headers: dict[str, str] = {}
|
||||
|
||||
for raw_header in raw_headers:
|
||||
k, v = raw_header.decode().split(": ")
|
||||
request_headers[k.lower()] = v
|
||||
|
||||
if http_version != HttpVersion.h11:
|
||||
del request_headers["upgrade"]
|
||||
del request_headers["connection"]
|
||||
request_headers[":protocol"] = "websocket"
|
||||
request_headers[":method"] = "CONNECT"
|
||||
|
||||
self._request_headers = request_headers
|
||||
|
||||
return request_headers
|
||||
|
||||
def close(self) -> None:
|
||||
"""End/Notify close for sub protocol."""
|
||||
if self._dsa is not None:
|
||||
if self._police_officer is not None:
|
||||
with self._police_officer.borrow(self._response):
|
||||
if self._remote_shutdown is False:
|
||||
try:
|
||||
data_to_send: bytes = self._protocol.send(
|
||||
CloseConnection(0)
|
||||
)
|
||||
except WebSocketProtocolError:
|
||||
pass
|
||||
else:
|
||||
with self._write_error_catcher():
|
||||
self._dsa.sendall(data_to_send)
|
||||
self._dsa.close()
|
||||
self._dsa = None
|
||||
else:
|
||||
self._dsa = None
|
||||
if self._response is not None:
|
||||
if self._police_officer is not None:
|
||||
self._police_officer.forget(self._response)
|
||||
else:
|
||||
self._response.close()
|
||||
self._response = None
|
||||
|
||||
self._police_officer = None
|
||||
|
||||
def next_payload(self) -> str | bytes | None:
|
||||
"""Unpack the next received message/payload from remote."""
|
||||
if self._dsa is None or self._response is None or self._police_officer is None:
|
||||
raise OSError("The HTTP extension is closed or uninitialized")
|
||||
|
||||
with self._police_officer.borrow(self._response):
|
||||
# we may have pending event to unpack!
|
||||
for event in self._protocol.events():
|
||||
if isinstance(event, TextMessage):
|
||||
return event.data
|
||||
elif isinstance(event, BytesMessage):
|
||||
return event.data
|
||||
elif isinstance(event, CloseConnection):
|
||||
self._remote_shutdown = True
|
||||
self.close()
|
||||
return None
|
||||
elif isinstance(event, Ping):
|
||||
try:
|
||||
data_to_send: bytes = self._protocol.send(event.response())
|
||||
except WebSocketProtocolError as e:
|
||||
self.close()
|
||||
raise ProtocolError from e
|
||||
|
||||
with self._write_error_catcher():
|
||||
self._dsa.sendall(data_to_send)
|
||||
|
||||
while True:
|
||||
with self._read_error_catcher():
|
||||
data, eot, _ = self._dsa.recv_extended(None)
|
||||
|
||||
try:
|
||||
self._protocol.receive_data(data)
|
||||
except WebSocketProtocolError as e:
|
||||
self.close()
|
||||
raise ProtocolError from e
|
||||
|
||||
for event in self._protocol.events():
|
||||
if isinstance(event, TextMessage):
|
||||
return event.data
|
||||
elif isinstance(event, BytesMessage):
|
||||
return event.data
|
||||
elif isinstance(event, CloseConnection):
|
||||
self._remote_shutdown = True
|
||||
self.close()
|
||||
return None
|
||||
elif isinstance(event, Ping):
|
||||
try:
|
||||
data_to_send = self._protocol.send(event.response())
|
||||
except WebSocketProtocolError as e:
|
||||
self.close()
|
||||
raise ProtocolError from e
|
||||
with self._write_error_catcher():
|
||||
self._dsa.sendall(data_to_send)
|
||||
elif isinstance(event, Pong):
|
||||
continue
|
||||
|
||||
def send_payload(self, buf: str | bytes) -> None:
|
||||
"""Dispatch a buffer to remote."""
|
||||
if self._dsa is None or self._response is None or self._police_officer is None:
|
||||
raise OSError("The HTTP extension is closed or uninitialized")
|
||||
|
||||
with self._police_officer.borrow(self._response):
|
||||
try:
|
||||
if isinstance(buf, str):
|
||||
data_to_send: bytes = self._protocol.send(TextMessage(buf))
|
||||
else:
|
||||
data_to_send = self._protocol.send(BytesMessage(buf))
|
||||
except WebSocketProtocolError as e:
|
||||
self.close()
|
||||
raise ProtocolError from e
|
||||
|
||||
with self._write_error_catcher():
|
||||
self._dsa.sendall(data_to_send)
|
||||
|
||||
def ping(self) -> None:
|
||||
if self._dsa is None or self._response is None or self._police_officer is None:
|
||||
raise OSError("The HTTP extension is closed or uninitialized")
|
||||
|
||||
with self._police_officer.borrow(self._response):
|
||||
if self._remote_shutdown is False:
|
||||
try:
|
||||
data_to_send: bytes = self._protocol.send(Ping())
|
||||
except WebSocketProtocolError as e:
|
||||
self.close()
|
||||
raise ProtocolError from e
|
||||
|
||||
with self._write_error_catcher():
|
||||
self._dsa.sendall(data_to_send)
|
||||
|
||||
@staticmethod
|
||||
def supported_schemes() -> set[str]:
|
||||
return {"ws", "wss"}
|
||||
|
||||
@staticmethod
|
||||
def scheme_to_http_scheme(scheme: str) -> str:
|
||||
return {"ws": "http", "wss": "https"}[scheme]
|
||||
|
||||
|
||||
class WebSocketExtensionFromMultiplexedHTTP(WebSocketExtensionFromHTTP):
|
||||
"""
|
||||
Plugin that support doing WebSocket over HTTP 2 and 3.
|
||||
This implement RFC8441. Beware that this isn't actually supported by much server around internet.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
return "rfc8441" # also known as rfc9220 (http3)
|
||||
|
||||
@staticmethod
|
||||
def supported_svn() -> set[HttpVersion]:
|
||||
return {HttpVersion.h11, HttpVersion.h2, HttpVersion.h3}
|
||||
373
.venv/lib/python3.9/site-packages/urllib3_future/exceptions.py
Normal file
373
.venv/lib/python3.9/site-packages/urllib3_future/exceptions.py
Normal file
@@ -0,0 +1,373 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
from email.errors import MessageDefect
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ._async.connection import AsyncHTTPConnection
|
||||
from ._async.connectionpool import AsyncConnectionPool
|
||||
from ._async.response import AsyncHTTPResponse
|
||||
from ._typing import _TYPE_REDUCE_RESULT
|
||||
from .backend import ResponsePromise
|
||||
from .connection import HTTPConnection
|
||||
from .connectionpool import ConnectionPool
|
||||
from .response import HTTPResponse
|
||||
from .util.retry import Retry
|
||||
|
||||
# Base Exceptions
|
||||
try: # Compiled with SSL?
|
||||
import ssl
|
||||
|
||||
BaseSSLError = ssl.SSLError
|
||||
except (ImportError, AttributeError):
|
||||
ssl = None # type: ignore[assignment]
|
||||
|
||||
class BaseSSLError(BaseException): # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
|
||||
class HTTPError(Exception):
|
||||
"""Base exception used by this module."""
|
||||
|
||||
|
||||
class HTTPWarning(Warning):
|
||||
"""Base warning used by this module."""
|
||||
|
||||
|
||||
class PoolError(HTTPError):
|
||||
"""Base exception for errors caused within a pool."""
|
||||
|
||||
def __init__(
|
||||
self, pool: ConnectionPool | AsyncConnectionPool, message: str
|
||||
) -> None:
|
||||
self.pool = pool
|
||||
super().__init__(f"{pool}: {message}")
|
||||
|
||||
def __reduce__(self) -> _TYPE_REDUCE_RESULT:
|
||||
# For pickling purposes.
|
||||
return self.__class__, (None, None)
|
||||
|
||||
|
||||
class RequestError(PoolError):
|
||||
"""Base exception for PoolErrors that have associated URLs."""
|
||||
|
||||
def __init__(
|
||||
self, pool: ConnectionPool | AsyncConnectionPool, url: str, message: str
|
||||
) -> None:
|
||||
self.url = url
|
||||
super().__init__(pool, message)
|
||||
|
||||
def __reduce__(self) -> _TYPE_REDUCE_RESULT:
|
||||
# For pickling purposes.
|
||||
return self.__class__, (None, self.url, None)
|
||||
|
||||
|
||||
class SSLError(HTTPError):
|
||||
"""Raised when SSL certificate fails in an HTTPS connection."""
|
||||
|
||||
|
||||
class ProxyError(HTTPError):
|
||||
"""Raised when the connection to a proxy fails."""
|
||||
|
||||
# The original error is also available as __cause__.
|
||||
original_error: Exception
|
||||
|
||||
def __init__(self, message: str, error: Exception) -> None:
|
||||
super().__init__(message, error)
|
||||
self.original_error = error
|
||||
|
||||
|
||||
class DecodeError(HTTPError):
|
||||
"""Raised when automatic decoding based on Content-Type fails."""
|
||||
|
||||
|
||||
class ProtocolError(HTTPError):
|
||||
"""Raised when something unexpected happens mid-request/response."""
|
||||
|
||||
|
||||
#: Renamed to ProtocolError but aliased for backwards compatibility.
|
||||
ConnectionError = ProtocolError
|
||||
|
||||
|
||||
# Leaf Exceptions
|
||||
|
||||
|
||||
class MaxRetryError(RequestError):
|
||||
"""Raised when the maximum number of retries is exceeded.
|
||||
|
||||
:param pool: The connection pool
|
||||
:type pool: :class:`~urllib3.connectionpool.HTTPConnectionPool`
|
||||
:param str url: The requested Url
|
||||
:param reason: The underlying error
|
||||
:type reason: :class:`Exception`
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pool: ConnectionPool | AsyncConnectionPool,
|
||||
url: str,
|
||||
reason: Exception | None = None,
|
||||
) -> None:
|
||||
self.reason = reason
|
||||
|
||||
message = f"Max retries exceeded with url: {url} (Caused by {reason!r})"
|
||||
|
||||
super().__init__(pool, url, message)
|
||||
|
||||
|
||||
class HostChangedError(RequestError):
|
||||
"""Raised when an existing pool gets a request for a foreign host."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pool: ConnectionPool | AsyncConnectionPool,
|
||||
url: str,
|
||||
retries: Retry | int = 3,
|
||||
) -> None:
|
||||
message = f"Tried to open a foreign host with url: {url}"
|
||||
super().__init__(pool, url, message)
|
||||
self.retries = retries
|
||||
|
||||
|
||||
class TimeoutStateError(HTTPError):
|
||||
"""Raised when passing an invalid state to a timeout"""
|
||||
|
||||
|
||||
class TimeoutError(HTTPError):
|
||||
"""Raised when a socket timeout error occurs.
|
||||
|
||||
Catching this error will catch both :exc:`ReadTimeoutErrors
|
||||
<ReadTimeoutError>` and :exc:`ConnectTimeoutErrors <ConnectTimeoutError>`.
|
||||
"""
|
||||
|
||||
|
||||
class ReadTimeoutError(TimeoutError, RequestError):
|
||||
"""Raised when a socket timeout occurs while receiving data from a server"""
|
||||
|
||||
|
||||
# This timeout error does not have a URL attached and needs to inherit from the
|
||||
# base HTTPError
|
||||
class ConnectTimeoutError(TimeoutError):
|
||||
"""Raised when a socket timeout occurs while connecting to a server"""
|
||||
|
||||
|
||||
class NewConnectionError(ConnectTimeoutError, HTTPError):
|
||||
"""Raised when we fail to establish a new connection. Usually ECONNREFUSED."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conn: HTTPConnection
|
||||
| AsyncHTTPConnection
|
||||
| ConnectionPool
|
||||
| AsyncConnectionPool,
|
||||
message: str,
|
||||
) -> None:
|
||||
self.conn = conn
|
||||
super().__init__(f"{conn}: {message}")
|
||||
|
||||
def __reduce__(self) -> _TYPE_REDUCE_RESULT:
|
||||
# For pickling purposes.
|
||||
return self.__class__, (None, None)
|
||||
|
||||
|
||||
class NameResolutionError(NewConnectionError):
|
||||
"""Raised when host name resolution fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
conn: HTTPConnection
|
||||
| AsyncHTTPConnection
|
||||
| ConnectionPool
|
||||
| AsyncConnectionPool,
|
||||
reason: socket.gaierror,
|
||||
):
|
||||
message = f"Failed to resolve '{host}' ({reason})"
|
||||
super().__init__(conn, message)
|
||||
|
||||
def __reduce__(self) -> _TYPE_REDUCE_RESULT:
|
||||
# For pickling purposes.
|
||||
return self.__class__, (None, None, None)
|
||||
|
||||
|
||||
class EmptyPoolError(PoolError):
|
||||
"""Raised when a pool runs out of connections and no more are allowed."""
|
||||
|
||||
|
||||
class FullPoolError(PoolError):
|
||||
"""Raised when we try to add a connection to a full pool in blocking mode."""
|
||||
|
||||
|
||||
class ClosedPoolError(PoolError):
|
||||
"""Raised when a request enters a pool after the pool has been closed."""
|
||||
|
||||
|
||||
class LocationValueError(ValueError, HTTPError):
|
||||
"""Raised when there is something wrong with a given URL input."""
|
||||
|
||||
|
||||
class LocationParseError(LocationValueError):
|
||||
"""Raised when get_host or similar fails to parse the URL input."""
|
||||
|
||||
def __init__(self, location: str) -> None:
|
||||
message = f"Failed to parse: {location}"
|
||||
super().__init__(message)
|
||||
|
||||
self.location = location
|
||||
|
||||
|
||||
class URLSchemeUnknown(LocationValueError):
|
||||
"""Raised when a URL input has an unsupported scheme."""
|
||||
|
||||
def __init__(self, scheme: str):
|
||||
message = f"Not supported URL scheme {scheme}"
|
||||
super().__init__(message)
|
||||
|
||||
self.scheme = scheme
|
||||
|
||||
|
||||
class ResponseError(HTTPError):
|
||||
"""Used as a container for an error reason supplied in a MaxRetryError."""
|
||||
|
||||
GENERIC_ERROR = "too many error responses"
|
||||
SPECIFIC_ERROR = "too many {status_code} error responses"
|
||||
|
||||
|
||||
class SecurityWarning(HTTPWarning):
|
||||
"""Warned when performing security reducing actions"""
|
||||
|
||||
|
||||
class InsecureRequestWarning(SecurityWarning):
|
||||
"""Warned when making an unverified HTTPS request."""
|
||||
|
||||
|
||||
class NotOpenSSLWarning(SecurityWarning):
|
||||
"""Warned when using unsupported SSL library"""
|
||||
|
||||
|
||||
class SystemTimeWarning(SecurityWarning):
|
||||
"""Warned when system time is suspected to be wrong"""
|
||||
|
||||
|
||||
class InsecurePlatformWarning(SecurityWarning):
|
||||
"""Warned when certain TLS/SSL configuration is not available on a platform."""
|
||||
|
||||
|
||||
class DependencyWarning(HTTPWarning):
|
||||
"""
|
||||
Warned when an attempt is made to import a module with missing optional
|
||||
dependencies.
|
||||
"""
|
||||
|
||||
|
||||
class ResponseNotChunked(ProtocolError, ValueError):
|
||||
"""Response needs to be chunked in order to read it as chunks."""
|
||||
|
||||
|
||||
class BodyNotHttplibCompatible(HTTPError):
|
||||
"""
|
||||
Body should be :class:`http.client.HTTPResponse` like
|
||||
(have an fp attribute which returns raw chunks) for read_chunked().
|
||||
"""
|
||||
|
||||
|
||||
class IncompleteRead(ProtocolError):
|
||||
"""
|
||||
Response length doesn't match expected Content-Length
|
||||
|
||||
Subclass of :class:`http.client.IncompleteRead` to allow int value
|
||||
for ``partial`` to avoid creating large objects on streamed reads.
|
||||
"""
|
||||
|
||||
def __init__(self, partial: int, expected: int | None = None) -> None:
|
||||
self.partial = partial
|
||||
self.expected = expected
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self.expected is not None:
|
||||
return f"IncompleteRead({self.partial} bytes read, {self.expected} more expected)"
|
||||
return f"IncompleteRead({self.partial} bytes read)"
|
||||
|
||||
__str__ = object.__str__
|
||||
|
||||
|
||||
class InvalidChunkLength(ProtocolError):
|
||||
"""Invalid chunk length in a chunked response."""
|
||||
|
||||
def __init__(
|
||||
self, response: HTTPResponse | AsyncHTTPResponse, length: bytes
|
||||
) -> None:
|
||||
self.partial: int = response.tell()
|
||||
self.expected: int | None = response.length_remaining
|
||||
self.response = response
|
||||
self.length = length
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "InvalidChunkLength(got length %r, %i bytes read)" % (
|
||||
self.length,
|
||||
self.partial,
|
||||
)
|
||||
|
||||
|
||||
class InvalidHeader(HTTPError):
|
||||
"""The header provided was somehow invalid."""
|
||||
|
||||
|
||||
class ProxySchemeUnknown(AssertionError, URLSchemeUnknown):
|
||||
"""ProxyManager does not support the supplied scheme"""
|
||||
|
||||
# TODO(t-8ch): Stop inheriting from AssertionError in v2.0.
|
||||
|
||||
def __init__(self, scheme: str | None) -> None:
|
||||
# 'localhost' is here because our URL parser parses
|
||||
# localhost:8080 -> scheme=localhost, remove if we fix this.
|
||||
if scheme == "localhost":
|
||||
scheme = None
|
||||
if scheme is None:
|
||||
message = "Proxy URL had no scheme, should start with http:// or https://"
|
||||
else:
|
||||
message = f"Proxy URL had unsupported scheme {scheme}, should use http:// or https://"
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ProxySchemeUnsupported(ValueError):
|
||||
"""Fetching HTTPS resources through HTTPS proxies is unsupported"""
|
||||
|
||||
|
||||
class HeaderParsingError(HTTPError):
|
||||
"""Raised by assert_header_parsing, but we convert it to a log.warning statement."""
|
||||
|
||||
def __init__(
|
||||
self, defects: list[MessageDefect], unparsed_data: bytes | str | None
|
||||
) -> None:
|
||||
message = f"{defects or 'Unknown'}, unparsed data: {unparsed_data!r}"
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class UnrewindableBodyError(HTTPError):
|
||||
"""urllib3 encountered an error when trying to rewind a body"""
|
||||
|
||||
|
||||
class EarlyResponse(HTTPError):
|
||||
"""urllib3 received a response prior to sending the whole body"""
|
||||
|
||||
def __init__(self, promise: ResponsePromise) -> None:
|
||||
self.promise = promise
|
||||
|
||||
|
||||
class ResponseNotReady(HTTPError):
|
||||
"""Kept for BC"""
|
||||
|
||||
|
||||
class RecoverableError(HTTPError):
|
||||
"""This error is never leaked in the upper stack, it serves only an internal purpose."""
|
||||
|
||||
|
||||
class MustDowngradeError(RecoverableError):
|
||||
"""An error occurred with a protocol and can be circumvented using an older protocol."""
|
||||
|
||||
|
||||
class MustRedialError(RecoverableError):
|
||||
"""Unused legacy exception. Remove it in a next major."""
|
||||
283
.venv/lib/python3.9/site-packages/urllib3_future/fields.py
Normal file
283
.venv/lib/python3.9/site-packages/urllib3_future/fields.py
Normal file
@@ -0,0 +1,283 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import email.utils
|
||||
import mimetypes
|
||||
import typing
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ._typing import _TYPE_FIELD_VALUE, _TYPE_FIELD_VALUE_TUPLE
|
||||
|
||||
|
||||
def guess_content_type(
|
||||
filename: str | None, default: str = "application/octet-stream"
|
||||
) -> str:
|
||||
"""
|
||||
Guess the "Content-Type" of a file.
|
||||
|
||||
:param filename:
|
||||
The filename to guess the "Content-Type" of using :mod:`mimetypes`.
|
||||
:param default:
|
||||
If no "Content-Type" can be guessed, default to `default`.
|
||||
"""
|
||||
if filename:
|
||||
return mimetypes.guess_type(filename)[0] or default
|
||||
return default
|
||||
|
||||
|
||||
def format_header_param_rfc2231(name: str, value: _TYPE_FIELD_VALUE) -> str:
|
||||
"""
|
||||
Helper function to format and quote a single header parameter using the
|
||||
strategy defined in RFC 2231.
|
||||
|
||||
Particularly useful for header parameters which might contain
|
||||
non-ASCII values, like file names. This follows
|
||||
`RFC 2388 Section 4.4 <https://tools.ietf.org/html/rfc2388#section-4.4>`_.
|
||||
|
||||
:param name:
|
||||
The name of the parameter, a string expected to be ASCII only.
|
||||
:param value:
|
||||
The value of the parameter, provided as ``bytes`` or `str``.
|
||||
:returns:
|
||||
An RFC-2231-formatted unicode string.
|
||||
|
||||
.. deprecated:: 2.0.0
|
||||
Will be removed in urllib3 v2.1.0. This is not valid for
|
||||
``multipart/form-data`` header parameters.
|
||||
"""
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode("utf-8")
|
||||
|
||||
if not any(ch in value for ch in '"\\\r\n'):
|
||||
result = f'{name}="{value}"'
|
||||
try:
|
||||
result.encode("ascii")
|
||||
except (UnicodeEncodeError, UnicodeDecodeError):
|
||||
pass
|
||||
else:
|
||||
return result
|
||||
|
||||
value = email.utils.encode_rfc2231(value, "utf-8")
|
||||
value = f"{name}*={value}"
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def format_multipart_header_param(name: str, value: _TYPE_FIELD_VALUE) -> str:
|
||||
"""
|
||||
Format and quote a single multipart header parameter.
|
||||
|
||||
This follows the `WHATWG HTML Standard`_ as of 2021/06/10, matching
|
||||
the behavior of current browser and curl versions. Values are
|
||||
assumed to be UTF-8. The ``\\n``, ``\\r``, and ``"`` characters are
|
||||
percent encoded.
|
||||
|
||||
.. _WHATWG HTML Standard:
|
||||
https://html.spec.whatwg.org/multipage/
|
||||
form-control-infrastructure.html#multipart-form-data
|
||||
|
||||
:param name:
|
||||
The name of the parameter, an ASCII-only ``str``.
|
||||
:param value:
|
||||
The value of the parameter, a ``str`` or UTF-8 encoded
|
||||
``bytes``.
|
||||
:returns:
|
||||
A string ``name="value"`` with the escaped value.
|
||||
|
||||
.. versionchanged:: 2.0.0
|
||||
Matches the WHATWG HTML Standard as of 2021/06/10. Control
|
||||
characters are no longer percent encoded.
|
||||
|
||||
.. versionchanged:: 2.0.0
|
||||
Renamed from ``format_header_param_html5`` and
|
||||
``format_header_param``. The old names will be removed in
|
||||
urllib3 v2.1.0.
|
||||
"""
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode("utf-8")
|
||||
|
||||
# percent encode \n \r "
|
||||
value = value.translate({10: "%0A", 13: "%0D", 34: "%22"})
|
||||
return f'{name}="{value}"'
|
||||
|
||||
|
||||
class RequestField:
|
||||
"""
|
||||
A data container for request body parameters.
|
||||
|
||||
:param name:
|
||||
The name of this request field. Must be unicode.
|
||||
:param data:
|
||||
The data/value body.
|
||||
:param filename:
|
||||
An optional filename of the request field. Must be unicode.
|
||||
:param headers:
|
||||
An optional dict-like object of headers to initially use for the field.
|
||||
|
||||
.. versionchanged:: 2.0.0
|
||||
The ``header_formatter`` parameter is deprecated and will
|
||||
be removed in urllib3 v2.1.0.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
data: _TYPE_FIELD_VALUE,
|
||||
filename: str | None = None,
|
||||
headers: typing.Mapping[str, str] | None = None,
|
||||
header_formatter: typing.Callable[[str, _TYPE_FIELD_VALUE], str] | None = None,
|
||||
):
|
||||
self._name = name
|
||||
self._filename = filename
|
||||
self.data = data
|
||||
self.headers: dict[str, str | None] = {}
|
||||
if headers:
|
||||
self.headers = dict(headers)
|
||||
|
||||
if header_formatter is not None:
|
||||
self.header_formatter = header_formatter
|
||||
else:
|
||||
self.header_formatter = format_multipart_header_param
|
||||
|
||||
@classmethod
|
||||
def from_tuples(
|
||||
cls,
|
||||
fieldname: str,
|
||||
value: _TYPE_FIELD_VALUE_TUPLE,
|
||||
header_formatter: typing.Callable[[str, _TYPE_FIELD_VALUE], str] | None = None,
|
||||
) -> RequestField:
|
||||
"""
|
||||
A :class:`~urllib3.fields.RequestField` factory from old-style tuple parameters.
|
||||
|
||||
Supports constructing :class:`~urllib3.fields.RequestField` from
|
||||
parameter of key/value strings AND key/filetuple. A filetuple is a
|
||||
(filename, data, MIME type) tuple where the MIME type is optional.
|
||||
For example::
|
||||
|
||||
'foo': 'bar',
|
||||
'fakefile': ('foofile.txt', 'contents of foofile'),
|
||||
'realfile': ('barfile.txt', open('realfile').read()),
|
||||
'typedfile': ('bazfile.bin', open('bazfile').read(), 'image/jpeg'),
|
||||
'nonamefile': 'contents of nonamefile field',
|
||||
|
||||
Field names and filenames must be unicode.
|
||||
"""
|
||||
filename: str | None
|
||||
content_type: str | None
|
||||
data: _TYPE_FIELD_VALUE
|
||||
|
||||
if isinstance(value, tuple):
|
||||
if len(value) == 3:
|
||||
filename, data, content_type = value
|
||||
else:
|
||||
filename, data = value
|
||||
content_type = guess_content_type(filename)
|
||||
else:
|
||||
filename = None
|
||||
content_type = None
|
||||
data = value
|
||||
|
||||
request_param = cls(
|
||||
fieldname, data, filename=filename, header_formatter=header_formatter
|
||||
)
|
||||
request_param.make_multipart(content_type=content_type)
|
||||
|
||||
return request_param
|
||||
|
||||
def _render_part(self, name: str, value: _TYPE_FIELD_VALUE) -> str:
|
||||
"""
|
||||
Override this method to change how each multipart header
|
||||
parameter is formatted. By default, this calls
|
||||
:func:`format_multipart_header_param`.
|
||||
|
||||
:param name:
|
||||
The name of the parameter, an ASCII-only ``str``.
|
||||
:param value:
|
||||
The value of the parameter, a ``str`` or UTF-8 encoded
|
||||
``bytes``.
|
||||
|
||||
:meta public:
|
||||
"""
|
||||
return self.header_formatter(name, value)
|
||||
|
||||
def _render_parts(
|
||||
self,
|
||||
header_parts: (
|
||||
dict[str, _TYPE_FIELD_VALUE | None]
|
||||
| typing.Sequence[tuple[str, _TYPE_FIELD_VALUE | None]]
|
||||
),
|
||||
) -> str:
|
||||
"""
|
||||
Helper function to format and quote a single header.
|
||||
|
||||
Useful for single headers that are composed of multiple items. E.g.,
|
||||
'Content-Disposition' fields.
|
||||
|
||||
:param header_parts:
|
||||
A sequence of (k, v) tuples or a :class:`dict` of (k, v) to format
|
||||
as `k1="v1"; k2="v2"; ...`.
|
||||
"""
|
||||
iterable: typing.Iterable[tuple[str, _TYPE_FIELD_VALUE | None]]
|
||||
|
||||
parts = []
|
||||
if isinstance(header_parts, dict):
|
||||
iterable = header_parts.items()
|
||||
else:
|
||||
iterable = header_parts
|
||||
|
||||
for name, value in iterable:
|
||||
if value is not None:
|
||||
parts.append(self._render_part(name, value))
|
||||
|
||||
return "; ".join(parts)
|
||||
|
||||
def render_headers(self) -> str:
|
||||
"""
|
||||
Renders the headers for this request field.
|
||||
"""
|
||||
lines = []
|
||||
|
||||
sort_keys = ["Content-Disposition", "Content-Type", "Content-Location"]
|
||||
for sort_key in sort_keys:
|
||||
if self.headers.get(sort_key, False):
|
||||
lines.append(f"{sort_key}: {self.headers[sort_key]}")
|
||||
|
||||
for header_name, header_value in self.headers.items():
|
||||
if header_name not in sort_keys:
|
||||
if header_value:
|
||||
lines.append(f"{header_name}: {header_value}")
|
||||
|
||||
lines.append("\r\n")
|
||||
return "\r\n".join(lines)
|
||||
|
||||
def make_multipart(
|
||||
self,
|
||||
content_disposition: str | None = None,
|
||||
content_type: str | None = None,
|
||||
content_location: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Makes this request field into a multipart request field.
|
||||
|
||||
This method overrides "Content-Disposition", "Content-Type" and
|
||||
"Content-Location" headers to the request parameter.
|
||||
|
||||
:param content_disposition:
|
||||
The 'Content-Disposition' of the request body. Defaults to 'form-data'
|
||||
:param content_type:
|
||||
The 'Content-Type' of the request body.
|
||||
:param content_location:
|
||||
The 'Content-Location' of the request body.
|
||||
|
||||
"""
|
||||
content_disposition = (content_disposition or "form-data") + "; ".join(
|
||||
[
|
||||
"",
|
||||
self._render_parts(
|
||||
(("name", self._name), ("filename", self._filename))
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.headers["Content-Disposition"] = content_disposition
|
||||
self.headers["Content-Type"] = content_type
|
||||
self.headers["Content-Location"] = content_location
|
||||
82
.venv/lib/python3.9/site-packages/urllib3_future/filepost.py
Normal file
82
.venv/lib/python3.9/site-packages/urllib3_future/filepost.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import binascii
|
||||
import codecs
|
||||
import os
|
||||
import typing
|
||||
from io import BytesIO
|
||||
|
||||
from ._typing import _TYPE_FIELD_VALUE_TUPLE, _TYPE_FIELDS
|
||||
from .fields import RequestField
|
||||
|
||||
writer = codecs.lookup("utf-8")[3]
|
||||
|
||||
|
||||
def choose_boundary() -> str:
|
||||
"""
|
||||
Our embarrassingly-simple replacement for mimetools.choose_boundary.
|
||||
"""
|
||||
return binascii.hexlify(os.urandom(16)).decode()
|
||||
|
||||
|
||||
def iter_field_objects(fields: _TYPE_FIELDS) -> typing.Iterable[RequestField]:
|
||||
"""
|
||||
Iterate over fields.
|
||||
|
||||
Supports list of (k, v) tuples and dicts, and lists of
|
||||
:class:`~urllib3.fields.RequestField`.
|
||||
|
||||
"""
|
||||
iterable: typing.Iterable[RequestField | tuple[str, _TYPE_FIELD_VALUE_TUPLE]]
|
||||
|
||||
if isinstance(fields, typing.Mapping):
|
||||
iterable = fields.items()
|
||||
else:
|
||||
iterable = fields
|
||||
|
||||
for field in iterable:
|
||||
if isinstance(field, RequestField):
|
||||
yield field
|
||||
else:
|
||||
yield RequestField.from_tuples(*field)
|
||||
|
||||
|
||||
def encode_multipart_formdata(
|
||||
fields: _TYPE_FIELDS, boundary: str | None = None
|
||||
) -> tuple[bytes, str]:
|
||||
"""
|
||||
Encode a dictionary of ``fields`` using the multipart/form-data MIME format.
|
||||
|
||||
:param fields:
|
||||
Dictionary of fields or list of (key, :class:`~urllib3.fields.RequestField`).
|
||||
Values are processed by :func:`urllib3.fields.RequestField.from_tuples`.
|
||||
|
||||
:param boundary:
|
||||
If not specified, then a random boundary will be generated using
|
||||
:func:`urllib3.filepost.choose_boundary`.
|
||||
"""
|
||||
body = BytesIO()
|
||||
if boundary is None:
|
||||
boundary = choose_boundary()
|
||||
|
||||
for field in iter_field_objects(fields):
|
||||
body.write(f"--{boundary}\r\n".encode("latin-1"))
|
||||
|
||||
writer(body).write(field.render_headers())
|
||||
data = field.data
|
||||
|
||||
if isinstance(data, int):
|
||||
data = str(data) # Backwards compatibility
|
||||
|
||||
if isinstance(data, str):
|
||||
writer(body).write(data)
|
||||
else:
|
||||
body.write(data)
|
||||
|
||||
body.write(b"\r\n")
|
||||
|
||||
body.write(f"--{boundary}--\r\n".encode("latin-1"))
|
||||
|
||||
content_type = f"multipart/form-data; boundary={boundary}"
|
||||
|
||||
return body.getvalue(), content_type
|
||||
@@ -0,0 +1,23 @@
|
||||
# Dummy file to match upstream modules
|
||||
# without actually serving them.
|
||||
# urllib3-future diverged from urllib3.
|
||||
# only the top-level (public API) are guaranteed to be compatible.
|
||||
# in-fact urllib3-future propose a better way to migrate/transition toward
|
||||
# newer protocols.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
def inject_into_urllib3() -> None:
|
||||
warnings.warn(
|
||||
"urllib3-future do not propose the http2 module as it is useless to us. "
|
||||
"enjoy HTTP/1.1, HTTP/2, and HTTP/3 without hacks. urllib3-future just works out "
|
||||
"of the box with all protocols. No hassles.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
|
||||
def extract_from_urllib3() -> None:
|
||||
pass
|
||||
1172
.venv/lib/python3.9/site-packages/urllib3_future/poolmanager.py
Normal file
1172
.venv/lib/python3.9/site-packages/urllib3_future/poolmanager.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,2 @@
|
||||
# Instruct type checkers to look for inline type annotations in this package.
|
||||
# See PEP 561.
|
||||
1091
.venv/lib/python3.9/site-packages/urllib3_future/response.py
Normal file
1091
.venv/lib/python3.9/site-packages/urllib3_future/response.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,41 @@
|
||||
# For backwards compatibility, provide imports that used to be here.
|
||||
from __future__ import annotations
|
||||
|
||||
from .connection import is_connection_dropped
|
||||
from .request import SKIP_HEADER, SKIPPABLE_HEADERS, make_headers
|
||||
from .response import is_fp_closed, parse_alt_svc
|
||||
from .retry import Retry
|
||||
from .ssl_ import (
|
||||
ALPN_PROTOCOLS,
|
||||
SSLContext,
|
||||
assert_fingerprint,
|
||||
create_urllib3_context,
|
||||
resolve_cert_reqs,
|
||||
resolve_ssl_version,
|
||||
ssl_wrap_socket,
|
||||
)
|
||||
from .timeout import Timeout
|
||||
from .url import Url, parse_url
|
||||
from .wait import wait_for_read, wait_for_write
|
||||
|
||||
__all__ = (
|
||||
"SSLContext",
|
||||
"ALPN_PROTOCOLS",
|
||||
"Retry",
|
||||
"Timeout",
|
||||
"Url",
|
||||
"assert_fingerprint",
|
||||
"create_urllib3_context",
|
||||
"is_connection_dropped",
|
||||
"is_fp_closed",
|
||||
"parse_url",
|
||||
"make_headers",
|
||||
"resolve_cert_reqs",
|
||||
"resolve_ssl_version",
|
||||
"ssl_wrap_socket",
|
||||
"wait_for_read",
|
||||
"wait_for_write",
|
||||
"SKIP_HEADER",
|
||||
"SKIPPABLE_HEADERS",
|
||||
"parse_alt_svc",
|
||||
)
|
||||
@@ -0,0 +1,186 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
import typing
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import ssl
|
||||
|
||||
from ...contrib.imcc import load_cert_chain as _ctx_load_cert_chain
|
||||
from ...contrib.ssa import AsyncSocket, SSLAsyncSocket
|
||||
from ...exceptions import SSLError
|
||||
from ..ssl_ import (
|
||||
ALPN_PROTOCOLS,
|
||||
_CacheableSSLContext,
|
||||
_is_key_file_encrypted,
|
||||
create_urllib3_context,
|
||||
_KnownCaller,
|
||||
)
|
||||
|
||||
|
||||
class DummyLock:
|
||||
def __enter__(self) -> DummyLock:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore[no-untyped-def]
|
||||
pass
|
||||
|
||||
|
||||
class _NoLock_CacheableSSLContext(_CacheableSSLContext):
|
||||
"""Deprecated: we no longer avoid the lock in the async part because we want to allow many loop within many thread..."""
|
||||
|
||||
def __init__(self, maxsize: int | None = 32):
|
||||
super().__init__(maxsize=maxsize)
|
||||
self._lock = DummyLock() # type: ignore[assignment]
|
||||
|
||||
|
||||
_SSLContextCache = _CacheableSSLContext()
|
||||
|
||||
|
||||
async def ssl_wrap_socket(
|
||||
sock: AsyncSocket,
|
||||
keyfile: str | None = None,
|
||||
certfile: str | None = None,
|
||||
cert_reqs: int | None = None,
|
||||
ca_certs: str | None = None,
|
||||
server_hostname: str | None = None,
|
||||
ssl_version: int | None = None,
|
||||
ciphers: str | None = None,
|
||||
ssl_context: ssl.SSLContext | None = None,
|
||||
ca_cert_dir: str | None = None,
|
||||
key_password: str | None = None,
|
||||
ca_cert_data: None | str | bytes = None,
|
||||
tls_in_tls: bool = False,
|
||||
alpn_protocols: list[str] | None = None,
|
||||
certdata: str | bytes | None = None,
|
||||
keydata: str | bytes | None = None,
|
||||
check_hostname: bool | None = None,
|
||||
ssl_minimum_version: int | None = None,
|
||||
ssl_maximum_version: int | None = None,
|
||||
) -> SSLAsyncSocket:
|
||||
"""
|
||||
All arguments except for server_hostname, ssl_context, and ca_cert_dir have
|
||||
the same meaning as they do when using :func:`ssl.wrap_socket`.
|
||||
|
||||
:param server_hostname:
|
||||
When SNI is supported, the expected hostname of the certificate
|
||||
:param ssl_context:
|
||||
A pre-made :class:`SSLContext` object. If none is provided, one will
|
||||
be created using :func:`create_urllib3_context`.
|
||||
:param ciphers:
|
||||
A string of ciphers we wish the client to support.
|
||||
:param ca_cert_dir:
|
||||
A directory containing CA certificates in multiple separate files, as
|
||||
supported by OpenSSL's -CApath flag or the capath argument to
|
||||
SSLContext.load_verify_locations().
|
||||
:param key_password:
|
||||
Optional password if the keyfile is encrypted.
|
||||
:param ca_cert_data:
|
||||
Optional string containing CA certificates in PEM format suitable for
|
||||
passing as the cadata parameter to SSLContext.load_verify_locations()
|
||||
:param tls_in_tls:
|
||||
No-op in asynchronous mode. Call wrap_socket of the SSLAsyncSocket later.
|
||||
:param alpn_protocols:
|
||||
Manually specify other protocols to be announced during tls handshake.
|
||||
:param certdata:
|
||||
Specify an in-memory client intermediary certificate for mTLS.
|
||||
:param keydata:
|
||||
Specify an in-memory client intermediary key for mTLS.
|
||||
"""
|
||||
context = ssl_context
|
||||
|
||||
cache_disabled: bool = context is not None
|
||||
|
||||
with _SSLContextCache.lock(
|
||||
keyfile,
|
||||
certfile if certfile is None else Path(certfile),
|
||||
cert_reqs,
|
||||
ca_certs,
|
||||
ssl_version,
|
||||
ciphers,
|
||||
ca_cert_dir if ca_cert_dir is None else Path(ca_cert_dir),
|
||||
alpn_protocols,
|
||||
certdata,
|
||||
keydata,
|
||||
key_password,
|
||||
ca_cert_data,
|
||||
os.getenv("SSLKEYLOGFILE", None),
|
||||
ssl_minimum_version,
|
||||
ssl_maximum_version,
|
||||
check_hostname,
|
||||
):
|
||||
cached_ctx = _SSLContextCache.get() if not cache_disabled else None
|
||||
|
||||
if cached_ctx is None:
|
||||
if context is None:
|
||||
context = create_urllib3_context(
|
||||
ssl_version,
|
||||
cert_reqs,
|
||||
ciphers=ciphers,
|
||||
caller_id=_KnownCaller.NIQUESTS,
|
||||
ssl_minimum_version=ssl_minimum_version,
|
||||
ssl_maximum_version=ssl_maximum_version,
|
||||
)
|
||||
|
||||
if cert_reqs is not None:
|
||||
context.verify_mode = cert_reqs # type: ignore[assignment]
|
||||
|
||||
if check_hostname is not None:
|
||||
context.check_hostname = check_hostname
|
||||
|
||||
if ca_certs or ca_cert_dir or ca_cert_data:
|
||||
# SSLContext does not support bytes for cadata[...]
|
||||
if ca_cert_data and isinstance(ca_cert_data, bytes):
|
||||
ca_cert_data = ca_cert_data.decode()
|
||||
|
||||
try:
|
||||
context.load_verify_locations(ca_certs, ca_cert_dir, ca_cert_data)
|
||||
except OSError as e:
|
||||
raise SSLError(e) from e
|
||||
|
||||
elif hasattr(context, "load_default_certs"):
|
||||
store_stats = context.cert_store_stats()
|
||||
# try to load OS default certs; works well on Windows.
|
||||
if "x509_ca" not in store_stats or not store_stats["x509_ca"]:
|
||||
context.load_default_certs()
|
||||
|
||||
# Attempt to detect if we get the goofy behavior of the
|
||||
# keyfile being encrypted and OpenSSL asking for the
|
||||
# passphrase via the terminal and instead error out.
|
||||
if keyfile and key_password is None and _is_key_file_encrypted(keyfile):
|
||||
raise SSLError("Client private key is encrypted, password is required")
|
||||
|
||||
if certfile:
|
||||
if key_password is None:
|
||||
context.load_cert_chain(certfile, keyfile)
|
||||
else:
|
||||
context.load_cert_chain(certfile, keyfile, key_password)
|
||||
elif certdata and keydata:
|
||||
try:
|
||||
_ctx_load_cert_chain(context, certdata, keydata, key_password)
|
||||
except io.UnsupportedOperation as e:
|
||||
warnings.warn(
|
||||
f"""Passing in-memory client/intermediary certificate for mTLS is unsupported on your platform.
|
||||
Reason: {e}. It will be picked out if you upgrade to a QUIC connection.""",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
try:
|
||||
context.set_alpn_protocols(alpn_protocols or ALPN_PROTOCOLS)
|
||||
except (
|
||||
NotImplementedError
|
||||
): # Defensive: in CI, we always have set_alpn_protocols
|
||||
pass
|
||||
|
||||
if ciphers:
|
||||
context.set_ciphers(ciphers)
|
||||
|
||||
if not cache_disabled:
|
||||
_SSLContextCache.save(context)
|
||||
else:
|
||||
context = cached_ctx
|
||||
|
||||
return await sock.wrap_socket(context, server_hostname=server_hostname)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user