fix: 포트 충돌 회피 — note_bridge 8098, intent_service 8099
Jellyfin(8096), OrbStack(8097) 포트 충돌으로 변경. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
187
.venv/lib/python3.9/site-packages/niquests/__init__.py
Normal file
187
.venv/lib/python3.9/site-packages/niquests/__init__.py
Normal file
@@ -0,0 +1,187 @@
|
||||
# __
|
||||
# /__) _ _ _ _ _/ _
|
||||
# / ( (- (/ (/ (- _) / _)
|
||||
# /
|
||||
|
||||
"""
|
||||
Niquests HTTP Library
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Niquests is an HTTP library, written in Python, for human beings.
|
||||
Basic GET usage:
|
||||
|
||||
>>> import niquests
|
||||
>>> r = niquests.get('https://www.python.org')
|
||||
>>> r.status_code
|
||||
200
|
||||
>>> b'Python is a programming language' in r.content
|
||||
True
|
||||
|
||||
... or POST:
|
||||
|
||||
>>> payload = dict(key1='value1', key2='value2')
|
||||
>>> r = niquests.post('https://httpbin.org/post', data=payload)
|
||||
>>> print(r.text)
|
||||
{
|
||||
...
|
||||
"form": {
|
||||
"key1": "value1",
|
||||
"key2": "value2"
|
||||
},
|
||||
...
|
||||
}
|
||||
|
||||
The other HTTP methods are supported - see `requests.api`. Full documentation
|
||||
is at <https://niquests.readthedocs.io>.
|
||||
|
||||
:copyright: (c) 2017 by Kenneth Reitz.
|
||||
:license: Apache 2.0, see LICENSE for more details.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# Set default logging handler to avoid "No handler found" warnings.
|
||||
import logging
|
||||
import warnings
|
||||
from logging import NullHandler
|
||||
|
||||
from ._compat import HAS_LEGACY_URLLIB3
|
||||
from .extensions.revocation import RevocationConfiguration, RevocationStrategy
|
||||
from .packages.urllib3 import (
|
||||
Retry as RetryConfiguration,
|
||||
)
|
||||
from .packages.urllib3 import (
|
||||
Timeout as TimeoutConfiguration,
|
||||
)
|
||||
from .packages.urllib3.contrib.webextensions.sse import ServerSentEvent
|
||||
from .packages.urllib3.exceptions import DependencyWarning
|
||||
|
||||
# urllib3's DependencyWarnings should be silenced.
|
||||
warnings.simplefilter("ignore", DependencyWarning)
|
||||
|
||||
# ruff: noqa: E402
|
||||
from . import utils
|
||||
from .__version__ import (
|
||||
__author__,
|
||||
__author_email__,
|
||||
__build__,
|
||||
__cake__,
|
||||
__copyright__,
|
||||
__description__,
|
||||
__license__,
|
||||
__title__,
|
||||
__url__,
|
||||
__version__,
|
||||
)
|
||||
from .api import delete, get, head, options, patch, post, put, request
|
||||
from .async_api import (
|
||||
delete as adelete,
|
||||
)
|
||||
from .async_api import (
|
||||
get as aget,
|
||||
)
|
||||
from .async_api import (
|
||||
head as ahead,
|
||||
)
|
||||
from .async_api import (
|
||||
options as aoptions,
|
||||
)
|
||||
from .async_api import (
|
||||
patch as apatch,
|
||||
)
|
||||
from .async_api import (
|
||||
post as apost,
|
||||
)
|
||||
from .async_api import (
|
||||
put as aput,
|
||||
)
|
||||
from .async_api import (
|
||||
request as arequest,
|
||||
)
|
||||
from .async_session import AsyncSession
|
||||
from .exceptions import (
|
||||
ConnectionError,
|
||||
ConnectTimeout,
|
||||
FileModeWarning,
|
||||
HTTPError,
|
||||
JSONDecodeError,
|
||||
ReadTimeout,
|
||||
RequestException,
|
||||
RequestsDependencyWarning,
|
||||
Timeout,
|
||||
TooManyRedirects,
|
||||
URLRequired,
|
||||
)
|
||||
from .hooks import (
|
||||
AsyncLeakyBucketLimiter,
|
||||
AsyncLifeCycleHook,
|
||||
AsyncTokenBucketLimiter,
|
||||
LeakyBucketLimiter,
|
||||
LifeCycleHook,
|
||||
TokenBucketLimiter,
|
||||
)
|
||||
from .models import AsyncResponse, PreparedRequest, Request, Response
|
||||
from .sessions import Session
|
||||
from .status_codes import codes
|
||||
|
||||
logging.getLogger(__name__).addHandler(NullHandler())
|
||||
|
||||
__all__ = (
|
||||
"RequestsDependencyWarning",
|
||||
"utils",
|
||||
"__author__",
|
||||
"__author_email__",
|
||||
"__build__",
|
||||
"__cake__",
|
||||
"__copyright__",
|
||||
"__description__",
|
||||
"__license__",
|
||||
"__title__",
|
||||
"__url__",
|
||||
"__version__",
|
||||
"delete",
|
||||
"get",
|
||||
"head",
|
||||
"options",
|
||||
"patch",
|
||||
"post",
|
||||
"put",
|
||||
"request",
|
||||
"adelete",
|
||||
"aget",
|
||||
"ahead",
|
||||
"aoptions",
|
||||
"apatch",
|
||||
"apost",
|
||||
"aput",
|
||||
"arequest",
|
||||
"ConnectionError",
|
||||
"ConnectTimeout",
|
||||
"FileModeWarning",
|
||||
"HTTPError",
|
||||
"JSONDecodeError",
|
||||
"ReadTimeout",
|
||||
"RequestException",
|
||||
"Timeout",
|
||||
"TooManyRedirects",
|
||||
"URLRequired",
|
||||
"PreparedRequest",
|
||||
"Request",
|
||||
"Response",
|
||||
"Session",
|
||||
"codes",
|
||||
"AsyncSession",
|
||||
"AsyncResponse",
|
||||
"TimeoutConfiguration",
|
||||
"RetryConfiguration",
|
||||
"HAS_LEGACY_URLLIB3",
|
||||
"AsyncLifeCycleHook",
|
||||
"AsyncLeakyBucketLimiter",
|
||||
"AsyncTokenBucketLimiter",
|
||||
"LifeCycleHook",
|
||||
"LeakyBucketLimiter",
|
||||
"TokenBucketLimiter",
|
||||
"RevocationConfiguration",
|
||||
"RevocationStrategy",
|
||||
"ServerSentEvent",
|
||||
)
|
||||
19
.venv/lib/python3.9/site-packages/niquests/__version__.py
Normal file
19
.venv/lib/python3.9/site-packages/niquests/__version__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# .-. .-. .-. . . .-. .-. .-. .-.
|
||||
# |( |- |.| | | |- `-. | `-.
|
||||
# ' ' `-' `-`.`-' `-' `-' ' `-'
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__title__: str = "niquests"
|
||||
__description__: str = "Python HTTP for Humans."
|
||||
__url__: str = "https://niquests.readthedocs.io"
|
||||
|
||||
__version__: str
|
||||
__version__ = "3.18.2"
|
||||
|
||||
__build__: int = 0x031802
|
||||
__author__: str = "Kenneth Reitz"
|
||||
__author_email__: str = "me@kennethreitz.org"
|
||||
__license__: str = "Apache-2.0"
|
||||
__copyright__: str = "Copyright Kenneth Reitz"
|
||||
__cake__: str = "\u2728 \U0001f370 \u2728"
|
||||
14
.venv/lib/python3.9/site-packages/niquests/_async.py
Normal file
14
.venv/lib/python3.9/site-packages/niquests/_async.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
(
|
||||
"importing niquests._async is deprecated and absolutely discouraged. "
|
||||
"It will be removed in a future release. In general, never import private "
|
||||
"modules."
|
||||
),
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
from .async_session import * # noqa
|
||||
94
.venv/lib/python3.9/site-packages/niquests/_compat.py
Normal file
94
.venv/lib/python3.9/site-packages/niquests/_compat.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import typing
|
||||
|
||||
# asyncio.iscoroutinefunction is deprecated in Python 3.14 and will be removed in 3.16.
|
||||
# Use inspect.iscoroutinefunction for Python 3.14+ and asyncio.iscoroutinefunction for earlier.
|
||||
# Note: There are subtle behavioral differences between the two functions, but for
|
||||
# the use cases in niquests (checking if callbacks/hooks are async), both should work.
|
||||
if sys.version_info >= (3, 14):
|
||||
import inspect
|
||||
|
||||
iscoroutinefunction = inspect.iscoroutinefunction
|
||||
else:
|
||||
import asyncio
|
||||
|
||||
iscoroutinefunction = asyncio.iscoroutinefunction
|
||||
|
||||
try:
|
||||
from urllib3 import __version__
|
||||
|
||||
HAS_LEGACY_URLLIB3: bool = int(__version__.split(".")[-1]) < 900
|
||||
except (ValueError, ImportError): # Defensive: tested in separate CI
|
||||
# Means one of the two cases:
|
||||
# 1) urllib3 does not exist -> fallback to urllib3_future
|
||||
# 2) urllib3 exist but not fork -> fallback to urllib3_future
|
||||
HAS_LEGACY_URLLIB3 = True
|
||||
|
||||
if HAS_LEGACY_URLLIB3: # Defensive: tested in separate CI
|
||||
import urllib3_future
|
||||
else:
|
||||
urllib3_future = None # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
import urllib3
|
||||
|
||||
# force detect broken or dummy urllib3 package
|
||||
urllib3.Timeout # noqa
|
||||
urllib3.Retry # noqa
|
||||
urllib3.__version__ # noqa
|
||||
except (ImportError, AttributeError): # Defensive: tested in separate CI
|
||||
urllib3 = None # type: ignore[assignment]
|
||||
|
||||
|
||||
if (urllib3 is None and urllib3_future is None) or (HAS_LEGACY_URLLIB3 and urllib3_future is None):
|
||||
raise RuntimeError( # Defensive: tested in separate CI
|
||||
"This is awkward but your environment is missing urllib3-future. "
|
||||
"Your environment seems broken. "
|
||||
"You may fix this issue by running `python -m pip install niquests -U` "
|
||||
"to force reinstall its dependencies."
|
||||
)
|
||||
|
||||
if urllib3 is not None:
|
||||
T = typing.TypeVar("T", urllib3.Timeout, urllib3.Retry)
|
||||
else: # Defensive: tested in separate CI
|
||||
T = typing.TypeVar("T", urllib3_future.Timeout, urllib3_future.Retry) # type: ignore
|
||||
|
||||
|
||||
def urllib3_ensure_type(o: T) -> T:
|
||||
"""Retry, Timeout must be the one in urllib3_future."""
|
||||
if urllib3 is None:
|
||||
return o
|
||||
|
||||
if HAS_LEGACY_URLLIB3: # Defensive: tested in separate CI
|
||||
if "urllib3_future" not in str(type(o)):
|
||||
assert urllib3_future is not None
|
||||
|
||||
if isinstance(o, urllib3.Timeout):
|
||||
return urllib3_future.Timeout( # type: ignore[return-value]
|
||||
o.total, # type: ignore[arg-type]
|
||||
o.connect_timeout, # type: ignore[arg-type]
|
||||
o.read_timeout, # type: ignore[arg-type]
|
||||
)
|
||||
if isinstance(o, urllib3.Retry):
|
||||
return urllib3_future.Retry( # type: ignore[return-value]
|
||||
o.total,
|
||||
o.connect,
|
||||
o.read,
|
||||
redirect=o.redirect,
|
||||
status=o.status,
|
||||
other=o.other,
|
||||
allowed_methods=o.allowed_methods,
|
||||
status_forcelist=o.status_forcelist,
|
||||
backoff_factor=o.backoff_factor,
|
||||
backoff_max=o.backoff_max,
|
||||
raise_on_redirect=o.raise_on_redirect,
|
||||
raise_on_status=o.raise_on_status,
|
||||
history=o.history, # type: ignore[arg-type]
|
||||
respect_retry_after_header=o.respect_retry_after_header,
|
||||
remove_headers_on_redirect=o.remove_headers_on_redirect,
|
||||
backoff_jitter=o.backoff_jitter,
|
||||
)
|
||||
|
||||
return o
|
||||
25
.venv/lib/python3.9/site-packages/niquests/_constant.py
Normal file
25
.venv/lib/python3.9/site-packages/niquests/_constant.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .typing import RetryType, TimeoutType
|
||||
|
||||
#: Default timeout (total) assigned for GET, HEAD, and OPTIONS methods.
|
||||
READ_DEFAULT_TIMEOUT: TimeoutType = 30
|
||||
#: Default timeout (total) assigned for DELETE, PUT, PATCH, and POST.
|
||||
WRITE_DEFAULT_TIMEOUT: TimeoutType = 120
|
||||
|
||||
DEFAULT_POOLBLOCK: bool = False
|
||||
DEFAULT_POOLSIZE: int = 10
|
||||
DEFAULT_RETRIES: RetryType = 0
|
||||
|
||||
|
||||
# we don't want to eagerly load this as some user just
|
||||
# don't leverage ssl anyway. this should make niquests
|
||||
# import generally faster.
|
||||
def __getattr__(name: str):
|
||||
if name == "DEFAULT_CA_BUNDLE":
|
||||
import wassima
|
||||
|
||||
val = wassima.generate_ca_bundle()
|
||||
globals()["DEFAULT_CA_BUNDLE"] = val
|
||||
return val
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
14
.venv/lib/python3.9/site-packages/niquests/_typing.py
Normal file
14
.venv/lib/python3.9/site-packages/niquests/_typing.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
(
|
||||
"importing niquests._typing is deprecated and absolutely discouraged. "
|
||||
"It will be removed in a future release. In general, never import private "
|
||||
"modules."
|
||||
),
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
from .typing import * # noqa
|
||||
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2020 TAHRI Ahmed R.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
Kiss-Headers
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
Kiss-Headers is a headers, HTTP or IMAP4 _(message, email)_ flavour, utility, written in pure Python, for humans.
|
||||
Object oriented headers. Keep it sweet and simple.
|
||||
Basic usage:
|
||||
|
||||
>>> import requests
|
||||
>>> from kiss_headers import parse_it
|
||||
>>> r = requests.get('https://www.python.org')
|
||||
>>> headers = parse_it(r)
|
||||
>>> 'charset' in headers.content_type
|
||||
True
|
||||
>>> headers.content_type.charset
|
||||
'utf-8'
|
||||
>>> 'text/html' in headers.content_type
|
||||
True
|
||||
>>> headers.content_type == 'text/html'
|
||||
True
|
||||
>>> headers -= 'content-type'
|
||||
>>> 'Content-Type' in headers
|
||||
False
|
||||
|
||||
... or from a raw IMAP4 message:
|
||||
|
||||
>>> message = requests.get("https://gist.githubusercontent.com/Ousret/8b84b736c375bb6aa3d389e86b5116ec/raw/21cb2f7af865e401c37d9b053fb6fe1abf63165b/sample-message.eml").content
|
||||
>>> headers = parse_it(message)
|
||||
>>> 'Sender' in headers
|
||||
True
|
||||
|
||||
Others methods and usages are available - see the full documentation
|
||||
at <https://github.com/jawah/kiss-headers>.
|
||||
|
||||
:copyright: (c) 2020 by Ahmed TAHRI
|
||||
:license: MIT, see LICENSE for more details.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .api import dumps, explain, get_polymorphic, parse_it
|
||||
from .builder import (
|
||||
Accept,
|
||||
AcceptEncoding,
|
||||
AcceptLanguage,
|
||||
Allow,
|
||||
AltSvc,
|
||||
Authorization,
|
||||
BasicAuthorization,
|
||||
CacheControl,
|
||||
Connection,
|
||||
ContentDisposition,
|
||||
ContentEncoding,
|
||||
ContentLength,
|
||||
ContentRange,
|
||||
ContentSecurityPolicy,
|
||||
ContentType,
|
||||
CrossOriginResourcePolicy,
|
||||
CustomHeader,
|
||||
Date,
|
||||
Digest,
|
||||
Dnt,
|
||||
Etag,
|
||||
Expires,
|
||||
Forwarded,
|
||||
From,
|
||||
Host,
|
||||
IfMatch,
|
||||
IfModifiedSince,
|
||||
IfNoneMatch,
|
||||
IfUnmodifiedSince,
|
||||
KeepAlive,
|
||||
LastModified,
|
||||
Location,
|
||||
ProxyAuthorization,
|
||||
Referer,
|
||||
ReferrerPolicy,
|
||||
RetryAfter,
|
||||
Server,
|
||||
SetCookie,
|
||||
StrictTransportSecurity,
|
||||
TransferEncoding,
|
||||
UpgradeInsecureRequests,
|
||||
UserAgent,
|
||||
Vary,
|
||||
WwwAuthenticate,
|
||||
XContentTypeOptions,
|
||||
XDnsPrefetchControl,
|
||||
XFrameOptions,
|
||||
XXssProtection,
|
||||
)
|
||||
from .models import Attributes, Header, Headers, lock_output_type
|
||||
from .serializer import decode, encode
|
||||
from .version import VERSION, __version__
|
||||
|
||||
__all__ = (
|
||||
"dumps",
|
||||
"explain",
|
||||
"get_polymorphic",
|
||||
"parse_it",
|
||||
"Attributes",
|
||||
"Header",
|
||||
"Headers",
|
||||
"lock_output_type",
|
||||
"decode",
|
||||
"encode",
|
||||
"VERSION",
|
||||
"__version__",
|
||||
"Accept",
|
||||
"AcceptEncoding",
|
||||
"AcceptLanguage",
|
||||
"Allow",
|
||||
"AltSvc",
|
||||
"Authorization",
|
||||
"BasicAuthorization",
|
||||
"CacheControl",
|
||||
"Connection",
|
||||
"ContentDisposition",
|
||||
"ContentEncoding",
|
||||
"ContentLength",
|
||||
"ContentRange",
|
||||
"ContentSecurityPolicy",
|
||||
"ContentType",
|
||||
"CrossOriginResourcePolicy",
|
||||
"CustomHeader",
|
||||
"Date",
|
||||
"Digest",
|
||||
"Dnt",
|
||||
"Etag",
|
||||
"Expires",
|
||||
"Forwarded",
|
||||
"From",
|
||||
"Host",
|
||||
"IfMatch",
|
||||
"IfModifiedSince",
|
||||
"IfNoneMatch",
|
||||
"IfUnmodifiedSince",
|
||||
"KeepAlive",
|
||||
"LastModified",
|
||||
"Location",
|
||||
"ProxyAuthorization",
|
||||
"Referer",
|
||||
"ReferrerPolicy",
|
||||
"RetryAfter",
|
||||
"Server",
|
||||
"SetCookie",
|
||||
"StrictTransportSecurity",
|
||||
"TransferEncoding",
|
||||
"UpgradeInsecureRequests",
|
||||
"UserAgent",
|
||||
"Vary",
|
||||
"WwwAuthenticate",
|
||||
"XContentTypeOptions",
|
||||
"XDnsPrefetchControl",
|
||||
"XFrameOptions",
|
||||
"XXssProtection",
|
||||
)
|
||||
@@ -0,0 +1,209 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from email.message import Message
|
||||
from email.parser import HeaderParser
|
||||
from io import BufferedReader, RawIOBase
|
||||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
from typing import Any, Iterable, Mapping, TypeVar
|
||||
|
||||
from .builder import CustomHeader
|
||||
from .models import Header, Headers
|
||||
from .serializer import decode, encode
|
||||
from .structures import CaseInsensitiveDict
|
||||
from .utils import (
|
||||
class_to_header_name,
|
||||
decode_partials,
|
||||
extract_class_name,
|
||||
extract_encoded_headers,
|
||||
header_content_split,
|
||||
header_name_to_class,
|
||||
is_content_json_object,
|
||||
is_legal_header_name,
|
||||
normalize_str,
|
||||
transform_possible_encoded,
|
||||
)
|
||||
|
||||
T = TypeVar("T", bound=CustomHeader, covariant=True)
|
||||
|
||||
|
||||
def parse_it(raw_headers: Any) -> Headers:
|
||||
"""
|
||||
Just decode anything that could contain headers. That simple PERIOD.
|
||||
If passed with a Headers instance, return a deep copy of it.
|
||||
:param raw_headers: Accept bytes, str, fp, dict, JSON, email.Message, requests.Response, niquests.Response, urllib3.HTTPResponse and httpx.Response.
|
||||
:raises:
|
||||
TypeError: If passed argument cannot be parsed to extract headers from it.
|
||||
"""
|
||||
|
||||
if isinstance(raw_headers, Headers):
|
||||
return deepcopy(raw_headers)
|
||||
|
||||
headers: Iterable[tuple[str | bytes, str | bytes]] | None = None
|
||||
|
||||
if isinstance(raw_headers, str):
|
||||
if raw_headers.startswith("{") and raw_headers.endswith("}"):
|
||||
return decode(json_loads(raw_headers))
|
||||
headers = HeaderParser().parsestr(raw_headers, headersonly=True).items()
|
||||
elif (
|
||||
isinstance(raw_headers, bytes)
|
||||
or isinstance(raw_headers, RawIOBase)
|
||||
or isinstance(raw_headers, BufferedReader)
|
||||
):
|
||||
decoded, not_decoded = extract_encoded_headers(
|
||||
raw_headers if isinstance(raw_headers, bytes) else raw_headers.read() or b""
|
||||
)
|
||||
return parse_it(decoded)
|
||||
elif isinstance(raw_headers, Mapping) or isinstance(raw_headers, Message):
|
||||
headers = raw_headers.items()
|
||||
else:
|
||||
r = extract_class_name(type(raw_headers))
|
||||
|
||||
if r:
|
||||
if r in [
|
||||
"requests.models.Response",
|
||||
"niquests.models.Response",
|
||||
"niquests.models.AsyncResponse",
|
||||
]:
|
||||
headers = []
|
||||
for header_name in raw_headers.raw.headers:
|
||||
for header_content in raw_headers.raw.headers.getlist(header_name):
|
||||
headers.append((header_name, header_content))
|
||||
elif r in [
|
||||
"httpx._models.Response",
|
||||
"urllib3.response.HTTPResponse",
|
||||
"urllib3._async.response.AsyncHTTPResponse",
|
||||
"urllib3_future.response.HTTPResponse",
|
||||
"urllib3_future._async.response.AsyncHTTPResponse",
|
||||
]: # pragma: no cover
|
||||
headers = raw_headers.headers.items()
|
||||
|
||||
if headers is None:
|
||||
raise TypeError( # pragma: no cover
|
||||
f"Cannot parse type {type(raw_headers)} as it is not supported by kiss-header."
|
||||
)
|
||||
|
||||
revised_headers: list[tuple[str, str]] = decode_partials(
|
||||
transform_possible_encoded(headers)
|
||||
)
|
||||
|
||||
# Sometime raw content does not begin with headers. If that is the case, search for the next line.
|
||||
if (
|
||||
len(revised_headers) == 0
|
||||
and len(raw_headers) > 0
|
||||
and (isinstance(raw_headers, bytes) or isinstance(raw_headers, str))
|
||||
):
|
||||
next_iter = raw_headers.split(
|
||||
b"\n" if isinstance(raw_headers, bytes) else "\n", # type: ignore[arg-type]
|
||||
maxsplit=1,
|
||||
)
|
||||
|
||||
if len(next_iter) >= 2:
|
||||
return parse_it(next_iter[-1])
|
||||
|
||||
# Prepare Header objects
|
||||
list_of_headers: list[Header] = []
|
||||
|
||||
for head, content in revised_headers:
|
||||
# We should ignore when a illegal name is considered as an header. We avoid ValueError (in __init__ of Header)
|
||||
if is_legal_header_name(head) is False:
|
||||
continue
|
||||
|
||||
is_json_obj: bool = is_content_json_object(content)
|
||||
entries: list[str]
|
||||
|
||||
if is_json_obj is False:
|
||||
entries = header_content_split(content, ",")
|
||||
else:
|
||||
entries = [content]
|
||||
|
||||
# Multiple entries are detected in one content at the only exception that its not IMAP header "Subject".
|
||||
if len(entries) > 1 and normalize_str(head) != "subject":
|
||||
for entry in entries:
|
||||
list_of_headers.append(Header(head, entry))
|
||||
else:
|
||||
list_of_headers.append(Header(head, content))
|
||||
|
||||
return Headers(*list_of_headers)
|
||||
|
||||
|
||||
def explain(headers: Headers) -> CaseInsensitiveDict:
|
||||
"""
|
||||
Return a brief explanation of each header present in headers if available.
|
||||
"""
|
||||
if not Header.__subclasses__():
|
||||
raise LookupError( # pragma: no cover
|
||||
"You cannot use explain() function without properly importing the public package."
|
||||
)
|
||||
|
||||
explanations: CaseInsensitiveDict = CaseInsensitiveDict()
|
||||
|
||||
for header in headers:
|
||||
if header.name in explanations:
|
||||
continue
|
||||
|
||||
try:
|
||||
target_class = header_name_to_class(header.name, Header.__subclasses__()[0])
|
||||
except TypeError:
|
||||
explanations[header.name] = "Unknown explanation."
|
||||
continue
|
||||
|
||||
explanations[header.name] = (
|
||||
target_class.__doc__.replace("\n", "").lstrip().replace(" ", " ").rstrip()
|
||||
if target_class.__doc__
|
||||
else "Missing docstring."
|
||||
)
|
||||
|
||||
return explanations
|
||||
|
||||
|
||||
def get_polymorphic(
|
||||
target: Headers | Header, desired_output: type[T]
|
||||
) -> T | list[T] | None:
|
||||
"""Experimental. Transform a Header or Headers object to its target `CustomHeader` subclass
|
||||
to access more ready-to-use methods. eg. You have a Header object named 'Set-Cookie' and you wish
|
||||
to extract the expiration date as a datetime.
|
||||
>>> header = Header("Set-Cookie", "1P_JAR=2020-03-16-21; expires=Wed, 15-Apr-2020 21:27:31 GMT")
|
||||
>>> header["expires"]
|
||||
'Wed, 15-Apr-2020 21:27:31 GMT'
|
||||
>>> from kiss_headers import SetCookie
|
||||
>>> set_cookie = get_polymorphic(header, SetCookie)
|
||||
>>> set_cookie.get_expire()
|
||||
datetime.datetime(2020, 4, 15, 21, 27, 31, tzinfo=datetime.timezone.utc)
|
||||
"""
|
||||
|
||||
if not issubclass(desired_output, Header):
|
||||
raise TypeError(
|
||||
f"The desired output should be a subclass of Header not {desired_output}."
|
||||
)
|
||||
|
||||
desired_output_header_name: str = class_to_header_name(desired_output)
|
||||
|
||||
if isinstance(target, Headers):
|
||||
r = target.get(desired_output_header_name)
|
||||
|
||||
if r is None:
|
||||
return None
|
||||
|
||||
elif isinstance(target, Header):
|
||||
if header_name_to_class(target.name, Header) != desired_output:
|
||||
raise TypeError(
|
||||
f"The target class does not match the desired output class. {target.__class__} != {desired_output}."
|
||||
)
|
||||
r = target
|
||||
else:
|
||||
raise TypeError(f"Unable to apply get_polymorphic on type {target.__class__}.")
|
||||
|
||||
# Change __class__ attribute.
|
||||
if not isinstance(r, list):
|
||||
r.__class__ = desired_output
|
||||
else:
|
||||
for header in r:
|
||||
header.__class__ = desired_output
|
||||
|
||||
return r # type: ignore
|
||||
|
||||
|
||||
def dumps(headers: Headers, **kwargs: Any | None) -> str:
|
||||
return json_dumps(encode(headers), **kwargs) # type: ignore
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .models import Header, Headers
|
||||
|
||||
|
||||
def encode(headers: Headers) -> dict[str, list[dict]]:
|
||||
"""
|
||||
Provide an opinionated but reliable way to encode headers to dict for serialization purposes.
|
||||
"""
|
||||
result: dict[str, list[dict]] = dict()
|
||||
|
||||
for header in headers:
|
||||
if header.name not in result:
|
||||
result[header.name] = list()
|
||||
|
||||
encoded_header: dict[str, str | None | list[str]] = dict()
|
||||
|
||||
for attribute, value in header:
|
||||
if attribute not in encoded_header:
|
||||
encoded_header[attribute] = value
|
||||
continue
|
||||
|
||||
if isinstance(encoded_header[attribute], list) is False:
|
||||
# Here encoded_header[attribute] most certainly is str
|
||||
# Had to silent mypy error.
|
||||
encoded_header[attribute] = [encoded_header[attribute]] # type: ignore
|
||||
|
||||
encoded_header[attribute].append(value) # type: ignore
|
||||
|
||||
result[header.name].append(encoded_header)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def decode(encoded_headers: dict[str, list[dict]]) -> Headers:
|
||||
"""
|
||||
Decode any previously encoded headers to a Headers object.
|
||||
"""
|
||||
headers: Headers = Headers()
|
||||
|
||||
for header_name, encoded_header_list in encoded_headers.items():
|
||||
if not isinstance(encoded_header_list, list):
|
||||
raise ValueError("Decode require first level values to be List")
|
||||
|
||||
for encoded_header in encoded_header_list:
|
||||
if not isinstance(encoded_header, dict):
|
||||
raise ValueError("Decode require each list element to be Dict")
|
||||
|
||||
header = Header(header_name, "")
|
||||
|
||||
for attr, value in encoded_header.items():
|
||||
if value is None:
|
||||
header += attr
|
||||
continue
|
||||
if isinstance(value, str):
|
||||
header[attr] = value
|
||||
continue
|
||||
|
||||
for sub_value in value:
|
||||
header.insert(-1, **{attr: sub_value})
|
||||
|
||||
headers += header
|
||||
|
||||
return headers
|
||||
@@ -0,0 +1,95 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Mapping, MutableMapping
|
||||
from typing import (
|
||||
Any,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
from typing import (
|
||||
MutableMapping as MutableMappingType,
|
||||
)
|
||||
|
||||
from .utils import normalize_str
|
||||
|
||||
"""
|
||||
Disclaimer : CaseInsensitiveDict has been borrowed from `psf/requests`.
|
||||
Minors changes has been made.
|
||||
"""
|
||||
|
||||
|
||||
class CaseInsensitiveDict(MutableMapping):
|
||||
"""A case-insensitive ``dict``-like object.
|
||||
|
||||
Implements all methods and operations of
|
||||
``MutableMapping`` as well as dict's ``copy``. Also
|
||||
provides ``lower_items``.
|
||||
|
||||
All keys are expected to be strings. The structure remembers the
|
||||
case of the last key to be set, and ``iter(instance)``,
|
||||
``keys()``, ``items()``, ``iterkeys()``, and ``iteritems()``
|
||||
will contain case-sensitive keys. However, querying and contains
|
||||
testing is case insensitive::
|
||||
|
||||
cid = CaseInsensitiveDict()
|
||||
cid['Accept'] = 'application/json'
|
||||
cid['aCCEPT'] == 'application/json' # True
|
||||
list(cid) == ['Accept'] # True
|
||||
|
||||
For example, ``headers['content-encoding']`` will return the
|
||||
value of a ``'Content-Encoding'`` response header, regardless
|
||||
of how the header name was originally stored.
|
||||
|
||||
If the constructor, ``.update``, or equality comparison
|
||||
operations are given keys that have equal ``.lower()``s, the
|
||||
behavior is undefined.
|
||||
"""
|
||||
|
||||
def __init__(self, data: Mapping | None = None, **kwargs: Any):
|
||||
self._store: OrderedDict = OrderedDict()
|
||||
if data is None:
|
||||
data = {}
|
||||
self.update(data, **kwargs)
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
# Use the lowercased key for lookups, but store the actual
|
||||
# key alongside the value.
|
||||
self._store[normalize_str(key)] = (key, value)
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self._store[normalize_str(key)][1]
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
del self._store[normalize_str(key)]
|
||||
|
||||
def __iter__(self) -> Iterator[tuple[str, Any]]:
|
||||
return (casedkey for casedkey, mappedvalue in self._store.values())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._store)
|
||||
|
||||
def lower_items(self) -> Iterator[tuple[str, Any]]:
|
||||
"""Like iteritems(), but with all lowercase keys."""
|
||||
return ((lowerkey, keyval[1]) for (lowerkey, keyval) in self._store.items())
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, Mapping):
|
||||
other = CaseInsensitiveDict(other)
|
||||
else:
|
||||
return NotImplemented
|
||||
# Compare insensitively
|
||||
return dict(self.lower_items()) == dict(other.lower_items())
|
||||
|
||||
# Copy is required
|
||||
def copy(self) -> CaseInsensitiveDict:
|
||||
return CaseInsensitiveDict(dict(self._store.values()))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(dict(self.items()))
|
||||
|
||||
|
||||
AttributeDescription = Tuple[List[Optional[str]], List[int]]
|
||||
AttributeBag = MutableMappingType[str, AttributeDescription]
|
||||
@@ -0,0 +1,487 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from email.header import decode_header
|
||||
from json import dumps
|
||||
from re import findall, search, sub
|
||||
from typing import Any, Iterable
|
||||
|
||||
RESERVED_KEYWORD: set[str] = {
|
||||
"and_",
|
||||
"assert_",
|
||||
"in_",
|
||||
"not_",
|
||||
"pass_",
|
||||
"finally_",
|
||||
"while_",
|
||||
"yield_",
|
||||
"is_",
|
||||
"as_",
|
||||
"break_",
|
||||
"return_",
|
||||
"elif_",
|
||||
"except_",
|
||||
"def_",
|
||||
"from_",
|
||||
"for_",
|
||||
}
|
||||
|
||||
|
||||
def normalize_str(string: str) -> str:
|
||||
"""
|
||||
Normalize a string by applying on it lowercase and replacing '-' to '_'.
|
||||
>>> normalize_str("Content-Type")
|
||||
'content_type'
|
||||
>>> normalize_str("X-content-type")
|
||||
'x_content_type'
|
||||
"""
|
||||
return string.lower().replace("-", "_")
|
||||
|
||||
|
||||
def normalize_list(strings: list[str]) -> list[str]:
|
||||
"""Normalize a list of string by applying fn normalize_str over each element."""
|
||||
return list(map(normalize_str, strings))
|
||||
|
||||
|
||||
def unpack_protected_keyword(name: str) -> str:
|
||||
"""
|
||||
By choice, this project aims to allow developers to access header or attribute in header by using the property
|
||||
notation. Some keywords are protected by the language itself. So :
|
||||
When starting by a number, prepend an underscore to it. When using a protected keyword, append an underscore to it.
|
||||
>>> unpack_protected_keyword("_3to1")
|
||||
'3to1'
|
||||
>>> unpack_protected_keyword("from_")
|
||||
'from'
|
||||
>>> unpack_protected_keyword("_from")
|
||||
'_from'
|
||||
>>> unpack_protected_keyword("3")
|
||||
'3'
|
||||
>>> unpack_protected_keyword("FroM_")
|
||||
'FroM_'
|
||||
"""
|
||||
if len(name) < 2:
|
||||
return name
|
||||
|
||||
if name[0] == "_" and name[1].isdigit():
|
||||
name = name[1:]
|
||||
|
||||
if name in RESERVED_KEYWORD:
|
||||
name = name[:-1]
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def extract_class_name(type_: type) -> str | None:
|
||||
"""
|
||||
Typically extract a class name from a Type.
|
||||
"""
|
||||
r = findall(r"<class '([a-zA-Z0-9._]+)'>", str(type_))
|
||||
return r[0] if r else None
|
||||
|
||||
|
||||
def header_content_split(string: str, delimiter: str) -> list[str]:
|
||||
"""
|
||||
Take a string and split it according to the passed delimiter.
|
||||
It will ignore delimiter if inside between double quote, inside a value, or in parenthesis.
|
||||
The input string is considered perfectly formed. This function does not split coma on a day
|
||||
when attached, see "RFC 7231, section 7.1.1.2: Date".
|
||||
>>> header_content_split("Wed, 15-Apr-2020 21:27:31 GMT, Fri, 01-Jan-2038 00:00:00 GMT", ",")
|
||||
['Wed, 15-Apr-2020 21:27:31 GMT', 'Fri, 01-Jan-2038 00:00:00 GMT']
|
||||
>>> header_content_split('quic=":443"; ma=2592000; v="46,43", h3-Q050=":443"; ma=2592000, h3-Q049=":443"; ma=2592000', ",")
|
||||
['quic=":443"; ma=2592000; v="46,43"', 'h3-Q050=":443"; ma=2592000', 'h3-Q049=":443"; ma=2592000']
|
||||
>>> header_content_split("Mozilla/5.0 (Macintosh; Intel Mac OS X 10.9; rv:50.0) Gecko/20100101 Firefox/50.0", ";")
|
||||
['Mozilla/5.0 (Macintosh; Intel Mac OS X 10.9; rv:50.0) Gecko/20100101 Firefox/50.0']
|
||||
>>> header_content_split("text/html; charset=UTF-8", ";")
|
||||
['text/html', 'charset=UTF-8']
|
||||
>>> header_content_split('text/html; charset="UTF-\\\"8"', ";")
|
||||
['text/html', 'charset="UTF-"8"']
|
||||
"""
|
||||
if len(delimiter) != 1 or delimiter not in {";", ",", " "}:
|
||||
raise ValueError("Delimiter should be either semi-colon, a coma or a space.")
|
||||
|
||||
in_double_quote: bool = False
|
||||
in_parenthesis: bool = False
|
||||
in_value: bool = False
|
||||
is_on_a_day: bool = False
|
||||
|
||||
result: list[str] = [""]
|
||||
|
||||
for letter, index in zip(string, range(0, len(string))):
|
||||
if letter == '"':
|
||||
in_double_quote = not in_double_quote
|
||||
|
||||
if in_value and not in_double_quote:
|
||||
in_value = False
|
||||
|
||||
elif letter == "(" and not in_parenthesis:
|
||||
in_parenthesis = True
|
||||
elif letter == ")" and in_parenthesis:
|
||||
in_parenthesis = False
|
||||
else:
|
||||
is_on_a_day = index >= 3 and string[index - 3 : index] in {
|
||||
"Mon",
|
||||
"Tue",
|
||||
"Wed",
|
||||
"Thu",
|
||||
"Fri",
|
||||
"Sat",
|
||||
"Sun",
|
||||
}
|
||||
|
||||
if not in_double_quote:
|
||||
if not in_value and letter == "=":
|
||||
in_value = True
|
||||
elif letter == ";" and in_value:
|
||||
in_value = False
|
||||
|
||||
if in_value and letter == delimiter and not is_on_a_day:
|
||||
in_value = False
|
||||
|
||||
if letter == delimiter and (
|
||||
(in_value or in_double_quote or in_parenthesis or is_on_a_day) is False
|
||||
):
|
||||
result[-1] = result[-1].lstrip().rstrip()
|
||||
result.append("")
|
||||
|
||||
continue
|
||||
|
||||
result[-1] += letter
|
||||
|
||||
if result:
|
||||
result[-1] = result[-1].lstrip().rstrip()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def class_to_header_name(type_: type) -> str:
|
||||
"""
|
||||
Take a type and infer its header name.
|
||||
>>> from kiss_headers.builder import ContentType, XContentTypeOptions, BasicAuthorization
|
||||
>>> class_to_header_name(ContentType)
|
||||
'Content-Type'
|
||||
>>> class_to_header_name(XContentTypeOptions)
|
||||
'X-Content-Type-Options'
|
||||
>>> class_to_header_name(BasicAuthorization)
|
||||
'Authorization'
|
||||
"""
|
||||
if hasattr(type_, "__override__") and type_.__override__ is not None:
|
||||
return type_.__override__
|
||||
|
||||
class_raw_name: str = str(type_).split("'")[-2].split(".")[-1]
|
||||
|
||||
if class_raw_name.endswith("_"):
|
||||
class_raw_name = class_raw_name[:-1]
|
||||
|
||||
if class_raw_name.startswith("_"):
|
||||
class_raw_name = class_raw_name[1:]
|
||||
|
||||
header_name: str = ""
|
||||
|
||||
for letter in class_raw_name:
|
||||
if letter.isupper() and header_name != "":
|
||||
header_name += "-" + letter
|
||||
continue
|
||||
header_name += letter
|
||||
|
||||
return header_name
|
||||
|
||||
|
||||
def header_name_to_class(name: str, root_type: type) -> type:
|
||||
"""
|
||||
The opposite of class_to_header_name function. Will raise TypeError if no corresponding entry is found.
|
||||
Do it recursively from the root type.
|
||||
>>> from kiss_headers.builder import CustomHeader, ContentType, XContentTypeOptions, LastModified, Date
|
||||
>>> header_name_to_class("Content-Type", CustomHeader)
|
||||
<class 'kiss_headers.builder.ContentType'>
|
||||
>>> header_name_to_class("Last-Modified", CustomHeader)
|
||||
<class 'kiss_headers.builder.LastModified'>
|
||||
"""
|
||||
|
||||
normalized_name = normalize_str(name).replace("_", "")
|
||||
|
||||
for subclass in root_type.__subclasses__():
|
||||
class_name = extract_class_name(subclass)
|
||||
|
||||
if class_name is None:
|
||||
continue
|
||||
|
||||
if (
|
||||
not (
|
||||
hasattr(subclass, "__override__") and subclass.__override__ is not None
|
||||
)
|
||||
and normalize_str(class_name.split(".")[-1]) == normalized_name
|
||||
):
|
||||
return subclass
|
||||
|
||||
if subclass.__subclasses__():
|
||||
try:
|
||||
return header_name_to_class(name, subclass)
|
||||
except TypeError:
|
||||
continue
|
||||
|
||||
raise TypeError(f"Cannot find a class matching header named '{name}'.")
|
||||
|
||||
|
||||
def prettify_header_name(name: str) -> str:
|
||||
"""
|
||||
Take a header name and prettify it.
|
||||
>>> prettify_header_name("x-hEllo-wORLD")
|
||||
'X-Hello-World'
|
||||
>>> prettify_header_name("server")
|
||||
'Server'
|
||||
>>> prettify_header_name("contEnt-TYPE")
|
||||
'Content-Type'
|
||||
>>> prettify_header_name("content_type")
|
||||
'Content-Type'
|
||||
"""
|
||||
return "-".join([el.capitalize() for el in name.replace("_", "-").split("-")])
|
||||
|
||||
|
||||
def decode_partials(items: Iterable[tuple[str, Any]]) -> list[tuple[str, str]]:
|
||||
"""
|
||||
This function takes a list of tuples, representing headers by key, value. Where value is bytes or string containing
|
||||
(RFC 2047 encoded) partials fragments like the following :
|
||||
>>> decode_partials([("Subject", "=?iso-8859-1?q?p=F6stal?=")])
|
||||
[('Subject', 'pöstal')]
|
||||
"""
|
||||
revised_items: list[tuple[str, str]] = list()
|
||||
|
||||
for head, content in items:
|
||||
revised_content: str = ""
|
||||
|
||||
for partial, partial_encoding in decode_header(content):
|
||||
if isinstance(partial, str):
|
||||
revised_content += partial
|
||||
if isinstance(partial, bytes):
|
||||
revised_content += partial.decode(
|
||||
partial_encoding if partial_encoding is not None else "utf-8",
|
||||
errors="ignore",
|
||||
)
|
||||
|
||||
revised_items.append((head, revised_content))
|
||||
|
||||
return revised_items
|
||||
|
||||
|
||||
def unquote(string: str) -> str:
|
||||
"""
|
||||
Remove simple quote or double quote around a string if any.
|
||||
>>> unquote('"hello"')
|
||||
'hello'
|
||||
>>> unquote('"hello')
|
||||
'"hello'
|
||||
>>> unquote('"a"')
|
||||
'a'
|
||||
>>> unquote('""')
|
||||
''
|
||||
"""
|
||||
if (
|
||||
len(string) >= 2
|
||||
and (string.startswith('"') and string.endswith('"'))
|
||||
or (string.startswith("'") and string.endswith("'"))
|
||||
):
|
||||
return string[1:-1]
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def quote(string: str) -> str:
|
||||
"""
|
||||
Surround string by a double quote char.
|
||||
>>> quote("hello")
|
||||
'"hello"'
|
||||
>>> quote('"hello')
|
||||
'""hello"'
|
||||
>>> quote('"hello"')
|
||||
'"hello"'
|
||||
"""
|
||||
return '"' + unquote(string) + '"'
|
||||
|
||||
|
||||
def count_leftover_space(content: str) -> int:
|
||||
"""
|
||||
A recursive function that counts trailing white space at the end of the given string.
|
||||
>>> count_leftover_space("hello ")
|
||||
3
|
||||
>>> count_leftover_space("byebye ")
|
||||
1
|
||||
>>> count_leftover_space(" hello ")
|
||||
1
|
||||
>>> count_leftover_space(" hello ")
|
||||
4
|
||||
"""
|
||||
if content.endswith(" "):
|
||||
return count_leftover_space(content[:-1]) + 1
|
||||
return 0
|
||||
|
||||
|
||||
def header_strip(content: str, elem: str) -> str:
|
||||
"""
|
||||
Remove a member for a given header content and take care of the unneeded leftover semi-colon.
|
||||
>>> header_strip("text/html; charset=UTF-8; format=flowed", "charset=UTF-8")
|
||||
'text/html; format=flowed'
|
||||
>>> header_strip("text/html; charset=UTF-8; format=flowed", "charset=UTF-8")
|
||||
'text/html; format=flowed'
|
||||
"""
|
||||
next_semi_colon_index: int | None = None
|
||||
|
||||
try:
|
||||
elem_index: int = content.index(elem)
|
||||
except ValueError:
|
||||
# If the target element in not found within the content, just return the unmodified content.
|
||||
return content
|
||||
|
||||
elem_end_index: int = elem_index + len(elem)
|
||||
|
||||
elem = (" " * count_leftover_space(content[:elem_index])) + elem
|
||||
|
||||
try:
|
||||
next_semi_colon_index = elem_end_index + content[elem_end_index:].index(";")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
content = (
|
||||
content.replace(
|
||||
elem
|
||||
+ (
|
||||
content[elem_end_index:next_semi_colon_index] + ";"
|
||||
if next_semi_colon_index is not None
|
||||
else ""
|
||||
),
|
||||
"",
|
||||
)
|
||||
.rstrip(" ")
|
||||
.lstrip(" ")
|
||||
)
|
||||
|
||||
if content.startswith(";"):
|
||||
content = content[1:]
|
||||
|
||||
if content.endswith(";"):
|
||||
content = content[:-1]
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def is_legal_header_name(name: str) -> bool:
|
||||
"""
|
||||
Verify if a provided header name is valid.
|
||||
>>> is_legal_header_name(":hello")
|
||||
False
|
||||
>>> is_legal_header_name("hello")
|
||||
True
|
||||
>>> is_legal_header_name("Content-Type")
|
||||
True
|
||||
>>> is_legal_header_name("Hello;")
|
||||
False
|
||||
>>> is_legal_header_name("Hello\\rWorld")
|
||||
False
|
||||
>>> is_legal_header_name("Hello \\tWorld")
|
||||
False
|
||||
>>> is_legal_header_name('Hello World"')
|
||||
False
|
||||
>>> is_legal_header_name("Hello-World/")
|
||||
True
|
||||
>>> is_legal_header_name("\x07")
|
||||
False
|
||||
"""
|
||||
return (
|
||||
name != ""
|
||||
and search(r"[^\x21-\x7F]|[:;(),<>=@?\[\]\r\n\t &{}\"\\]", name) is None
|
||||
)
|
||||
|
||||
|
||||
def extract_comments(content: str) -> list[str]:
|
||||
"""
|
||||
Extract parts of content that are considered as comments. Between parenthesis.
|
||||
>>> extract_comments("Mozilla/5.0 (Macintosh; Intel Mac OS X 10.9; rv:50.0) Gecko/20100101 Firefox/50.0 (hello) llll (abc)")
|
||||
['Macintosh; Intel Mac OS X 10.9; rv:50.0', 'hello', 'abc']
|
||||
"""
|
||||
return findall(r"\(([^)]+)\)", content)
|
||||
|
||||
|
||||
def unfold(content: str) -> str:
|
||||
r"""Some header content may have folded content (CRLF + n spaces) in it, making your job at reading them a little more difficult.
|
||||
This function undoes the folding in the given content.
|
||||
>>> unfold("___utmvbtouVBFmB=gZg\r\n XbNOjalT: Lte; path=/; Max-Age=900")
|
||||
'___utmvbtouVBFmB=gZg XbNOjalT: Lte; path=/; Max-Age=900'
|
||||
"""
|
||||
return sub(r"\r\n[ ]+", " ", content)
|
||||
|
||||
|
||||
def extract_encoded_headers(payload: bytes) -> tuple[str, bytes]:
|
||||
"""This function's purpose is to extract lines that can be decoded using the UTF-8 decoder.
|
||||
>>> extract_encoded_headers("Host: developer.mozilla.org\\r\\nX-Hello-World: 死の漢字\\r\\n\\r\\n".encode("utf-8"))
|
||||
('Host: developer.mozilla.org\\r\\nX-Hello-World: 死の漢字\\r\\n', b'')
|
||||
>>> extract_encoded_headers("Host: developer.mozilla.org\\r\\nX-Hello-World: 死の漢字\\r\\n\\r\\nThat IS totally random.".encode("utf-8"))
|
||||
('Host: developer.mozilla.org\\r\\nX-Hello-World: 死の漢字\\r\\n', b'That IS totally random.')
|
||||
"""
|
||||
result: str = ""
|
||||
lines: list[bytes] = payload.splitlines()
|
||||
index: int = 0
|
||||
|
||||
for line, index in zip(lines, range(0, len(lines))):
|
||||
if line == b"":
|
||||
return result, b"\r\n".join(lines[index + 1 :])
|
||||
|
||||
try:
|
||||
result += line.decode("utf-8") + "\r\n"
|
||||
except UnicodeDecodeError:
|
||||
break
|
||||
|
||||
return result, b"\r\n".join(lines[index + 1 :])
|
||||
|
||||
|
||||
def unescape_double_quote(content: str) -> str:
|
||||
"""
|
||||
Replace escaped double quote in content by removing the backslash.
|
||||
>>> unescape_double_quote(r'UTF\"-8')
|
||||
'UTF"-8'
|
||||
>>> unescape_double_quote(r'UTF"-8')
|
||||
'UTF"-8'
|
||||
"""
|
||||
return content.replace(r"\"", '"')
|
||||
|
||||
|
||||
def escape_double_quote(content: str) -> str:
|
||||
r"""
|
||||
Replace not escaped double quote in content by adding a backslash beforehand.
|
||||
>>> escape_double_quote(r'UTF\"-8')
|
||||
'UTF\\"-8'
|
||||
>>> escape_double_quote(r'UTF"-8')
|
||||
'UTF\\"-8'
|
||||
"""
|
||||
return unescape_double_quote(content).replace('"', r"\"")
|
||||
|
||||
|
||||
def is_content_json_object(content: str) -> bool:
|
||||
"""
|
||||
Sometime, you may receive a header that hold a JSON list or object.
|
||||
This function detect it.
|
||||
"""
|
||||
content = content.strip()
|
||||
return (content.startswith("{") and content.endswith("}")) or (
|
||||
content.startswith("[") and content.endswith("]")
|
||||
)
|
||||
|
||||
|
||||
def transform_possible_encoded(
|
||||
headers: Iterable[tuple[str | bytes, str | bytes]],
|
||||
) -> Iterable[tuple[str, str]]:
|
||||
decoded = []
|
||||
|
||||
for k, v in headers:
|
||||
# we shall discard it if set to None.
|
||||
if v is None:
|
||||
continue
|
||||
if isinstance(k, bytes):
|
||||
k = k.decode("utf_8")
|
||||
if isinstance(v, bytes):
|
||||
v = v.decode("utf_8")
|
||||
elif isinstance(v, str) is False:
|
||||
if isinstance(v, (dict, list)):
|
||||
v = dumps(v)
|
||||
else:
|
||||
v = str(v)
|
||||
decoded.append((k, v))
|
||||
|
||||
return decoded
|
||||
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Expose version
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__version__ = "2.5.0"
|
||||
VERSION = __version__.split(".")
|
||||
2410
.venv/lib/python3.9/site-packages/niquests/adapters.py
Normal file
2410
.venv/lib/python3.9/site-packages/niquests/adapters.py
Normal file
File diff suppressed because it is too large
Load Diff
638
.venv/lib/python3.9/site-packages/niquests/api.py
Normal file
638
.venv/lib/python3.9/site-packages/niquests/api.py
Normal file
@@ -0,0 +1,638 @@
|
||||
"""
|
||||
requests.api
|
||||
~~~~~~~~~~~~
|
||||
|
||||
This module implements the Requests API.
|
||||
|
||||
:copyright: (c) 2012 by Kenneth Reitz.
|
||||
:license: Apache2, see LICENSE for more details.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import typing
|
||||
|
||||
from . import sessions
|
||||
from ._constant import DEFAULT_RETRIES, READ_DEFAULT_TIMEOUT, WRITE_DEFAULT_TIMEOUT
|
||||
from .models import PreparedRequest, Response
|
||||
from .structures import QuicSharedCache
|
||||
from .typing import (
|
||||
BodyType,
|
||||
CacheLayerAltSvcType,
|
||||
CookiesType,
|
||||
HeadersType,
|
||||
HookType,
|
||||
HttpAuthenticationType,
|
||||
HttpMethodType,
|
||||
MultiPartFilesAltType,
|
||||
MultiPartFilesType,
|
||||
ProxyType,
|
||||
QueryParameterType,
|
||||
RetryType,
|
||||
TimeoutType,
|
||||
TLSClientCertType,
|
||||
TLSVerifyType,
|
||||
)
|
||||
|
||||
_SHARED_OCSP_CACHE: typing.Any | None = None
|
||||
_SHARED_CRL_CACHE: typing.Any | None = None
|
||||
_SHARED_QUIC_CACHE: CacheLayerAltSvcType = QuicSharedCache(max_size=12_288)
|
||||
_SHARED_CACHE_LOCK: threading.RLock = threading.RLock()
|
||||
|
||||
|
||||
def request(
|
||||
method: HttpMethodType,
|
||||
url: str,
|
||||
*,
|
||||
params: QueryParameterType | None = None,
|
||||
data: BodyType | None = None,
|
||||
json: typing.Any | None = None,
|
||||
headers: HeadersType | None = None,
|
||||
cookies: CookiesType | None = None,
|
||||
files: MultiPartFilesType | MultiPartFilesAltType | None = None,
|
||||
auth: HttpAuthenticationType | None = None,
|
||||
timeout: TimeoutType | None = WRITE_DEFAULT_TIMEOUT,
|
||||
allow_redirects: bool = True,
|
||||
proxies: ProxyType | None = None,
|
||||
verify: TLSVerifyType = True,
|
||||
stream: bool = False,
|
||||
cert: TLSClientCertType | None = None,
|
||||
hooks: HookType[PreparedRequest | Response] | None = None,
|
||||
retries: RetryType = DEFAULT_RETRIES,
|
||||
) -> Response:
|
||||
"""Constructs and sends a :class:`Request <Request>`.
|
||||
|
||||
This does not keep the connection alive. Use a :class:`Session` to reuse the connection.
|
||||
|
||||
:param method: method for the new :class:`Request` object: ``GET``, ``OPTIONS``, ``HEAD``, ``POST``, ``PUT``, ``PATCH``,
|
||||
or ``DELETE``.
|
||||
:param url: URL for the new :class:`Request` object.
|
||||
:param params: (optional) Dictionary, list of tuples or bytes to send
|
||||
in the query string for the :class:`Request`.
|
||||
:param data: (optional) Dictionary, list of tuples, bytes, or file-like
|
||||
object to send in the body of the :class:`Request`.
|
||||
:param json: (optional) A JSON serializable Python object to send in the body of the :class:`Request`.
|
||||
:param headers: (optional) Dictionary of HTTP Headers to send with the :class:`Request`.
|
||||
:param cookies: (optional) Dict or CookieJar object to send with the :class:`Request`.
|
||||
:param files: (optional) Dictionary of ``'name': file-like-objects`` (or ``{'name': file-tuple}``)
|
||||
for multipart encoding upload.
|
||||
``file-tuple`` can be a 2-tuple ``('filename', fileobj)``, 3-tuple ``('filename', fileobj, 'content_type')``
|
||||
or a 4-tuple ``('filename', fileobj, 'content_type', custom_headers)``, where ``'content_type'`` is a string
|
||||
defining the content type of the given file and ``custom_headers`` a dict-like object containing additional headers
|
||||
to add for the file.
|
||||
:param auth: (optional) Auth tuple to enable Basic/Digest/Custom HTTP Auth.
|
||||
:param timeout: (optional) How many seconds to wait for the server to send data
|
||||
before giving up, as a float, or a :ref:`(connect timeout, read
|
||||
timeout) <timeouts>` tuple.
|
||||
:param allow_redirects: (optional) Boolean. Enable/disable GET/OPTIONS/POST/PUT/PATCH/DELETE/HEAD redirection.
|
||||
Defaults to ``True``.
|
||||
:param proxies: (optional) Dictionary mapping protocol to the URL of the proxy.
|
||||
:param verify: (optional) Either a boolean, in which case it controls whether we verify
|
||||
the server's TLS certificate, or a path passed as a string or os.Pathlike object,
|
||||
in which case it must be a path to a CA bundle to use.
|
||||
Defaults to ``True``.
|
||||
It is also possible to put the certificates (directly) in a string or bytes.
|
||||
:param stream: (optional) if ``False``, the response content will be immediately downloaded.
|
||||
:param cert: (optional) if String, path to ssl client cert file (.pem).
|
||||
If Tuple, ('cert', 'key') pair, or ('cert', 'key', 'key_password').
|
||||
:param hooks: (optional) Register functions that should be called at very specific moment in the request lifecycle.
|
||||
:param retries: (optional) If integer, determine the number of retry in case of a timeout or connection error.
|
||||
Otherwise, for fine gained retry, use directly a ``Retry`` instance from urllib3.
|
||||
:return: :class:`Response <Response>` object
|
||||
|
||||
Usage::
|
||||
|
||||
>>> import niquests
|
||||
>>> req = niquests.request('GET', 'https://httpbin.org/get')
|
||||
>>> req
|
||||
<Response HTTP/2 [200]>
|
||||
"""
|
||||
|
||||
# By using the 'with' statement we are sure the session is closed, thus we
|
||||
# avoid leaving sockets open which can trigger a ResourceWarning in some
|
||||
# cases, and look like a memory leak in others.
|
||||
global _SHARED_OCSP_CACHE, _SHARED_CRL_CACHE
|
||||
with sessions.Session(quic_cache_layer=_SHARED_QUIC_CACHE, retries=retries) as session:
|
||||
session._ocsp_cache = _SHARED_OCSP_CACHE
|
||||
session._crl_cache = _SHARED_CRL_CACHE
|
||||
try:
|
||||
return session.request(
|
||||
method=method,
|
||||
url=url,
|
||||
params=params,
|
||||
data=data,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
files=files,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
proxies=proxies,
|
||||
hooks=hooks,
|
||||
stream=stream,
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
json=json,
|
||||
)
|
||||
finally:
|
||||
with _SHARED_CACHE_LOCK:
|
||||
if _SHARED_OCSP_CACHE is None and session._ocsp_cache is not None:
|
||||
_SHARED_OCSP_CACHE = session._ocsp_cache
|
||||
if _SHARED_CRL_CACHE is None and session._crl_cache is not None:
|
||||
_SHARED_CRL_CACHE = session._crl_cache
|
||||
|
||||
|
||||
def get(
|
||||
url: str,
|
||||
params: QueryParameterType | None = None,
|
||||
*,
|
||||
headers: HeadersType | None = None,
|
||||
cookies: CookiesType | None = None,
|
||||
auth: HttpAuthenticationType | None = None,
|
||||
timeout: TimeoutType | None = READ_DEFAULT_TIMEOUT,
|
||||
allow_redirects: bool = True,
|
||||
proxies: ProxyType | None = None,
|
||||
verify: TLSVerifyType = True,
|
||||
stream: bool = False,
|
||||
cert: TLSClientCertType | None = None,
|
||||
hooks: HookType[PreparedRequest | Response] | None = None,
|
||||
retries: RetryType = DEFAULT_RETRIES,
|
||||
**kwargs: typing.Any,
|
||||
) -> Response:
|
||||
r"""Sends a GET request. This does not keep the connection alive. Use a :class:`Session` to reuse the connection.
|
||||
|
||||
:param url: URL for the new :class:`Request` object.
|
||||
:param params: (optional) Dictionary, list of tuples or bytes to send
|
||||
in the query string for the :class:`Request`.
|
||||
:param params: (optional) Dictionary, list of tuples or bytes to send
|
||||
in the query string for the :class:`Request`.
|
||||
:param headers: (optional) Dictionary of HTTP Headers to send with the :class:`Request`.
|
||||
:param cookies: (optional) Dict or CookieJar object to send with the :class:`Request`.
|
||||
:param auth: (optional) Auth tuple to enable Basic/Digest/Custom HTTP Auth.
|
||||
:param timeout: (optional) How many seconds to wait for the server to send data
|
||||
before giving up, as a float, or a :ref:`(connect timeout, read
|
||||
timeout) <timeouts>` tuple.
|
||||
:param allow_redirects: (optional) Boolean. Enable/disable GET/OPTIONS/POST/PUT/PATCH/DELETE/HEAD redirection.
|
||||
Defaults to ``True``.
|
||||
:param proxies: (optional) Dictionary mapping protocol to the URL of the proxy.
|
||||
:param verify: (optional) Either a boolean, in which case it controls whether we verify
|
||||
the server's TLS certificate, or a path passed as a string or os.Pathlike object,
|
||||
in which case it must be a path to a CA bundle to use.
|
||||
Defaults to ``True``.
|
||||
It is also possible to put the certificates (directly) in a string or bytes.
|
||||
:param stream: (optional) if ``False``, the response content will be immediately downloaded.
|
||||
:param cert: (optional) if String, path to ssl client cert file (.pem).
|
||||
If Tuple, ('cert', 'key') pair, or ('cert', 'key', 'key_password').
|
||||
:param hooks: (optional) Register functions that should be called at very specific moment in the request lifecycle.
|
||||
:param retries: (optional) If integer, determine the number of retry in case of a timeout or connection error.
|
||||
Otherwise, for fine gained retry, use directly a ``Retry`` instance from urllib3.
|
||||
:return: :class:`Response <Response>` object
|
||||
"""
|
||||
|
||||
return request(
|
||||
"GET",
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
proxies=proxies,
|
||||
verify=verify,
|
||||
stream=stream,
|
||||
cert=cert,
|
||||
hooks=hooks,
|
||||
retries=retries,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def options(
|
||||
url: str,
|
||||
*,
|
||||
params: QueryParameterType | None = None,
|
||||
headers: HeadersType | None = None,
|
||||
cookies: CookiesType | None = None,
|
||||
auth: HttpAuthenticationType | None = None,
|
||||
timeout: TimeoutType | None = READ_DEFAULT_TIMEOUT,
|
||||
allow_redirects: bool = True,
|
||||
proxies: ProxyType | None = None,
|
||||
verify: TLSVerifyType = True,
|
||||
stream: bool = False,
|
||||
cert: TLSClientCertType | None = None,
|
||||
hooks: HookType[PreparedRequest | Response] | None = None,
|
||||
retries: RetryType = DEFAULT_RETRIES,
|
||||
**kwargs: typing.Any,
|
||||
) -> Response:
|
||||
r"""Sends an OPTIONS request. This does not keep the connection alive. Use a :class:`Session` to reuse the connection.
|
||||
|
||||
:param url: URL for the new :class:`Request` object.
|
||||
:param params: (optional) Dictionary, list of tuples or bytes to send
|
||||
in the query string for the :class:`Request`.
|
||||
:param headers: (optional) Dictionary of HTTP Headers to send with the :class:`Request`.
|
||||
:param cookies: (optional) Dict or CookieJar object to send with the :class:`Request`.
|
||||
:param auth: (optional) Auth tuple to enable Basic/Digest/Custom HTTP Auth.
|
||||
:param timeout: (optional) How many seconds to wait for the server to send data
|
||||
before giving up, as a float, or a :ref:`(connect timeout, read
|
||||
timeout) <timeouts>` tuple.
|
||||
:param allow_redirects: (optional) Boolean. Enable/disable GET/OPTIONS/POST/PUT/PATCH/DELETE/HEAD redirection.
|
||||
Defaults to ``True``.
|
||||
:param proxies: (optional) Dictionary mapping protocol to the URL of the proxy.
|
||||
:param verify: (optional) Either a boolean, in which case it controls whether we verify
|
||||
the server's TLS certificate, or a path passed as a string or os.Pathlike object,
|
||||
in which case it must be a path to a CA bundle to use.
|
||||
Defaults to ``True``.
|
||||
It is also possible to put the certificates (directly) in a string or bytes.
|
||||
:param stream: (optional) if ``False``, the response content will be immediately downloaded.
|
||||
:param cert: (optional) if String, path to ssl client cert file (.pem).
|
||||
If Tuple, ('cert', 'key') pair, or ('cert', 'key', 'key_password').
|
||||
:param hooks: (optional) Register functions that should be called at very specific moment in the request lifecycle.
|
||||
:param retries: (optional) If integer, determine the number of retry in case of a timeout or connection error.
|
||||
Otherwise, for fine gained retry, use directly a ``Retry`` instance from urllib3.
|
||||
|
||||
:return: :class:`Response <Response>` object
|
||||
"""
|
||||
|
||||
return request(
|
||||
"OPTIONS",
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
proxies=proxies,
|
||||
verify=verify,
|
||||
stream=stream,
|
||||
cert=cert,
|
||||
hooks=hooks,
|
||||
retries=retries,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def head(
|
||||
url: str,
|
||||
*,
|
||||
params: QueryParameterType | None = None,
|
||||
headers: HeadersType | None = None,
|
||||
cookies: CookiesType | None = None,
|
||||
auth: HttpAuthenticationType | None = None,
|
||||
timeout: TimeoutType | None = READ_DEFAULT_TIMEOUT,
|
||||
allow_redirects: bool = False,
|
||||
proxies: ProxyType | None = None,
|
||||
verify: TLSVerifyType = True,
|
||||
stream: bool = False,
|
||||
cert: TLSClientCertType | None = None,
|
||||
hooks: HookType[PreparedRequest | Response] | None = None,
|
||||
retries: RetryType = DEFAULT_RETRIES,
|
||||
**kwargs: typing.Any,
|
||||
) -> Response:
|
||||
r"""Sends a HEAD request. This does not keep the connection alive. Use a :class:`Session` to reuse the connection.
|
||||
|
||||
:param url: URL for the new :class:`Request` object.
|
||||
:param params: (optional) Dictionary, list of tuples or bytes to send
|
||||
in the query string for the :class:`Request`.
|
||||
:param headers: (optional) Dictionary of HTTP Headers to send with the :class:`Request`.
|
||||
:param cookies: (optional) Dict or CookieJar object to send with the :class:`Request`.
|
||||
:param auth: (optional) Auth tuple to enable Basic/Digest/Custom HTTP Auth.
|
||||
:param timeout: (optional) How many seconds to wait for the server to send data
|
||||
before giving up, as a float, or a :ref:`(connect timeout, read
|
||||
timeout) <timeouts>` tuple.
|
||||
:param allow_redirects: (optional) Boolean. Enable/disable GET/OPTIONS/POST/PUT/PATCH/DELETE/HEAD redirection.
|
||||
Defaults to ``True``.
|
||||
:param proxies: (optional) Dictionary mapping protocol to the URL of the proxy.
|
||||
:param verify: (optional) Either a boolean, in which case it controls whether we verify
|
||||
the server's TLS certificate, or a path passed as a string or os.Pathlike object,
|
||||
in which case it must be a path to a CA bundle to use.
|
||||
Defaults to ``True``.
|
||||
It is also possible to put the certificates (directly) in a string or bytes.
|
||||
:param stream: (optional) if ``False``, the response content will be immediately downloaded.
|
||||
:param cert: (optional) if String, path to ssl client cert file (.pem).
|
||||
If Tuple, ('cert', 'key') pair, or ('cert', 'key', 'key_password').
|
||||
:param hooks: (optional) Register functions that should be called at very specific moment in the request lifecycle.
|
||||
:param retries: (optional) If integer, determine the number of retry in case of a timeout or connection error.
|
||||
Otherwise, for fine gained retry, use directly a ``Retry`` instance from urllib3.
|
||||
|
||||
:return: :class:`Response <Response>` object
|
||||
"""
|
||||
|
||||
return request(
|
||||
"HEAD",
|
||||
url,
|
||||
allow_redirects=allow_redirects,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
proxies=proxies,
|
||||
verify=verify,
|
||||
stream=stream,
|
||||
cert=cert,
|
||||
hooks=hooks,
|
||||
retries=retries,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def post(
|
||||
url: str,
|
||||
data: BodyType | None = None,
|
||||
json: typing.Any | None = None,
|
||||
*,
|
||||
params: QueryParameterType | None = None,
|
||||
headers: HeadersType | None = None,
|
||||
cookies: CookiesType | None = None,
|
||||
files: MultiPartFilesType | MultiPartFilesAltType | None = None,
|
||||
auth: HttpAuthenticationType | None = None,
|
||||
timeout: TimeoutType | None = WRITE_DEFAULT_TIMEOUT,
|
||||
allow_redirects: bool = True,
|
||||
proxies: ProxyType | None = None,
|
||||
verify: TLSVerifyType = True,
|
||||
stream: bool = False,
|
||||
cert: TLSClientCertType | None = None,
|
||||
hooks: HookType[PreparedRequest | Response] | None = None,
|
||||
retries: RetryType = DEFAULT_RETRIES,
|
||||
) -> Response:
|
||||
r"""Sends a POST request. This does not keep the connection alive. Use a :class:`Session` to reuse the connection.
|
||||
|
||||
:param url: URL for the new :class:`Request` object.
|
||||
:param params: (optional) Dictionary, list of tuples or bytes to send
|
||||
in the query string for the :class:`Request`.
|
||||
:param data: (optional) Dictionary, list of tuples, bytes, or file-like
|
||||
object to send in the body of the :class:`Request`.
|
||||
:param json: (optional) A JSON serializable Python object to send in the body of the :class:`Request`.
|
||||
:param headers: (optional) Dictionary of HTTP Headers to send with the :class:`Request`.
|
||||
:param cookies: (optional) Dict or CookieJar object to send with the :class:`Request`.
|
||||
:param files: (optional) Dictionary of ``'name': file-like-objects`` (or ``{'name': file-tuple}``)
|
||||
for multipart encoding upload.
|
||||
``file-tuple`` can be a 2-tuple ``('filename', fileobj)``, 3-tuple ``('filename', fileobj, 'content_type')``
|
||||
or a 4-tuple ``('filename', fileobj, 'content_type', custom_headers)``, where ``'content_type'`` is a string
|
||||
defining the content type of the given file and ``custom_headers`` a dict-like object containing additional headers
|
||||
to add for the file.
|
||||
:param auth: (optional) Auth tuple to enable Basic/Digest/Custom HTTP Auth.
|
||||
:param timeout: (optional) How many seconds to wait for the server to send data
|
||||
before giving up, as a float, or a :ref:`(connect timeout, read
|
||||
timeout) <timeouts>` tuple.
|
||||
:param allow_redirects: (optional) Boolean. Enable/disable GET/OPTIONS/POST/PUT/PATCH/DELETE/HEAD redirection.
|
||||
Defaults to ``True``.
|
||||
:param proxies: (optional) Dictionary mapping protocol to the URL of the proxy.
|
||||
:param verify: (optional) Either a boolean, in which case it controls whether we verify
|
||||
the server's TLS certificate, or a path passed as a string or os.Pathlike object,
|
||||
in which case it must be a path to a CA bundle to use.
|
||||
Defaults to ``True``.
|
||||
It is also possible to put the certificates (directly) in a string or bytes.
|
||||
:param stream: (optional) if ``False``, the response content will be immediately downloaded.
|
||||
:param cert: (optional) if String, path to ssl client cert file (.pem).
|
||||
If Tuple, ('cert', 'key') pair, or ('cert', 'key', 'key_password').
|
||||
:param hooks: (optional) Register functions that should be called at very specific moment in the request lifecycle.
|
||||
:param retries: (optional) If integer, determine the number of retry in case of a timeout or connection error.
|
||||
Otherwise, for fine gained retry, use directly a ``Retry`` instance from urllib3.
|
||||
|
||||
:return: :class:`Response <Response>` object
|
||||
"""
|
||||
|
||||
return request(
|
||||
"POST",
|
||||
url,
|
||||
data=data,
|
||||
json=json,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
files=files,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
proxies=proxies,
|
||||
verify=verify,
|
||||
stream=stream,
|
||||
cert=cert,
|
||||
hooks=hooks,
|
||||
retries=retries,
|
||||
)
|
||||
|
||||
|
||||
def put(
|
||||
url: str,
|
||||
data: BodyType | None = None,
|
||||
*,
|
||||
json: typing.Any | None = None,
|
||||
params: QueryParameterType | None = None,
|
||||
headers: HeadersType | None = None,
|
||||
cookies: CookiesType | None = None,
|
||||
files: MultiPartFilesType | MultiPartFilesAltType | None = None,
|
||||
auth: HttpAuthenticationType | None = None,
|
||||
timeout: TimeoutType | None = WRITE_DEFAULT_TIMEOUT,
|
||||
allow_redirects: bool = True,
|
||||
proxies: ProxyType | None = None,
|
||||
verify: TLSVerifyType = True,
|
||||
stream: bool = False,
|
||||
cert: TLSClientCertType | None = None,
|
||||
hooks: HookType[PreparedRequest | Response] | None = None,
|
||||
retries: RetryType = DEFAULT_RETRIES,
|
||||
) -> Response:
|
||||
r"""Sends a PUT request. This does not keep the connection alive. Use a :class:`Session` to reuse the connection.
|
||||
|
||||
:param url: URL for the new :class:`Request` object.
|
||||
:param params: (optional) Dictionary, list of tuples or bytes to send
|
||||
in the query string for the :class:`Request`.
|
||||
:param data: (optional) Dictionary, list of tuples, bytes, or file-like
|
||||
object to send in the body of the :class:`Request`.
|
||||
:param json: (optional) A JSON serializable Python object to send in the body of the :class:`Request`.
|
||||
:param headers: (optional) Dictionary of HTTP Headers to send with the :class:`Request`.
|
||||
:param cookies: (optional) Dict or CookieJar object to send with the :class:`Request`.
|
||||
:param files: (optional) Dictionary of ``'name': file-like-objects`` (or ``{'name': file-tuple}``)
|
||||
for multipart encoding upload.
|
||||
``file-tuple`` can be a 2-tuple ``('filename', fileobj)``, 3-tuple ``('filename', fileobj, 'content_type')``
|
||||
or a 4-tuple ``('filename', fileobj, 'content_type', custom_headers)``, where ``'content_type'`` is a string
|
||||
defining the content type of the given file and ``custom_headers`` a dict-like object containing additional headers
|
||||
to add for the file.
|
||||
:param auth: (optional) Auth tuple to enable Basic/Digest/Custom HTTP Auth.
|
||||
:param timeout: (optional) How many seconds to wait for the server to send data
|
||||
before giving up, as a float, or a :ref:`(connect timeout, read
|
||||
timeout) <timeouts>` tuple.
|
||||
:param allow_redirects: (optional) Boolean. Enable/disable GET/OPTIONS/POST/PUT/PATCH/DELETE/HEAD redirection.
|
||||
Defaults to ``True``.
|
||||
:param proxies: (optional) Dictionary mapping protocol to the URL of the proxy.
|
||||
:param verify: (optional) Either a boolean, in which case it controls whether we verify
|
||||
the server's TLS certificate, or a path passed as a string or os.Pathlike object,
|
||||
in which case it must be a path to a CA bundle to use.
|
||||
Defaults to ``True``.
|
||||
It is also possible to put the certificates (directly) in a string or bytes.
|
||||
:param stream: (optional) if ``False``, the response content will be immediately downloaded.
|
||||
:param cert: (optional) if String, path to ssl client cert file (.pem).
|
||||
If Tuple, ('cert', 'key') pair, or ('cert', 'key', 'key_password').
|
||||
:param hooks: (optional) Register functions that should be called at very specific moment in the request lifecycle.
|
||||
:param retries: (optional) If integer, determine the number of retry in case of a timeout or connection error.
|
||||
Otherwise, for fine gained retry, use directly a ``Retry`` instance from urllib3.
|
||||
|
||||
:return: :class:`Response <Response>` object
|
||||
"""
|
||||
|
||||
return request(
|
||||
"PUT",
|
||||
url,
|
||||
data=data,
|
||||
params=params,
|
||||
json=json,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
files=files,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
proxies=proxies,
|
||||
verify=verify,
|
||||
stream=stream,
|
||||
cert=cert,
|
||||
hooks=hooks,
|
||||
retries=retries,
|
||||
)
|
||||
|
||||
|
||||
def patch(
|
||||
url: str,
|
||||
data: BodyType | None = None,
|
||||
*,
|
||||
json: typing.Any | None = None,
|
||||
params: QueryParameterType | None = None,
|
||||
headers: HeadersType | None = None,
|
||||
cookies: CookiesType | None = None,
|
||||
files: MultiPartFilesType | MultiPartFilesAltType | None = None,
|
||||
auth: HttpAuthenticationType | None = None,
|
||||
timeout: TimeoutType | None = WRITE_DEFAULT_TIMEOUT,
|
||||
allow_redirects: bool = True,
|
||||
proxies: ProxyType | None = None,
|
||||
verify: TLSVerifyType = True,
|
||||
stream: bool = False,
|
||||
cert: TLSClientCertType | None = None,
|
||||
hooks: HookType[PreparedRequest | Response] | None = None,
|
||||
retries: RetryType = DEFAULT_RETRIES,
|
||||
) -> Response:
|
||||
r"""Sends a PATCH request. This does not keep the connection alive. Use a :class:`Session` to reuse the connection.
|
||||
|
||||
:param url: URL for the new :class:`Request` object.
|
||||
:param params: (optional) Dictionary, list of tuples or bytes to send
|
||||
in the query string for the :class:`Request`.
|
||||
:param data: (optional) Dictionary, list of tuples, bytes, or file-like
|
||||
object to send in the body of the :class:`Request`.
|
||||
:param json: (optional) A JSON serializable Python object to send in the body of the :class:`Request`.
|
||||
:param headers: (optional) Dictionary of HTTP Headers to send with the :class:`Request`.
|
||||
:param cookies: (optional) Dict or CookieJar object to send with the :class:`Request`.
|
||||
:param files: (optional) Dictionary of ``'name': file-like-objects`` (or ``{'name': file-tuple}``)
|
||||
for multipart encoding upload.
|
||||
``file-tuple`` can be a 2-tuple ``('filename', fileobj)``, 3-tuple ``('filename', fileobj, 'content_type')``
|
||||
or a 4-tuple ``('filename', fileobj, 'content_type', custom_headers)``, where ``'content_type'`` is a string
|
||||
defining the content type of the given file and ``custom_headers`` a dict-like object containing additional headers
|
||||
to add for the file.
|
||||
:param auth: (optional) Auth tuple to enable Basic/Digest/Custom HTTP Auth.
|
||||
:param timeout: (optional) How many seconds to wait for the server to send data
|
||||
before giving up, as a float, or a :ref:`(connect timeout, read
|
||||
timeout) <timeouts>` tuple.
|
||||
:param allow_redirects: (optional) Boolean. Enable/disable GET/OPTIONS/POST/PUT/PATCH/DELETE/HEAD redirection.
|
||||
Defaults to ``True``.
|
||||
:param proxies: (optional) Dictionary mapping protocol to the URL of the proxy.
|
||||
:param verify: (optional) Either a boolean, in which case it controls whether we verify
|
||||
the server's TLS certificate, or a path passed as a string or os.Pathlike object,
|
||||
in which case it must be a path to a CA bundle to use.
|
||||
Defaults to ``True``.
|
||||
It is also possible to put the certificates (directly) in a string or bytes.
|
||||
:param stream: (optional) if ``False``, the response content will be immediately downloaded.
|
||||
:param cert: (optional) if String, path to ssl client cert file (.pem).
|
||||
If Tuple, ('cert', 'key') pair, or ('cert', 'key', 'key_password').
|
||||
:param hooks: (optional) Register functions that should be called at very specific moment in the request lifecycle.
|
||||
:param retries: (optional) If integer, determine the number of retry in case of a timeout or connection error.
|
||||
Otherwise, for fine gained retry, use directly a ``Retry`` instance from urllib3.
|
||||
|
||||
:return: :class:`Response <Response>` object
|
||||
"""
|
||||
|
||||
return request(
|
||||
"PATCH",
|
||||
url,
|
||||
data=data,
|
||||
params=params,
|
||||
json=json,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
files=files,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
proxies=proxies,
|
||||
verify=verify,
|
||||
stream=stream,
|
||||
cert=cert,
|
||||
hooks=hooks,
|
||||
retries=retries,
|
||||
)
|
||||
|
||||
|
||||
def delete(
|
||||
url: str,
|
||||
*,
|
||||
params: QueryParameterType | None = None,
|
||||
headers: HeadersType | None = None,
|
||||
cookies: CookiesType | None = None,
|
||||
auth: HttpAuthenticationType | None = None,
|
||||
timeout: TimeoutType | None = WRITE_DEFAULT_TIMEOUT,
|
||||
allow_redirects: bool = True,
|
||||
proxies: ProxyType | None = None,
|
||||
verify: TLSVerifyType = True,
|
||||
stream: bool = False,
|
||||
cert: TLSClientCertType | None = None,
|
||||
hooks: HookType[PreparedRequest | Response] | None = None,
|
||||
retries: RetryType = DEFAULT_RETRIES,
|
||||
**kwargs: typing.Any,
|
||||
) -> Response:
|
||||
r"""Sends a DELETE request. This does not keep the connection alive. Use a :class:`Session` to reuse the connection.
|
||||
|
||||
:param url: URL for the new :class:`Request` object.
|
||||
:param params: (optional) Dictionary, list of tuples or bytes to send
|
||||
in the query string for the :class:`Request`.
|
||||
:param headers: (optional) Dictionary of HTTP Headers to send with the :class:`Request`.
|
||||
:param cookies: (optional) Dict or CookieJar object to send with the :class:`Request`.
|
||||
:param auth: (optional) Auth tuple to enable Basic/Digest/Custom HTTP Auth.
|
||||
:param timeout: (optional) How many seconds to wait for the server to send data
|
||||
before giving up, as a float, or a :ref:`(connect timeout, read
|
||||
timeout) <timeouts>` tuple.
|
||||
:param allow_redirects: (optional) Boolean. Enable/disable GET/OPTIONS/POST/PUT/PATCH/DELETE/HEAD redirection.
|
||||
Defaults to ``True``.
|
||||
:param proxies: (optional) Dictionary mapping protocol to the URL of the proxy.
|
||||
:param verify: (optional) Either a boolean, in which case it controls whether we verify
|
||||
the server's TLS certificate, or a path passed as a string or os.Pathlike object,
|
||||
in which case it must be a path to a CA bundle to use.
|
||||
Defaults to ``True``.
|
||||
It is also possible to put the certificates (directly) in a string or bytes.
|
||||
:param stream: (optional) if ``False``, the response content will be immediately downloaded.
|
||||
:param cert: (optional) if String, path to ssl client cert file (.pem).
|
||||
If Tuple, ('cert', 'key') pair, or ('cert', 'key', 'key_password').
|
||||
:param hooks: (optional) Register functions that should be called at very specific moment in the request lifecycle.
|
||||
:param retries: (optional) If integer, determine the number of retry in case of a timeout or connection error.
|
||||
Otherwise, for fine gained retry, use directly a ``Retry`` instance from urllib3.
|
||||
|
||||
:return: :class:`Response <Response>` object
|
||||
"""
|
||||
|
||||
return request(
|
||||
"DELETE",
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
proxies=proxies,
|
||||
verify=verify,
|
||||
stream=stream,
|
||||
cert=cert,
|
||||
hooks=hooks,
|
||||
retries=retries,
|
||||
**kwargs,
|
||||
)
|
||||
975
.venv/lib/python3.9/site-packages/niquests/async_api.py
Normal file
975
.venv/lib/python3.9/site-packages/niquests/async_api.py
Normal file
@@ -0,0 +1,975 @@
|
||||
"""
|
||||
requests.api
|
||||
~~~~~~~~~~~~
|
||||
|
||||
This module implements the Requests API.
|
||||
|
||||
:copyright: (c) 2012 by Kenneth Reitz.
|
||||
:license: Apache2, see LICENSE for more details.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
import typing
|
||||
|
||||
from ._constant import DEFAULT_RETRIES, READ_DEFAULT_TIMEOUT, WRITE_DEFAULT_TIMEOUT
|
||||
from .async_session import AsyncSession
|
||||
from .models import AsyncResponse, PreparedRequest, Response
|
||||
from .structures import AsyncQuicSharedCache
|
||||
from .typing import (
|
||||
AsyncBodyType,
|
||||
AsyncHookType,
|
||||
AsyncHttpAuthenticationType,
|
||||
BodyType,
|
||||
CacheLayerAltSvcType,
|
||||
CookiesType,
|
||||
HeadersType,
|
||||
HttpAuthenticationType,
|
||||
HttpMethodType,
|
||||
MultiPartFilesAltType,
|
||||
MultiPartFilesType,
|
||||
ProxyType,
|
||||
QueryParameterType,
|
||||
RetryType,
|
||||
TimeoutType,
|
||||
TLSClientCertType,
|
||||
TLSVerifyType,
|
||||
)
|
||||
|
||||
_SHARED_OCSP_CACHE: contextvars.ContextVar[typing.Any | None] = contextvars.ContextVar("ocsp_cache", default=None)
|
||||
_SHARED_CRL_CACHE: contextvars.ContextVar[typing.Any | None] = contextvars.ContextVar("crl_cache", default=None)
|
||||
_SHARED_QUIC_CACHE: CacheLayerAltSvcType = AsyncQuicSharedCache(max_size=12_288)
|
||||
|
||||
|
||||
@typing.overload
|
||||
async def request(
|
||||
method: HttpMethodType,
|
||||
url: str,
|
||||
*,
|
||||
params: QueryParameterType | None = ...,
|
||||
data: BodyType | AsyncBodyType | None = ...,
|
||||
headers: HeadersType | None = ...,
|
||||
cookies: CookiesType | None = ...,
|
||||
files: MultiPartFilesType | MultiPartFilesAltType | None = ...,
|
||||
auth: HttpAuthenticationType | AsyncHttpAuthenticationType | None = ...,
|
||||
timeout: TimeoutType | None = ...,
|
||||
allow_redirects: bool = ...,
|
||||
proxies: ProxyType | None = ...,
|
||||
hooks: AsyncHookType[PreparedRequest | Response] | None = ...,
|
||||
stream: typing.Literal[False] | None = ...,
|
||||
verify: TLSVerifyType | None = ...,
|
||||
cert: TLSClientCertType | None = ...,
|
||||
json: typing.Any | None = ...,
|
||||
retries: RetryType = ...,
|
||||
) -> Response: ...
|
||||
|
||||
|
||||
@typing.overload
|
||||
async def request(
|
||||
method: HttpMethodType,
|
||||
url: str,
|
||||
*,
|
||||
params: QueryParameterType | None = ...,
|
||||
data: BodyType | AsyncBodyType | None = ...,
|
||||
headers: HeadersType | None = ...,
|
||||
cookies: CookiesType | None = ...,
|
||||
files: MultiPartFilesType | MultiPartFilesAltType | None = ...,
|
||||
auth: HttpAuthenticationType | AsyncHttpAuthenticationType | None = ...,
|
||||
timeout: TimeoutType | None = ...,
|
||||
allow_redirects: bool = ...,
|
||||
proxies: ProxyType | None = ...,
|
||||
hooks: AsyncHookType[PreparedRequest | Response] | None = ...,
|
||||
stream: typing.Literal[True] = ...,
|
||||
verify: TLSVerifyType | None = ...,
|
||||
cert: TLSClientCertType | None = ...,
|
||||
json: typing.Any | None = ...,
|
||||
retries: RetryType = ...,
|
||||
) -> AsyncResponse: ...
|
||||
|
||||
|
||||
async def request(
|
||||
method: HttpMethodType,
|
||||
url: str,
|
||||
*,
|
||||
params: QueryParameterType | None = None,
|
||||
data: BodyType | AsyncBodyType | None = None,
|
||||
headers: HeadersType | None = None,
|
||||
cookies: CookiesType | None = None,
|
||||
files: MultiPartFilesType | MultiPartFilesAltType | None = None,
|
||||
auth: HttpAuthenticationType | AsyncHttpAuthenticationType | None = None,
|
||||
timeout: TimeoutType | None = WRITE_DEFAULT_TIMEOUT,
|
||||
allow_redirects: bool = True,
|
||||
proxies: ProxyType | None = None,
|
||||
hooks: AsyncHookType[PreparedRequest | Response] | None = None,
|
||||
stream: bool | None = None,
|
||||
verify: TLSVerifyType | None = None,
|
||||
cert: TLSClientCertType | None = None,
|
||||
json: typing.Any | None = None,
|
||||
retries: RetryType = DEFAULT_RETRIES,
|
||||
) -> Response | AsyncResponse:
|
||||
"""Constructs and sends a :class:`Request <Request>`. This does not keep the connection alive.
|
||||
Use an :class:`AsyncSession` to reuse the connection.
|
||||
|
||||
:param method: method for the new :class:`Request` object: ``GET``, ``OPTIONS``, ``HEAD``, ``POST``, ``PUT``, ``PATCH``,
|
||||
or ``DELETE``.
|
||||
:param url: URL for the new :class:`Request` object.
|
||||
:param params: (optional) Dictionary, list of tuples or bytes to send
|
||||
in the query string for the :class:`Request`.
|
||||
:param data: (optional) Dictionary, list of tuples, bytes, or file-like
|
||||
object to send in the body of the :class:`Request`.
|
||||
:param json: (optional) A JSON serializable Python object to send in the body of the :class:`Request`.
|
||||
:param headers: (optional) Dictionary of HTTP Headers to send with the :class:`Request`.
|
||||
:param cookies: (optional) Dict or CookieJar object to send with the :class:`Request`.
|
||||
:param files: (optional) Dictionary of ``'name': file-like-objects`` (or ``{'name': file-tuple}``)
|
||||
for multipart encoding upload.
|
||||
``file-tuple`` can be a 2-tuple ``('filename', fileobj)``, 3-tuple ``('filename', fileobj, 'content_type')``
|
||||
or a 4-tuple ``('filename', fileobj, 'content_type', custom_headers)``, where ``'content_type'`` is a string
|
||||
defining the content type of the given file and ``custom_headers`` a dict-like object containing additional headers
|
||||
to add for the file.
|
||||
:param auth: (optional) Auth tuple to enable Basic/Digest/Custom HTTP Auth.
|
||||
:param timeout: (optional) How many seconds to wait for the server to send data
|
||||
before giving up, as a float, or a :ref:`(connect timeout, read
|
||||
timeout) <timeouts>` tuple.
|
||||
:param allow_redirects: (optional) Boolean. Enable/disable GET/OPTIONS/POST/PUT/PATCH/DELETE/HEAD redirection.
|
||||
Defaults to ``True``.
|
||||
:param proxies: (optional) Dictionary mapping protocol to the URL of the proxy.
|
||||
:param verify: (optional) Either a boolean, in which case it controls whether we verify
|
||||
the server's TLS certificate, or a path passed as a string or os.Pathlike object,
|
||||
in which case it must be a path to a CA bundle to use.
|
||||
Defaults to ``True``.
|
||||
It is also possible to put the certificates (directly) in a string or bytes.
|
||||
:param stream: (optional) if ``False``, the response content will be immediately downloaded. Otherwise, the response will
|
||||
be of type :class:`AsyncResponse <AsyncResponse>` so that it will be awaitable.
|
||||
:param cert: (optional) if String, path to ssl client cert file (.pem).
|
||||
If Tuple, ('cert', 'key') pair, or ('cert', 'key', 'key_password').
|
||||
:param hooks: (optional) Register functions that should be called at very specific moment in the request lifecycle.
|
||||
:param retries: (optional) If integer, determine the number of retry in case of a timeout or connection error.
|
||||
Otherwise, for fine gained retry, use directly a ``Retry`` instance from urllib3.
|
||||
:return: :class:`Response <Response>` object if stream=None or False. Otherwise :class:`AsyncResponse <AsyncResponse>`
|
||||
|
||||
Usage::
|
||||
|
||||
>>> import niquests
|
||||
>>> req = await niquests.arequest('GET', 'https://httpbin.org/get')
|
||||
>>> req
|
||||
<Response HTTP/2 [200]>
|
||||
"""
|
||||
|
||||
# By using the 'with' statement we are sure the session is closed, thus we
|
||||
# avoid leaving sockets open which can trigger a ResourceWarning in some
|
||||
# cases, and look like a memory leak in others.
|
||||
async with AsyncSession(quic_cache_layer=_SHARED_QUIC_CACHE, retries=retries) as session:
|
||||
session._ocsp_cache = _SHARED_OCSP_CACHE.get()
|
||||
session._crl_cache = _SHARED_CRL_CACHE.get()
|
||||
try:
|
||||
return await session.request( # type: ignore[misc]
|
||||
method,
|
||||
url,
|
||||
params,
|
||||
data,
|
||||
headers,
|
||||
cookies,
|
||||
files,
|
||||
auth,
|
||||
timeout,
|
||||
allow_redirects,
|
||||
proxies,
|
||||
hooks,
|
||||
stream, # type: ignore[arg-type]
|
||||
verify,
|
||||
cert,
|
||||
json,
|
||||
)
|
||||
finally:
|
||||
if _SHARED_OCSP_CACHE.get() is None and session._ocsp_cache is not None:
|
||||
_SHARED_OCSP_CACHE.set(session._ocsp_cache)
|
||||
if _SHARED_CRL_CACHE.get() is None and session._crl_cache is not None:
|
||||
_SHARED_CRL_CACHE.set(session._crl_cache)
|
||||
|
||||
|
||||
@typing.overload
|
||||
async def get(
|
||||
url: str,
|
||||
params: QueryParameterType | None = ...,
|
||||
*,
|
||||
headers: HeadersType | None = ...,
|
||||
cookies: CookiesType | None = ...,
|
||||
auth: HttpAuthenticationType | AsyncHttpAuthenticationType | None = ...,
|
||||
timeout: TimeoutType | None = ...,
|
||||
allow_redirects: bool = ...,
|
||||
proxies: ProxyType | None = ...,
|
||||
hooks: AsyncHookType[PreparedRequest | Response] | None = ...,
|
||||
verify: TLSVerifyType | None = ...,
|
||||
stream: typing.Literal[False] | None = ...,
|
||||
cert: TLSClientCertType | None = ...,
|
||||
retries: RetryType = ...,
|
||||
**kwargs: typing.Any,
|
||||
) -> Response: ...
|
||||
|
||||
|
||||
@typing.overload
|
||||
async def get(
|
||||
url: str,
|
||||
params: QueryParameterType | None = ...,
|
||||
*,
|
||||
headers: HeadersType | None = ...,
|
||||
cookies: CookiesType | None = ...,
|
||||
auth: HttpAuthenticationType | AsyncHttpAuthenticationType | None = ...,
|
||||
timeout: TimeoutType | None = ...,
|
||||
allow_redirects: bool = ...,
|
||||
proxies: ProxyType | None = ...,
|
||||
hooks: AsyncHookType[PreparedRequest | Response] | None = ...,
|
||||
verify: TLSVerifyType | None = ...,
|
||||
stream: typing.Literal[True] = ...,
|
||||
cert: TLSClientCertType | None = ...,
|
||||
retries: RetryType = ...,
|
||||
**kwargs: typing.Any,
|
||||
) -> AsyncResponse: ...
|
||||
|
||||
|
||||
async def get(
|
||||
url: str,
|
||||
params: QueryParameterType | None = None,
|
||||
*,
|
||||
headers: HeadersType | None = None,
|
||||
cookies: CookiesType | None = None,
|
||||
auth: HttpAuthenticationType | AsyncHttpAuthenticationType | None = None,
|
||||
timeout: TimeoutType | None = READ_DEFAULT_TIMEOUT,
|
||||
allow_redirects: bool = True,
|
||||
proxies: ProxyType | None = None,
|
||||
hooks: AsyncHookType[PreparedRequest | Response] | None = None,
|
||||
verify: TLSVerifyType | None = None,
|
||||
stream: bool | None = None,
|
||||
cert: TLSClientCertType | None = None,
|
||||
retries: RetryType = DEFAULT_RETRIES,
|
||||
**kwargs: typing.Any,
|
||||
) -> Response | AsyncResponse:
|
||||
r"""Sends a GET request. This does not keep the connection alive. Use an :class:`AsyncSession` to reuse the connection.
|
||||
|
||||
:param url: URL for the new :class:`Request` object.
|
||||
:param params: (optional) Dictionary, list of tuples or bytes to send
|
||||
in the query string for the :class:`Request`.
|
||||
:param params: (optional) Dictionary, list of tuples or bytes to send
|
||||
in the query string for the :class:`Request`.
|
||||
:param headers: (optional) Dictionary of HTTP Headers to send with the :class:`Request`.
|
||||
:param cookies: (optional) Dict or CookieJar object to send with the :class:`Request`.
|
||||
:param auth: (optional) Auth tuple to enable Basic/Digest/Custom HTTP Auth.
|
||||
:param timeout: (optional) How many seconds to wait for the server to send data
|
||||
before giving up, as a float, or a :ref:`(connect timeout, read
|
||||
timeout) <timeouts>` tuple.
|
||||
:param allow_redirects: (optional) Boolean. Enable/disable GET/OPTIONS/POST/PUT/PATCH/DELETE/HEAD redirection.
|
||||
Defaults to ``True``.
|
||||
:param proxies: (optional) Dictionary mapping protocol to the URL of the proxy.
|
||||
:param verify: (optional) Either a boolean, in which case it controls whether we verify
|
||||
the server's TLS certificate, or a path passed as a string or os.Pathlike object,
|
||||
in which case it must be a path to a CA bundle to use.
|
||||
Defaults to ``True``.
|
||||
It is also possible to put the certificates (directly) in a string or bytes.
|
||||
:param stream: (optional) if ``False``, the response content will be immediately downloaded. Otherwise, the response will
|
||||
be of type :class:`AsyncResponse <AsyncResponse>` so that it will be awaitable.
|
||||
:param cert: (optional) if String, path to ssl client cert file (.pem).
|
||||
If Tuple, ('cert', 'key') pair, or ('cert', 'key', 'key_password').
|
||||
:param hooks: (optional) Register functions that should be called at very specific moment in the request lifecycle.
|
||||
:param retries: (optional) If integer, determine the number of retry in case of a timeout or connection error.
|
||||
Otherwise, for fine gained retry, use directly a ``Retry`` instance from urllib3.
|
||||
:return: :class:`Response <Response>` object if stream=None or False. Otherwise :class:`AsyncResponse <AsyncResponse>`
|
||||
"""
|
||||
return await request( # type: ignore[misc]
|
||||
"GET",
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
proxies=proxies,
|
||||
verify=verify,
|
||||
stream=stream, # type: ignore[arg-type]
|
||||
cert=cert,
|
||||
hooks=hooks,
|
||||
retries=retries,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@typing.overload
|
||||
async def options(
|
||||
url: str,
|
||||
*,
|
||||
params: QueryParameterType | None = ...,
|
||||
headers: HeadersType | None = ...,
|
||||
cookies: CookiesType | None = ...,
|
||||
auth: HttpAuthenticationType | AsyncHttpAuthenticationType | None = ...,
|
||||
timeout: TimeoutType | None = ...,
|
||||
allow_redirects: bool = ...,
|
||||
proxies: ProxyType | None = ...,
|
||||
hooks: AsyncHookType[PreparedRequest | Response] | None = ...,
|
||||
verify: TLSVerifyType | None = ...,
|
||||
stream: typing.Literal[False] | typing.Literal[None] = ...,
|
||||
cert: TLSClientCertType | None = ...,
|
||||
retries: RetryType = ...,
|
||||
**kwargs: typing.Any,
|
||||
) -> Response: ...
|
||||
|
||||
|
||||
@typing.overload
|
||||
async def options(
|
||||
url: str,
|
||||
*,
|
||||
params: QueryParameterType | None = ...,
|
||||
headers: HeadersType | None = ...,
|
||||
cookies: CookiesType | None = ...,
|
||||
auth: HttpAuthenticationType | AsyncHttpAuthenticationType | None = ...,
|
||||
timeout: TimeoutType | None = ...,
|
||||
allow_redirects: bool = ...,
|
||||
proxies: ProxyType | None = ...,
|
||||
hooks: AsyncHookType[PreparedRequest | Response] | None = ...,
|
||||
verify: TLSVerifyType | None = ...,
|
||||
stream: typing.Literal[True],
|
||||
cert: TLSClientCertType | None = ...,
|
||||
retries: RetryType = ...,
|
||||
**kwargs: typing.Any,
|
||||
) -> AsyncResponse: ...
|
||||
|
||||
|
||||
async def options(
|
||||
url: str,
|
||||
*,
|
||||
params: QueryParameterType | None = None,
|
||||
headers: HeadersType | None = None,
|
||||
cookies: CookiesType | None = None,
|
||||
auth: HttpAuthenticationType | AsyncHttpAuthenticationType | None = None,
|
||||
timeout: TimeoutType | None = READ_DEFAULT_TIMEOUT,
|
||||
allow_redirects: bool = True,
|
||||
proxies: ProxyType | None = None,
|
||||
hooks: AsyncHookType[PreparedRequest | Response] | None = None,
|
||||
verify: TLSVerifyType | None = None,
|
||||
stream: bool | None = None,
|
||||
cert: TLSClientCertType | None = None,
|
||||
retries: RetryType = DEFAULT_RETRIES,
|
||||
**kwargs: typing.Any,
|
||||
) -> Response | AsyncResponse:
|
||||
r"""Sends an OPTIONS request. This does not keep the connection alive. Use an :class:`AsyncSession` to reuse the connection.
|
||||
|
||||
:param url: URL for the new :class:`Request` object.
|
||||
:param params: (optional) Dictionary, list of tuples or bytes to send
|
||||
in the query string for the :class:`Request`.
|
||||
:param headers: (optional) Dictionary of HTTP Headers to send with the :class:`Request`.
|
||||
:param cookies: (optional) Dict or CookieJar object to send with the :class:`Request`.
|
||||
:param auth: (optional) Auth tuple to enable Basic/Digest/Custom HTTP Auth.
|
||||
:param timeout: (optional) How many seconds to wait for the server to send data
|
||||
before giving up, as a float, or a :ref:`(connect timeout, read
|
||||
timeout) <timeouts>` tuple.
|
||||
:param allow_redirects: (optional) Boolean. Enable/disable GET/OPTIONS/POST/PUT/PATCH/DELETE/HEAD redirection.
|
||||
Defaults to ``True``.
|
||||
:param proxies: (optional) Dictionary mapping protocol to the URL of the proxy.
|
||||
:param verify: (optional) Either a boolean, in which case it controls whether we verify
|
||||
the server's TLS certificate, or a path passed as a string or os.Pathlike object,
|
||||
in which case it must be a path to a CA bundle to use.
|
||||
Defaults to ``True``.
|
||||
It is also possible to put the certificates (directly) in a string or bytes.
|
||||
:param stream: (optional) if ``False``, the response content will be immediately downloaded. Otherwise, the response will
|
||||
be of type :class:`AsyncResponse <AsyncResponse>` so that it will be awaitable.
|
||||
:param cert: (optional) if String, path to ssl client cert file (.pem).
|
||||
If Tuple, ('cert', 'key') pair, or ('cert', 'key', 'key_password').
|
||||
:param hooks: (optional) Register functions that should be called at very specific moment in the request lifecycle.
|
||||
:param retries: (optional) If integer, determine the number of retry in case of a timeout or connection error.
|
||||
Otherwise, for fine gained retry, use directly a ``Retry`` instance from urllib3.
|
||||
|
||||
:return: :class:`Response <Response>` object if stream=None or False. Otherwise :class:`AsyncResponse <AsyncResponse>`
|
||||
"""
|
||||
return await request( # type: ignore[misc]
|
||||
"OPTIONS",
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
proxies=proxies,
|
||||
verify=verify,
|
||||
stream=stream, # type: ignore[arg-type]
|
||||
cert=cert,
|
||||
hooks=hooks,
|
||||
retries=retries,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@typing.overload
|
||||
async def head(
|
||||
url: str,
|
||||
*,
|
||||
params: QueryParameterType | None = ...,
|
||||
headers: HeadersType | None = ...,
|
||||
cookies: CookiesType | None = ...,
|
||||
auth: HttpAuthenticationType | AsyncHttpAuthenticationType | None = ...,
|
||||
timeout: TimeoutType | None = ...,
|
||||
allow_redirects: bool = ...,
|
||||
proxies: ProxyType | None = ...,
|
||||
hooks: AsyncHookType[PreparedRequest | Response] | None = ...,
|
||||
verify: TLSVerifyType | None = ...,
|
||||
stream: typing.Literal[False] | typing.Literal[None] = ...,
|
||||
cert: TLSClientCertType | None = ...,
|
||||
retries: RetryType = ...,
|
||||
**kwargs: typing.Any,
|
||||
) -> Response: ...
|
||||
|
||||
|
||||
@typing.overload
|
||||
async def head(
|
||||
url: str,
|
||||
*,
|
||||
params: QueryParameterType | None = ...,
|
||||
headers: HeadersType | None = ...,
|
||||
cookies: CookiesType | None = ...,
|
||||
auth: HttpAuthenticationType | AsyncHttpAuthenticationType | None = ...,
|
||||
timeout: TimeoutType | None = ...,
|
||||
allow_redirects: bool = ...,
|
||||
proxies: ProxyType | None = ...,
|
||||
hooks: AsyncHookType[PreparedRequest | Response] | None = ...,
|
||||
verify: TLSVerifyType | None = ...,
|
||||
stream: typing.Literal[True],
|
||||
cert: TLSClientCertType | None = ...,
|
||||
retries: RetryType = ...,
|
||||
**kwargs: typing.Any,
|
||||
) -> AsyncResponse: ...
|
||||
|
||||
|
||||
async def head(
|
||||
url: str,
|
||||
*,
|
||||
params: QueryParameterType | None = None,
|
||||
headers: HeadersType | None = None,
|
||||
cookies: CookiesType | None = None,
|
||||
auth: HttpAuthenticationType | AsyncHttpAuthenticationType | None = None,
|
||||
timeout: TimeoutType | None = READ_DEFAULT_TIMEOUT,
|
||||
allow_redirects: bool = False,
|
||||
proxies: ProxyType | None = None,
|
||||
hooks: AsyncHookType[PreparedRequest | Response] | None = None,
|
||||
verify: TLSVerifyType | None = None,
|
||||
stream: bool | None = None,
|
||||
cert: TLSClientCertType | None = None,
|
||||
retries: RetryType = DEFAULT_RETRIES,
|
||||
**kwargs: typing.Any,
|
||||
) -> Response | AsyncResponse:
|
||||
r"""Sends a HEAD request. This does not keep the connection alive. Use an :class:`AsyncSession` to reuse the connection.
|
||||
|
||||
:param url: URL for the new :class:`Request` object.
|
||||
:param params: (optional) Dictionary, list of tuples or bytes to send
|
||||
in the query string for the :class:`Request`.
|
||||
:param headers: (optional) Dictionary of HTTP Headers to send with the :class:`Request`.
|
||||
:param cookies: (optional) Dict or CookieJar object to send with the :class:`Request`.
|
||||
:param auth: (optional) Auth tuple to enable Basic/Digest/Custom HTTP Auth.
|
||||
:param timeout: (optional) How many seconds to wait for the server to send data
|
||||
before giving up, as a float, or a :ref:`(connect timeout, read
|
||||
timeout) <timeouts>` tuple.
|
||||
:param allow_redirects: (optional) Boolean. Enable/disable GET/OPTIONS/POST/PUT/PATCH/DELETE/HEAD redirection.
|
||||
Defaults to ``False``.
|
||||
:param proxies: (optional) Dictionary mapping protocol to the URL of the proxy.
|
||||
:param verify: (optional) Either a boolean, in which case it controls whether we verify
|
||||
the server's TLS certificate, or a path passed as a string or os.Pathlike object,
|
||||
in which case it must be a path to a CA bundle to use.
|
||||
Defaults to ``True``.
|
||||
It is also possible to put the certificates (directly) in a string or bytes.
|
||||
:param stream: (optional) if ``False``, the response content will be immediately downloaded. Otherwise, the response will
|
||||
be of type :class:`AsyncResponse <AsyncResponse>` so that it will be awaitable.
|
||||
:param cert: (optional) if String, path to ssl client cert file (.pem).
|
||||
If Tuple, ('cert', 'key') pair, or ('cert', 'key', 'key_password').
|
||||
:param hooks: (optional) Register functions that should be called at very specific moment in the request lifecycle.
|
||||
:param retries: (optional) If integer, determine the number of retry in case of a timeout or connection error.
|
||||
Otherwise, for fine gained retry, use directly a ``Retry`` instance from urllib3.
|
||||
|
||||
:return: :class:`Response <Response>` object if stream=None or False. Otherwise :class:`AsyncResponse <AsyncResponse>`
|
||||
"""
|
||||
return await request( # type: ignore[misc]
|
||||
"HEAD",
|
||||
url,
|
||||
allow_redirects=allow_redirects,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
proxies=proxies,
|
||||
verify=verify,
|
||||
stream=stream, # type: ignore[arg-type]
|
||||
cert=cert,
|
||||
hooks=hooks,
|
||||
retries=retries,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@typing.overload
|
||||
async def post(
|
||||
url: str,
|
||||
data: BodyType | AsyncBodyType | None = ...,
|
||||
json: typing.Any | None = ...,
|
||||
*,
|
||||
params: QueryParameterType | None = ...,
|
||||
headers: HeadersType | None = ...,
|
||||
cookies: CookiesType | None = ...,
|
||||
files: MultiPartFilesType | MultiPartFilesAltType | None = ...,
|
||||
auth: HttpAuthenticationType | AsyncHttpAuthenticationType | None = ...,
|
||||
timeout: TimeoutType | None = ...,
|
||||
allow_redirects: bool = ...,
|
||||
proxies: ProxyType | None = ...,
|
||||
hooks: AsyncHookType[PreparedRequest | Response] | None = ...,
|
||||
verify: TLSVerifyType | None = ...,
|
||||
stream: typing.Literal[False] | typing.Literal[None] = ...,
|
||||
cert: TLSClientCertType | None = ...,
|
||||
retries: RetryType = ...,
|
||||
) -> Response: ...
|
||||
|
||||
|
||||
@typing.overload
|
||||
async def post(
|
||||
url: str,
|
||||
data: BodyType | AsyncBodyType | None = ...,
|
||||
json: typing.Any | None = ...,
|
||||
*,
|
||||
params: QueryParameterType | None = ...,
|
||||
headers: HeadersType | None = ...,
|
||||
cookies: CookiesType | None = ...,
|
||||
files: MultiPartFilesType | MultiPartFilesAltType | None = ...,
|
||||
auth: HttpAuthenticationType | AsyncHttpAuthenticationType | None = ...,
|
||||
timeout: TimeoutType | None = ...,
|
||||
allow_redirects: bool = ...,
|
||||
proxies: ProxyType | None = ...,
|
||||
hooks: AsyncHookType[PreparedRequest | Response] | None = ...,
|
||||
verify: TLSVerifyType | None = ...,
|
||||
stream: typing.Literal[True],
|
||||
cert: TLSClientCertType | None = ...,
|
||||
retries: RetryType = ...,
|
||||
) -> AsyncResponse: ...
|
||||
|
||||
|
||||
async def post(
|
||||
url: str,
|
||||
data: BodyType | AsyncBodyType | None = None,
|
||||
json: typing.Any | None = None,
|
||||
*,
|
||||
params: QueryParameterType | None = None,
|
||||
headers: HeadersType | None = None,
|
||||
cookies: CookiesType | None = None,
|
||||
files: MultiPartFilesType | MultiPartFilesAltType | None = None,
|
||||
auth: HttpAuthenticationType | AsyncHttpAuthenticationType | None = None,
|
||||
timeout: TimeoutType | None = WRITE_DEFAULT_TIMEOUT,
|
||||
allow_redirects: bool = True,
|
||||
proxies: ProxyType | None = None,
|
||||
hooks: AsyncHookType[PreparedRequest | Response] | None = None,
|
||||
verify: TLSVerifyType | None = None,
|
||||
stream: bool | None = None,
|
||||
cert: TLSClientCertType | None = None,
|
||||
retries: RetryType = DEFAULT_RETRIES,
|
||||
) -> Response | AsyncResponse:
|
||||
r"""Sends a POST request. This does not keep the connection alive. Use an :class:`AsyncSession` to reuse the connection.
|
||||
|
||||
:param url: URL for the new :class:`Request` object.
|
||||
:param params: (optional) Dictionary, list of tuples or bytes to send
|
||||
in the query string for the :class:`Request`.
|
||||
:param data: (optional) Dictionary, list of tuples, bytes, or (awaitable or not) file-like
|
||||
object to send in the body of the :class:`Request`.
|
||||
:param json: (optional) A JSON serializable Python object to send in the body of the :class:`Request`.
|
||||
:param headers: (optional) Dictionary of HTTP Headers to send with the :class:`Request`.
|
||||
:param cookies: (optional) Dict or CookieJar object to send with the :class:`Request`.
|
||||
:param files: (optional) Dictionary of ``'name': file-like-objects`` (or ``{'name': file-tuple}``)
|
||||
for multipart encoding upload.
|
||||
``file-tuple`` can be a 2-tuple ``('filename', fileobj)``, 3-tuple ``('filename', fileobj, 'content_type')``
|
||||
or a 4-tuple ``('filename', fileobj, 'content_type', custom_headers)``, where ``'content_type'`` is a string
|
||||
defining the content type of the given file and ``custom_headers`` a dict-like object containing additional headers
|
||||
to add for the file.
|
||||
:param auth: (optional) Auth tuple to enable Basic/Digest/Custom HTTP Auth.
|
||||
:param timeout: (optional) How many seconds to wait for the server to send data
|
||||
before giving up, as a float, or a :ref:`(connect timeout, read
|
||||
timeout) <timeouts>` tuple.
|
||||
:param allow_redirects: (optional) Boolean. Enable/disable GET/OPTIONS/POST/PUT/PATCH/DELETE/HEAD redirection.
|
||||
Defaults to ``True``.
|
||||
:param proxies: (optional) Dictionary mapping protocol to the URL of the proxy.
|
||||
:param verify: (optional) Either a boolean, in which case it controls whether we verify
|
||||
the server's TLS certificate, or a path passed as a string or os.Pathlike object,
|
||||
in which case it must be a path to a CA bundle to use.
|
||||
Defaults to ``True``.
|
||||
It is also possible to put the certificates (directly) in a string or bytes.
|
||||
:param stream: (optional) if ``False``, the response content will be immediately downloaded. Otherwise, the response will
|
||||
be of type :class:`AsyncResponse <AsyncResponse>` so that it will be awaitable.
|
||||
:param cert: (optional) if String, path to ssl client cert file (.pem).
|
||||
If Tuple, ('cert', 'key') pair, or ('cert', 'key', 'key_password').
|
||||
:param hooks: (optional) Register functions that should be called at very specific moment in the request lifecycle.
|
||||
:param retries: (optional) If integer, determine the number of retry in case of a timeout or connection error.
|
||||
Otherwise, for fine gained retry, use directly a ``Retry`` instance from urllib3.
|
||||
|
||||
:return: :class:`Response <Response>` object if stream=None or False. Otherwise :class:`AsyncResponse <AsyncResponse>`
|
||||
"""
|
||||
return await request( # type: ignore[misc]
|
||||
"POST",
|
||||
url,
|
||||
data=data,
|
||||
json=json,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
files=files,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
proxies=proxies,
|
||||
verify=verify,
|
||||
stream=stream, # type: ignore[arg-type]
|
||||
cert=cert,
|
||||
hooks=hooks,
|
||||
retries=retries,
|
||||
)
|
||||
|
||||
|
||||
@typing.overload
|
||||
async def put(
|
||||
url: str,
|
||||
data: BodyType | AsyncBodyType | None = ...,
|
||||
*,
|
||||
json: typing.Any | None = ...,
|
||||
params: QueryParameterType | None = ...,
|
||||
headers: HeadersType | None = ...,
|
||||
cookies: CookiesType | None = ...,
|
||||
files: MultiPartFilesType | MultiPartFilesAltType | None = ...,
|
||||
auth: HttpAuthenticationType | AsyncHttpAuthenticationType | None = ...,
|
||||
timeout: TimeoutType | None = ...,
|
||||
allow_redirects: bool = ...,
|
||||
proxies: ProxyType | None = ...,
|
||||
hooks: AsyncHookType[PreparedRequest | Response] | None = ...,
|
||||
verify: TLSVerifyType | None = ...,
|
||||
stream: typing.Literal[False] | typing.Literal[None] = ...,
|
||||
cert: TLSClientCertType | None = ...,
|
||||
retries: RetryType = ...,
|
||||
) -> Response: ...
|
||||
|
||||
|
||||
@typing.overload
|
||||
async def put(
|
||||
url: str,
|
||||
data: BodyType | AsyncBodyType | None = ...,
|
||||
*,
|
||||
json: typing.Any | None = ...,
|
||||
params: QueryParameterType | None = ...,
|
||||
headers: HeadersType | None = ...,
|
||||
cookies: CookiesType | None = ...,
|
||||
files: MultiPartFilesType | MultiPartFilesAltType | None = ...,
|
||||
auth: HttpAuthenticationType | AsyncHttpAuthenticationType | None = ...,
|
||||
timeout: TimeoutType | None = ...,
|
||||
allow_redirects: bool = ...,
|
||||
proxies: ProxyType | None = ...,
|
||||
hooks: AsyncHookType[PreparedRequest | Response] | None = ...,
|
||||
verify: TLSVerifyType | None = ...,
|
||||
stream: typing.Literal[True],
|
||||
cert: TLSClientCertType | None = ...,
|
||||
retries: RetryType = ...,
|
||||
) -> AsyncResponse: ...
|
||||
|
||||
|
||||
async def put(
|
||||
url: str,
|
||||
data: BodyType | AsyncBodyType | None = None,
|
||||
*,
|
||||
json: typing.Any | None = None,
|
||||
params: QueryParameterType | None = None,
|
||||
headers: HeadersType | None = None,
|
||||
cookies: CookiesType | None = None,
|
||||
files: MultiPartFilesType | MultiPartFilesAltType | None = None,
|
||||
auth: HttpAuthenticationType | AsyncHttpAuthenticationType | None = None,
|
||||
timeout: TimeoutType | None = WRITE_DEFAULT_TIMEOUT,
|
||||
allow_redirects: bool = True,
|
||||
proxies: ProxyType | None = None,
|
||||
hooks: AsyncHookType[PreparedRequest | Response] | None = None,
|
||||
verify: TLSVerifyType | None = None,
|
||||
stream: bool | None = None,
|
||||
cert: TLSClientCertType | None = None,
|
||||
retries: RetryType = DEFAULT_RETRIES,
|
||||
) -> Response | AsyncResponse:
|
||||
r"""Sends a PUT request. This does not keep the connection alive. Use an :class:`AsyncSession` to reuse the connection.
|
||||
|
||||
:param url: URL for the new :class:`Request` object.
|
||||
:param params: (optional) Dictionary, list of tuples or bytes to send
|
||||
in the query string for the :class:`Request`.
|
||||
:param data: (optional) Dictionary, list of tuples, bytes, or (awaitable or not) file-like
|
||||
object to send in the body of the :class:`Request`.
|
||||
:param json: (optional) A JSON serializable Python object to send in the body of the :class:`Request`.
|
||||
:param headers: (optional) Dictionary of HTTP Headers to send with the :class:`Request`.
|
||||
:param cookies: (optional) Dict or CookieJar object to send with the :class:`Request`.
|
||||
:param files: (optional) Dictionary of ``'name': file-like-objects`` (or ``{'name': file-tuple}``)
|
||||
for multipart encoding upload.
|
||||
``file-tuple`` can be a 2-tuple ``('filename', fileobj)``, 3-tuple ``('filename', fileobj, 'content_type')``
|
||||
or a 4-tuple ``('filename', fileobj, 'content_type', custom_headers)``, where ``'content_type'`` is a string
|
||||
defining the content type of the given file and ``custom_headers`` a dict-like object containing additional headers
|
||||
to add for the file.
|
||||
:param auth: (optional) Auth tuple to enable Basic/Digest/Custom HTTP Auth.
|
||||
:param timeout: (optional) How many seconds to wait for the server to send data
|
||||
before giving up, as a float, or a :ref:`(connect timeout, read
|
||||
timeout) <timeouts>` tuple.
|
||||
:param allow_redirects: (optional) Boolean. Enable/disable GET/OPTIONS/POST/PUT/PATCH/DELETE/HEAD redirection.
|
||||
Defaults to ``True``.
|
||||
:param proxies: (optional) Dictionary mapping protocol to the URL of the proxy.
|
||||
:param verify: (optional) Either a boolean, in which case it controls whether we verify
|
||||
the server's TLS certificate, or a path passed as a string or os.Pathlike object,
|
||||
in which case it must be a path to a CA bundle to use.
|
||||
Defaults to ``True``.
|
||||
It is also possible to put the certificates (directly) in a string or bytes.
|
||||
:param stream: (optional) if ``False``, the response content will be immediately downloaded. Otherwise, the response will
|
||||
be of type :class:`AsyncResponse <AsyncResponse>` so that it will be awaitable.
|
||||
:param cert: (optional) if String, path to ssl client cert file (.pem).
|
||||
If Tuple, ('cert', 'key') pair, or ('cert', 'key', 'key_password').
|
||||
:param hooks: (optional) Register functions that should be called at very specific moment in the request lifecycle.
|
||||
:param retries: (optional) If integer, determine the number of retry in case of a timeout or connection error.
|
||||
Otherwise, for fine gained retry, use directly a ``Retry`` instance from urllib3.
|
||||
|
||||
:return: :class:`Response <Response>` object if stream=None or False. Otherwise :class:`AsyncResponse <AsyncResponse>`
|
||||
"""
|
||||
return await request( # type: ignore[misc]
|
||||
"PUT",
|
||||
url,
|
||||
data=data,
|
||||
params=params,
|
||||
json=json,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
files=files,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
proxies=proxies,
|
||||
verify=verify,
|
||||
stream=stream, # type: ignore[arg-type]
|
||||
cert=cert,
|
||||
hooks=hooks,
|
||||
retries=retries,
|
||||
)
|
||||
|
||||
|
||||
@typing.overload
|
||||
async def patch(
|
||||
url: str,
|
||||
data: BodyType | AsyncBodyType | None = ...,
|
||||
*,
|
||||
json: typing.Any | None = ...,
|
||||
params: QueryParameterType | None = ...,
|
||||
headers: HeadersType | None = ...,
|
||||
cookies: CookiesType | None = ...,
|
||||
files: MultiPartFilesType | MultiPartFilesAltType | None = ...,
|
||||
auth: HttpAuthenticationType | AsyncHttpAuthenticationType | None = ...,
|
||||
timeout: TimeoutType | None = ...,
|
||||
allow_redirects: bool = ...,
|
||||
proxies: ProxyType | None = ...,
|
||||
hooks: AsyncHookType[PreparedRequest | Response] | None = ...,
|
||||
verify: TLSVerifyType | None = ...,
|
||||
stream: typing.Literal[False] | typing.Literal[None] = ...,
|
||||
cert: TLSClientCertType | None = ...,
|
||||
retries: RetryType = ...,
|
||||
) -> Response: ...
|
||||
|
||||
|
||||
@typing.overload
|
||||
async def patch(
|
||||
url: str,
|
||||
data: BodyType | AsyncBodyType | None = ...,
|
||||
*,
|
||||
json: typing.Any | None = ...,
|
||||
params: QueryParameterType | None = ...,
|
||||
headers: HeadersType | None = ...,
|
||||
cookies: CookiesType | None = ...,
|
||||
files: MultiPartFilesType | MultiPartFilesAltType | None = ...,
|
||||
auth: HttpAuthenticationType | AsyncHttpAuthenticationType | None = ...,
|
||||
timeout: TimeoutType | None = ...,
|
||||
allow_redirects: bool = ...,
|
||||
proxies: ProxyType | None = ...,
|
||||
hooks: AsyncHookType[PreparedRequest | Response] | None = ...,
|
||||
verify: TLSVerifyType | None = ...,
|
||||
stream: typing.Literal[True],
|
||||
cert: TLSClientCertType | None = ...,
|
||||
retries: RetryType = ...,
|
||||
) -> AsyncResponse: ...
|
||||
|
||||
|
||||
async def patch(
|
||||
url: str,
|
||||
data: BodyType | AsyncBodyType | None = None,
|
||||
*,
|
||||
json: typing.Any | None = None,
|
||||
params: QueryParameterType | None = None,
|
||||
headers: HeadersType | None = None,
|
||||
cookies: CookiesType | None = None,
|
||||
files: MultiPartFilesType | MultiPartFilesAltType | None = None,
|
||||
auth: HttpAuthenticationType | AsyncHttpAuthenticationType | None = None,
|
||||
timeout: TimeoutType | None = WRITE_DEFAULT_TIMEOUT,
|
||||
allow_redirects: bool = True,
|
||||
proxies: ProxyType | None = None,
|
||||
hooks: AsyncHookType[PreparedRequest | Response] | None = None,
|
||||
verify: TLSVerifyType | None = None,
|
||||
stream: bool | None = None,
|
||||
cert: TLSClientCertType | None = None,
|
||||
retries: RetryType = DEFAULT_RETRIES,
|
||||
) -> Response | AsyncResponse:
|
||||
r"""Sends a PATCH request. This does not keep the connection alive. Use an :class:`AsyncSession` to reuse the connection.
|
||||
|
||||
:param url: URL for the new :class:`Request` object.
|
||||
:param params: (optional) Dictionary, list of tuples or bytes to send
|
||||
in the query string for the :class:`Request`.
|
||||
:param data: (optional) Dictionary, list of tuples, bytes, or (awaitable or not) file-like
|
||||
object to send in the body of the :class:`Request`.
|
||||
:param json: (optional) A JSON serializable Python object to send in the body of the :class:`Request`.
|
||||
:param headers: (optional) Dictionary of HTTP Headers to send with the :class:`Request`.
|
||||
:param cookies: (optional) Dict or CookieJar object to send with the :class:`Request`.
|
||||
:param files: (optional) Dictionary of ``'name': file-like-objects`` (or ``{'name': file-tuple}``)
|
||||
for multipart encoding upload.
|
||||
``file-tuple`` can be a 2-tuple ``('filename', fileobj)``, 3-tuple ``('filename', fileobj, 'content_type')``
|
||||
or a 4-tuple ``('filename', fileobj, 'content_type', custom_headers)``, where ``'content_type'`` is a string
|
||||
defining the content type of the given file and ``custom_headers`` a dict-like object containing additional headers
|
||||
to add for the file.
|
||||
:param auth: (optional) Auth tuple to enable Basic/Digest/Custom HTTP Auth.
|
||||
:param timeout: (optional) How many seconds to wait for the server to send data
|
||||
before giving up, as a float, or a :ref:`(connect timeout, read
|
||||
timeout) <timeouts>` tuple.
|
||||
:param allow_redirects: (optional) Boolean. Enable/disable GET/OPTIONS/POST/PUT/PATCH/DELETE/HEAD redirection.
|
||||
Defaults to ``True``.
|
||||
:param proxies: (optional) Dictionary mapping protocol to the URL of the proxy.
|
||||
:param verify: (optional) Either a boolean, in which case it controls whether we verify
|
||||
the server's TLS certificate, or a path passed as a string or os.Pathlike object,
|
||||
in which case it must be a path to a CA bundle to use.
|
||||
Defaults to ``True``.
|
||||
It is also possible to put the certificates (directly) in a string or bytes.
|
||||
:param stream: (optional) if ``False``, the response content will be immediately downloaded. Otherwise, the response will
|
||||
be of type :class:`AsyncResponse <AsyncResponse>` so that it will be awaitable.
|
||||
:param cert: (optional) if String, path to ssl client cert file (.pem).
|
||||
If Tuple, ('cert', 'key') pair, or ('cert', 'key', 'key_password').
|
||||
:param hooks: (optional) Register functions that should be called at very specific moment in the request lifecycle.
|
||||
:param retries: (optional) If integer, determine the number of retry in case of a timeout or connection error.
|
||||
Otherwise, for fine gained retry, use directly a ``Retry`` instance from urllib3.
|
||||
|
||||
:return: :class:`Response <Response>` object if stream=None or False. Otherwise :class:`AsyncResponse <AsyncResponse>`
|
||||
"""
|
||||
return await request( # type: ignore[misc]
|
||||
"PATCH",
|
||||
url,
|
||||
data=data,
|
||||
params=params,
|
||||
json=json,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
files=files,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
proxies=proxies,
|
||||
verify=verify,
|
||||
stream=stream, # type: ignore[arg-type]
|
||||
cert=cert,
|
||||
hooks=hooks,
|
||||
retries=retries,
|
||||
)
|
||||
|
||||
|
||||
@typing.overload
|
||||
async def delete(
|
||||
url: str,
|
||||
*,
|
||||
params: QueryParameterType | None = ...,
|
||||
headers: HeadersType | None = ...,
|
||||
cookies: CookiesType | None = ...,
|
||||
auth: HttpAuthenticationType | AsyncHttpAuthenticationType | None = ...,
|
||||
timeout: TimeoutType | None = ...,
|
||||
allow_redirects: bool = ...,
|
||||
proxies: ProxyType | None = ...,
|
||||
hooks: AsyncHookType[PreparedRequest | Response] | None = ...,
|
||||
verify: TLSVerifyType | None = ...,
|
||||
stream: typing.Literal[False] | typing.Literal[None] = ...,
|
||||
cert: TLSClientCertType | None = ...,
|
||||
retries: RetryType = ...,
|
||||
**kwargs: typing.Any,
|
||||
) -> Response: ...
|
||||
|
||||
|
||||
@typing.overload
|
||||
async def delete(
|
||||
url: str,
|
||||
*,
|
||||
params: QueryParameterType | None = ...,
|
||||
headers: HeadersType | None = ...,
|
||||
cookies: CookiesType | None = ...,
|
||||
auth: HttpAuthenticationType | AsyncHttpAuthenticationType | None = ...,
|
||||
timeout: TimeoutType | None = ...,
|
||||
allow_redirects: bool = ...,
|
||||
proxies: ProxyType | None = ...,
|
||||
hooks: AsyncHookType[PreparedRequest | Response] | None = ...,
|
||||
verify: TLSVerifyType | None = ...,
|
||||
stream: typing.Literal[True],
|
||||
cert: TLSClientCertType | None = ...,
|
||||
retries: RetryType = ...,
|
||||
**kwargs: typing.Any,
|
||||
) -> AsyncResponse: ...
|
||||
|
||||
|
||||
async def delete(
|
||||
url: str,
|
||||
*,
|
||||
params: QueryParameterType | None = None,
|
||||
headers: HeadersType | None = None,
|
||||
cookies: CookiesType | None = None,
|
||||
auth: HttpAuthenticationType | AsyncHttpAuthenticationType | None = None,
|
||||
timeout: TimeoutType | None = WRITE_DEFAULT_TIMEOUT,
|
||||
allow_redirects: bool = True,
|
||||
proxies: ProxyType | None = None,
|
||||
hooks: AsyncHookType[PreparedRequest | Response] | None = None,
|
||||
verify: TLSVerifyType | None = None,
|
||||
stream: bool | None = None,
|
||||
cert: TLSClientCertType | None = None,
|
||||
retries: RetryType = DEFAULT_RETRIES,
|
||||
**kwargs: typing.Any,
|
||||
) -> Response | AsyncResponse:
|
||||
r"""Sends a DELETE request. This does not keep the connection alive. Use an :class:`AsyncSession` to reuse the connection.
|
||||
|
||||
:param url: URL for the new :class:`Request` object.
|
||||
:param params: (optional) Dictionary, list of tuples or bytes to send
|
||||
in the query string for the :class:`Request`.
|
||||
:param headers: (optional) Dictionary of HTTP Headers to send with the :class:`Request`.
|
||||
:param cookies: (optional) Dict or CookieJar object to send with the :class:`Request`.
|
||||
:param auth: (optional) Auth tuple to enable Basic/Digest/Custom HTTP Auth.
|
||||
:param timeout: (optional) How many seconds to wait for the server to send data
|
||||
before giving up, as a float, or a :ref:`(connect timeout, read
|
||||
timeout) <timeouts>` tuple.
|
||||
:param allow_redirects: (optional) Boolean. Enable/disable GET/OPTIONS/POST/PUT/PATCH/DELETE/HEAD redirection.
|
||||
Defaults to ``True``.
|
||||
:param proxies: (optional) Dictionary mapping protocol to the URL of the proxy.
|
||||
:param verify: (optional) Either a boolean, in which case it controls whether we verify
|
||||
the server's TLS certificate, or a path passed as a string or os.Pathlike object,
|
||||
in which case it must be a path to a CA bundle to use.
|
||||
Defaults to ``True``.
|
||||
It is also possible to put the certificates (directly) in a string or bytes.
|
||||
:param stream: (optional) if ``False``, the response content will be immediately downloaded. Otherwise, the response will
|
||||
be of type :class:`AsyncResponse <AsyncResponse>` so that it will be awaitable.
|
||||
:param cert: (optional) if String, path to ssl client cert file (.pem).
|
||||
If Tuple, ('cert', 'key') pair, or ('cert', 'key', 'key_password').
|
||||
:param hooks: (optional) Register functions that should be called at very specific moment in the request lifecycle.
|
||||
:param retries: (optional) If integer, determine the number of retry in case of a timeout or connection error.
|
||||
Otherwise, for fine gained retry, use directly a ``Retry`` instance from urllib3.
|
||||
|
||||
:return: :class:`Response <Response>` object if stream=None or False. Otherwise :class:`AsyncResponse <AsyncResponse>`
|
||||
"""
|
||||
return await request( # type: ignore[misc]
|
||||
"DELETE",
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
auth=auth,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
proxies=proxies,
|
||||
verify=verify,
|
||||
stream=stream, # type: ignore[arg-type]
|
||||
cert=cert,
|
||||
hooks=hooks,
|
||||
retries=retries,
|
||||
**kwargs,
|
||||
)
|
||||
1623
.venv/lib/python3.9/site-packages/niquests/async_session.py
Normal file
1623
.venv/lib/python3.9/site-packages/niquests/async_session.py
Normal file
File diff suppressed because it is too large
Load Diff
429
.venv/lib/python3.9/site-packages/niquests/auth.py
Normal file
429
.venv/lib/python3.9/site-packages/niquests/auth.py
Normal file
@@ -0,0 +1,429 @@
|
||||
"""
|
||||
requests.auth
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
This module contains the authentication handlers for Requests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import typing
|
||||
from base64 import b64encode
|
||||
from dataclasses import dataclass, field
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ._compat import iscoroutinefunction
|
||||
from .cookies import extract_cookies_to_jar
|
||||
from .utils import parse_dict_header
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .models import PreparedRequest
|
||||
|
||||
CONTENT_TYPE_FORM_URLENCODED: str = "application/x-www-form-urlencoded"
|
||||
CONTENT_TYPE_MULTI_PART: str = "multipart/form-data"
|
||||
|
||||
|
||||
def _basic_auth_str(username: str | bytes, password: str | bytes) -> str:
|
||||
"""Returns a Basic Auth string."""
|
||||
|
||||
if isinstance(username, str):
|
||||
username = username.encode("utf-8")
|
||||
|
||||
if isinstance(password, str):
|
||||
password = password.encode("utf-8")
|
||||
|
||||
authstr = "Basic " + b64encode(b":".join((username, password))).strip().decode()
|
||||
|
||||
return authstr
|
||||
|
||||
|
||||
class AsyncAuthBase:
|
||||
"""Base class that all asynchronous auth implementations derive from"""
|
||||
|
||||
async def __call__(self, r: PreparedRequest) -> PreparedRequest:
|
||||
raise NotImplementedError("Auth hooks must be callable.")
|
||||
|
||||
|
||||
class AuthBase:
|
||||
"""Base class that all synchronous auth implementations derive from"""
|
||||
|
||||
def __call__(self, r: PreparedRequest) -> PreparedRequest:
|
||||
raise NotImplementedError("Auth hooks must be callable.")
|
||||
|
||||
|
||||
class BearerTokenAuth(AuthBase):
|
||||
"""Simple token injection in Authorization header"""
|
||||
|
||||
def __init__(self, token: str):
|
||||
self.token = token
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return self.token == getattr(other, "token", None)
|
||||
|
||||
def __ne__(self, other) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __call__(self, r):
|
||||
detect_token_type: list[str] = self.token.split(" ", maxsplit=1)
|
||||
|
||||
if len(detect_token_type) == 1:
|
||||
r.headers["Authorization"] = f"Bearer {self.token}"
|
||||
else:
|
||||
r.headers["Authorization"] = self.token
|
||||
|
||||
return r
|
||||
|
||||
|
||||
class HTTPBasicAuth(AuthBase):
|
||||
"""Attaches HTTP Basic Authentication to the given Request object."""
|
||||
|
||||
def __init__(self, username: str | bytes, password: str | bytes):
|
||||
self.username = username
|
||||
self.password = password
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return all(
|
||||
[
|
||||
self.username == getattr(other, "username", None),
|
||||
self.password == getattr(other, "password", None),
|
||||
]
|
||||
)
|
||||
|
||||
def __ne__(self, other) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __call__(self, r):
|
||||
r.headers["Authorization"] = _basic_auth_str(self.username, self.password)
|
||||
return r
|
||||
|
||||
|
||||
class HTTPProxyAuth(HTTPBasicAuth):
|
||||
"""Attaches HTTP Proxy Authentication to a given Request object."""
|
||||
|
||||
def __call__(self, r):
|
||||
r.headers["Proxy-Authorization"] = _basic_auth_str(self.username, self.password)
|
||||
return r
|
||||
|
||||
|
||||
@dataclass
|
||||
class DigestAuthState:
|
||||
"""Container for digest auth state per task/thread"""
|
||||
|
||||
init: bool = False
|
||||
last_nonce: str = ""
|
||||
nonce_count: int = 0
|
||||
chal: typing.Mapping[str, str | None] = field(default_factory=dict)
|
||||
pos: int | None = None
|
||||
num_401_calls: int | None = None
|
||||
|
||||
|
||||
class HTTPDigestAuth(AuthBase):
|
||||
"""Attaches HTTP Digest Authentication to the given Request object."""
|
||||
|
||||
def __init__(self, username: str, password: str):
|
||||
self.username = username
|
||||
self.password = password
|
||||
# Keep state in per-thread local storage
|
||||
self._thread_local: contextvars.ContextVar[DigestAuthState] = contextvars.ContextVar("digest_auth_state")
|
||||
|
||||
def init_per_thread_state(self) -> None:
|
||||
# Ensure state is initialized just once per-thread
|
||||
state = self._thread_local.get(None)
|
||||
|
||||
if state is None or not state.init:
|
||||
self._thread_local.set(DigestAuthState(init=True))
|
||||
|
||||
def build_digest_header(self, method: str, url: str) -> str | None:
|
||||
state = self._thread_local.get(None)
|
||||
|
||||
assert state is not None
|
||||
|
||||
realm = state.chal["realm"]
|
||||
nonce = state.chal["nonce"]
|
||||
qop = state.chal.get("qop")
|
||||
algorithm = state.chal.get("algorithm")
|
||||
opaque = state.chal.get("opaque")
|
||||
|
||||
hash_utf8: typing.Callable[[str | bytes], str] | None = None
|
||||
|
||||
if algorithm is None:
|
||||
_algorithm = "MD5"
|
||||
else:
|
||||
_algorithm = algorithm.upper()
|
||||
# lambdas assume digest modules are imported at the top level
|
||||
if _algorithm == "MD5" or _algorithm == "MD5-SESS":
|
||||
|
||||
def md5_utf8(x: str | bytes) -> str:
|
||||
if isinstance(x, str):
|
||||
x = x.encode("utf-8")
|
||||
return hashlib.md5(x).hexdigest()
|
||||
|
||||
hash_utf8 = md5_utf8
|
||||
elif _algorithm == "SHA":
|
||||
|
||||
def sha_utf8(x: str | bytes) -> str:
|
||||
if isinstance(x, str):
|
||||
x = x.encode("utf-8")
|
||||
return hashlib.sha1(x).hexdigest()
|
||||
|
||||
hash_utf8 = sha_utf8
|
||||
elif _algorithm == "SHA-256":
|
||||
|
||||
def sha256_utf8(x: str | bytes) -> str:
|
||||
if isinstance(x, str):
|
||||
x = x.encode("utf-8")
|
||||
return hashlib.sha256(x).hexdigest()
|
||||
|
||||
hash_utf8 = sha256_utf8
|
||||
elif _algorithm == "SHA-512":
|
||||
|
||||
def sha512_utf8(x: str | bytes) -> str:
|
||||
if isinstance(x, str):
|
||||
x = x.encode("utf-8")
|
||||
return hashlib.sha512(x).hexdigest()
|
||||
|
||||
hash_utf8 = sha512_utf8
|
||||
else:
|
||||
raise ValueError(f"'{_algorithm}' hashing algorithm is not supported")
|
||||
|
||||
KD = lambda s, d: hash_utf8(f"{s}:{d}") # noqa:E731
|
||||
|
||||
if hash_utf8 is None:
|
||||
return None
|
||||
|
||||
# XXX not implemented yet
|
||||
entdig = None
|
||||
p_parsed = urlparse(url)
|
||||
#: path is request-uri defined in RFC 2616 which should not be empty
|
||||
path = p_parsed.path or "/"
|
||||
if p_parsed.query:
|
||||
path += f"?{p_parsed.query}"
|
||||
|
||||
A1 = f"{self.username}:{realm}:{self.password}"
|
||||
A2 = f"{method}:{path}"
|
||||
|
||||
HA1 = hash_utf8(A1)
|
||||
HA2 = hash_utf8(A2)
|
||||
|
||||
if nonce == state.last_nonce:
|
||||
state.nonce_count += 1
|
||||
else:
|
||||
state.nonce_count = 1
|
||||
ncvalue = f"{state.nonce_count:08x}"
|
||||
s = str(state.nonce_count).encode("utf-8")
|
||||
|
||||
assert nonce is not None
|
||||
|
||||
s += nonce.encode("utf-8")
|
||||
s += time.ctime().encode("utf-8")
|
||||
s += os.urandom(8)
|
||||
|
||||
cnonce = hashlib.sha1(s).hexdigest()[:16]
|
||||
if _algorithm == "MD5-SESS":
|
||||
HA1 = hash_utf8(f"{HA1}:{nonce}:{cnonce}")
|
||||
|
||||
if not qop:
|
||||
respdig = KD(HA1, f"{nonce}:{HA2}")
|
||||
elif qop == "auth" or "auth" in qop.split(","):
|
||||
noncebit = f"{nonce}:{ncvalue}:{cnonce}:auth:{HA2}"
|
||||
respdig = KD(HA1, noncebit)
|
||||
else:
|
||||
# XXX handle auth-int.
|
||||
return None
|
||||
|
||||
state.last_nonce = nonce
|
||||
|
||||
# XXX should the partial digests be encoded too?
|
||||
base = f'username="{self.username}", realm="{realm}", nonce="{nonce}", uri="{path}", response="{respdig}"'
|
||||
if opaque:
|
||||
base += f', opaque="{opaque}"'
|
||||
if algorithm:
|
||||
base += f', algorithm="{algorithm}"'
|
||||
if entdig:
|
||||
base += f', digest="{entdig}"'
|
||||
if qop:
|
||||
base += f', qop="auth", nc={ncvalue}, cnonce="{cnonce}"'
|
||||
|
||||
return f"Digest {base}"
|
||||
|
||||
def handle_redirect(self, r, **kwargs) -> None:
|
||||
"""Reset num_401_calls counter on redirects."""
|
||||
state = self._thread_local.get(None)
|
||||
|
||||
assert state is not None
|
||||
|
||||
if r.is_redirect:
|
||||
state.num_401_calls = 1
|
||||
|
||||
async def async_handle_401(self, r, **kwargs):
|
||||
"""
|
||||
Takes the given response and tries digest-auth, if needed (async version).
|
||||
|
||||
:rtype: requests.Response
|
||||
"""
|
||||
state = self._thread_local.get(None)
|
||||
|
||||
assert state is not None
|
||||
|
||||
# If response is not 4xx, do not auth
|
||||
# See https://github.com/psf/requests/issues/3772
|
||||
if not 400 <= r.status_code < 500:
|
||||
state.num_401_calls = 1
|
||||
return r
|
||||
|
||||
if state.pos is not None:
|
||||
# Rewind the file position indicator of the body to where
|
||||
# it was to resend the request.
|
||||
r.request.body.seek(state.pos)
|
||||
s_auth = r.headers.get("www-authenticate", "")
|
||||
|
||||
if "digest" in s_auth.lower() and state.num_401_calls < 2: # type: ignore[operator]
|
||||
state.num_401_calls += 1 # type: ignore[operator]
|
||||
pat = re.compile(r"digest ", flags=re.IGNORECASE)
|
||||
state.chal = parse_dict_header(pat.sub("", s_auth, count=1))
|
||||
|
||||
# Consume content and release the original connection
|
||||
# to allow our new request to reuse the same one.
|
||||
await r.content
|
||||
await r.close()
|
||||
prep = r.request.copy()
|
||||
extract_cookies_to_jar(prep._cookies, r.request, r.raw)
|
||||
prep.prepare_cookies(prep._cookies)
|
||||
|
||||
prep.headers["Authorization"] = self.build_digest_header(prep.method, prep.url)
|
||||
_r = await r.connection.send(prep, **kwargs)
|
||||
_r.history.append(r)
|
||||
_r.request = prep
|
||||
|
||||
return _r
|
||||
|
||||
state.num_401_calls = 1
|
||||
return r
|
||||
|
||||
def handle_401(self, r, **kwargs):
|
||||
"""
|
||||
Takes the given response and tries digest-auth, if needed.
|
||||
|
||||
:rtype: requests.Response
|
||||
"""
|
||||
state = self._thread_local.get(None)
|
||||
|
||||
assert state is not None
|
||||
|
||||
# If response is not 4xx, do not auth
|
||||
# See https://github.com/psf/requests/issues/3772
|
||||
if not 400 <= r.status_code < 500:
|
||||
state.num_401_calls = 1
|
||||
return r
|
||||
|
||||
if state.pos is not None:
|
||||
# Rewind the file position indicator of the body to where
|
||||
# it was to resend the request.
|
||||
r.request.body.seek(state.pos)
|
||||
s_auth = r.headers.get("www-authenticate", "")
|
||||
|
||||
if "digest" in s_auth.lower() and state.num_401_calls < 2: # type: ignore[operator]
|
||||
state.num_401_calls += 1 # type: ignore[operator]
|
||||
pat = re.compile(r"digest ", flags=re.IGNORECASE)
|
||||
state.chal = parse_dict_header(pat.sub("", s_auth, count=1))
|
||||
|
||||
# Consume content and release the original connection
|
||||
# to allow our new request to reuse the same one.
|
||||
r.content
|
||||
r.close()
|
||||
prep = r.request.copy()
|
||||
extract_cookies_to_jar(prep._cookies, r.request, r.raw)
|
||||
prep.prepare_cookies(prep._cookies)
|
||||
|
||||
prep.headers["Authorization"] = self.build_digest_header(prep.method, prep.url)
|
||||
_r = r.connection.send(prep, **kwargs)
|
||||
_r.history.append(r)
|
||||
_r.request = prep
|
||||
|
||||
return _r
|
||||
|
||||
state.num_401_calls = 1
|
||||
return r
|
||||
|
||||
def __call__(self, r):
|
||||
# Initialize per-thread state, if needed
|
||||
self.init_per_thread_state()
|
||||
state = self._thread_local.get(None)
|
||||
assert state is not None
|
||||
# If we have a saved nonce, skip the 401
|
||||
if state.last_nonce:
|
||||
r.headers["Authorization"] = self.build_digest_header(r.method, r.url)
|
||||
try:
|
||||
state.pos = r.body.tell()
|
||||
except AttributeError:
|
||||
# In the case of HTTPDigestAuth being reused and the body of
|
||||
# the previous request was a file-like object, pos has the
|
||||
# file position of the previous body. Ensure it's set to
|
||||
# None.
|
||||
state.pos = None
|
||||
# Register sync hooks only - use AsyncHTTPDigestAuth for async sessions
|
||||
r.register_hook("response", self.handle_401)
|
||||
r.register_hook("response", self.handle_redirect)
|
||||
state.num_401_calls = 1
|
||||
|
||||
return r
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
return all(
|
||||
[
|
||||
self.username == getattr(other, "username", None),
|
||||
self.password == getattr(other, "password", None),
|
||||
]
|
||||
)
|
||||
|
||||
def __ne__(self, other) -> bool:
|
||||
return not self == other
|
||||
|
||||
|
||||
class AsyncHTTPDigestAuth(HTTPDigestAuth, AsyncAuthBase):
|
||||
"""Async version of HTTPDigestAuth for use with AsyncSession.
|
||||
|
||||
Attaches HTTP Digest Authentication to the given Request object and handles
|
||||
401 responses asynchronously.
|
||||
|
||||
Example usage::
|
||||
|
||||
>>> import niquests
|
||||
>>> auth = niquests.auth.AsyncHTTPDigestAuth('user', 'pass')
|
||||
>>> async with niquests.AsyncSession() as session:
|
||||
... r = await session.get('https://httpbin.org/digest-auth/auth/user/pass', auth=auth)
|
||||
... print(r.status_code)
|
||||
200
|
||||
"""
|
||||
|
||||
async def __call__(self, r):
|
||||
# Initialize per-thread state, if needed
|
||||
self.init_per_thread_state()
|
||||
state = self._thread_local.get(None)
|
||||
assert state is not None
|
||||
# If we have a saved nonce, skip the 401
|
||||
if state.last_nonce:
|
||||
r.headers["Authorization"] = self.build_digest_header(r.method, r.url)
|
||||
try:
|
||||
if iscoroutinefunction(r.body.tell):
|
||||
state.pos = await r.body.tell()
|
||||
else:
|
||||
state.pos = r.body.tell()
|
||||
|
||||
except AttributeError:
|
||||
# In the case of AsyncHTTPDigestAuth being reused and the body of
|
||||
# the previous request was a file-like object, pos has the
|
||||
# file position of the previous body. Ensure it's set to
|
||||
# None.
|
||||
state.pos = None
|
||||
# Register async hooks only
|
||||
r.register_hook("response", self.async_handle_401)
|
||||
r.register_hook("response", self.handle_redirect)
|
||||
state.num_401_calls = 1
|
||||
|
||||
return r
|
||||
592
.venv/lib/python3.9/site-packages/niquests/cookies.py
Normal file
592
.venv/lib/python3.9/site-packages/niquests/cookies.py
Normal file
@@ -0,0 +1,592 @@
|
||||
"""
|
||||
requests.cookies
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
Compatibility code to be able to use `http.cookiejar.CookieJar` with requests.
|
||||
|
||||
requests.utils imports from here, so be careful with imports.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import calendar
|
||||
import copy
|
||||
import threading
|
||||
import time
|
||||
import typing
|
||||
from collections.abc import MutableMapping
|
||||
from http import cookiejar as cookielib
|
||||
from http.cookiejar import CookieJar
|
||||
from http.cookies import Morsel
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
from .packages.urllib3 import BaseHTTPResponse
|
||||
from .utils import parse_scheme
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .models import PreparedRequest, Request
|
||||
|
||||
|
||||
class CookiePolicyLocalhostBypass(cookielib.DefaultCookiePolicy):
|
||||
"""A subclass of DefaultCookiePolicy to allow cookie set for domain=localhost.
|
||||
Credit goes to https://github.com/Pylons/webtest/blob/main/webtest/app.py#L60"""
|
||||
|
||||
def return_ok_domain(self, cookie, request):
|
||||
if cookie.domain == ".localhost":
|
||||
return True
|
||||
return cookielib.DefaultCookiePolicy.return_ok_domain(self, cookie, request)
|
||||
|
||||
def set_ok_domain(self, cookie, request):
|
||||
if cookie.domain == ".localhost":
|
||||
return True
|
||||
return cookielib.DefaultCookiePolicy.set_ok_domain(self, cookie, request)
|
||||
|
||||
|
||||
class MockRequest:
|
||||
"""Wraps a `requests.Request` to mimic a `urllib2.Request`.
|
||||
|
||||
The code in `http.cookiejar.CookieJar` expects this interface in order to correctly
|
||||
manage cookie policies, i.e., determine whether a cookie can be set, given the
|
||||
domains of the request and the cookie.
|
||||
|
||||
The original request object is read-only. The client is responsible for collecting
|
||||
the new headers via `get_new_headers()` and interpreting them appropriately. You
|
||||
probably want `get_cookie_header`, defined below.
|
||||
"""
|
||||
|
||||
def __init__(self, request):
|
||||
self._r = request
|
||||
self._new_headers = {}
|
||||
|
||||
try:
|
||||
self.type: str | None = parse_scheme(self._r.url)
|
||||
except ValueError:
|
||||
self.type = None
|
||||
|
||||
def get_type(self):
|
||||
return self.type
|
||||
|
||||
def get_host(self):
|
||||
return urlparse(self._r.url).netloc
|
||||
|
||||
def get_origin_req_host(self):
|
||||
return self.get_host()
|
||||
|
||||
def get_full_url(self):
|
||||
# Only return the response's URL if the user hadn't set the Host
|
||||
# header
|
||||
if not self._r.headers.get("Host"):
|
||||
return self._r.url
|
||||
# If they did set it, retrieve it and reconstruct the expected domain
|
||||
host = self._r.headers["Host"]
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("utf-8")
|
||||
parsed = urlparse(self._r.url)
|
||||
# Reconstruct the URL as we expect it
|
||||
return urlunparse(
|
||||
[
|
||||
parsed.scheme,
|
||||
host,
|
||||
parsed.path,
|
||||
parsed.params,
|
||||
parsed.query,
|
||||
parsed.fragment,
|
||||
]
|
||||
)
|
||||
|
||||
def is_unverifiable(self):
|
||||
return True
|
||||
|
||||
def has_header(self, name):
|
||||
return name in self._r.headers or name in self._new_headers
|
||||
|
||||
def get_header(self, name, default=None):
|
||||
return self._r.headers.get(name, self._new_headers.get(name, default))
|
||||
|
||||
def add_header(self, key, val):
|
||||
"""cookiejar has no legitimate use for this method; add it back if you find one."""
|
||||
raise NotImplementedError("Cookie headers should be added with add_unredirected_header()")
|
||||
|
||||
def add_unredirected_header(self, name, value):
|
||||
self._new_headers[name] = value
|
||||
|
||||
def get_new_headers(self):
|
||||
return self._new_headers
|
||||
|
||||
@property
|
||||
def unverifiable(self):
|
||||
return self.is_unverifiable()
|
||||
|
||||
@property
|
||||
def origin_req_host(self):
|
||||
return self.get_origin_req_host()
|
||||
|
||||
@property
|
||||
def host(self):
|
||||
return self.get_host()
|
||||
|
||||
|
||||
class MockResponse:
|
||||
"""Wraps a `httplib.HTTPMessage` to mimic a `urllib.addinfourl`.
|
||||
|
||||
...what? Basically, expose the parsed HTTP headers from the server response
|
||||
the way `http.cookiejar` expects to see them.
|
||||
"""
|
||||
|
||||
def __init__(self, headers):
|
||||
"""Make a MockResponse for `cookiejar` to read.
|
||||
|
||||
:param headers: a httplib.HTTPMessage or analogous carrying the headers
|
||||
"""
|
||||
self._headers = headers
|
||||
|
||||
def info(self):
|
||||
return self._headers
|
||||
|
||||
def getheaders(self, name):
|
||||
self._headers.getheaders(name)
|
||||
|
||||
|
||||
def extract_cookies_to_jar(
|
||||
jar: CookieJar,
|
||||
request: Request | PreparedRequest | None,
|
||||
response: BaseHTTPResponse | None,
|
||||
):
|
||||
"""Extract the cookies from the response into a CookieJar.
|
||||
|
||||
:param jar: http.cookiejar.CookieJar (not necessarily a RequestsCookieJar)
|
||||
:param request: our own requests.Request object
|
||||
:param response: urllib3.HTTPResponse object
|
||||
"""
|
||||
if request is None or response is None:
|
||||
raise ValueError("Attempt to extract cookie from undefined request and/or response")
|
||||
|
||||
if not (hasattr(response, "_original_response") and response._original_response):
|
||||
return
|
||||
if "Set-Cookie" not in response._original_response.msg:
|
||||
return
|
||||
# the _original_response field is the wrapped httplib.HTTPResponse object,
|
||||
req = MockRequest(request)
|
||||
# pull out the HTTPMessage with the headers and put it in the mock:
|
||||
res = MockResponse(response._original_response.msg) # type: ignore[attr-defined]
|
||||
jar.extract_cookies(res, req) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def get_cookie_header(jar, request) -> str | None:
|
||||
"""
|
||||
Produce an appropriate Cookie header string to be sent with `request`, or None.
|
||||
"""
|
||||
r = MockRequest(request)
|
||||
jar.add_cookie_header(r)
|
||||
return r.get_new_headers().get("Cookie")
|
||||
|
||||
|
||||
def remove_cookie_by_name(cookiejar, name, domain=None, path=None):
|
||||
"""Unsets a cookie by name, by default over all domains and paths.
|
||||
|
||||
Wraps CookieJar.clear(), is O(n).
|
||||
"""
|
||||
clearables = []
|
||||
for cookie in cookiejar:
|
||||
if cookie.name != name:
|
||||
continue
|
||||
if domain is not None and domain != cookie.domain:
|
||||
continue
|
||||
if path is not None and path != cookie.path:
|
||||
continue
|
||||
clearables.append((cookie.domain, cookie.path, cookie.name))
|
||||
|
||||
for domain, path, name in clearables:
|
||||
cookiejar.clear(domain, path, name)
|
||||
|
||||
|
||||
class CookieConflictError(RuntimeError):
|
||||
"""There are two cookies that meet the criteria specified in the cookie jar.
|
||||
Use .get and .set and include domain and path args in order to be more specific.
|
||||
"""
|
||||
|
||||
|
||||
class RequestsCookieJar(cookielib.CookieJar, MutableMapping):
|
||||
"""Compatibility class; is a http.cookiejar.CookieJar, but exposes a dict
|
||||
interface.
|
||||
|
||||
This is the CookieJar we create by default for requests and sessions that
|
||||
don't specify one, since some clients may expect response.cookies and
|
||||
session.cookies to support dict operations.
|
||||
|
||||
Requests does not use the dict interface internally; it's just for
|
||||
compatibility with external client code. All requests code should work
|
||||
out of the box with externally provided instances of ``CookieJar``, e.g.
|
||||
``LWPCookieJar`` and ``FileCookieJar``.
|
||||
|
||||
Unlike a regular CookieJar, this class is pickleable.
|
||||
|
||||
.. warning:: dictionary operations that are normally O(1) may be O(n).
|
||||
"""
|
||||
|
||||
def __init__(self, policy: cookielib.CookiePolicy | None = None, thread_free: bool = False):
|
||||
super().__init__(policy=policy or CookiePolicyLocalhostBypass())
|
||||
if thread_free:
|
||||
from .structures import DummyLock
|
||||
|
||||
self._cookies_lock = DummyLock()
|
||||
|
||||
def get(self, name, default=None, domain=None, path=None):
|
||||
"""Dict-like get() that also supports optional domain and path args in
|
||||
order to resolve naming collisions from using one cookie jar over
|
||||
multiple domains.
|
||||
|
||||
.. warning:: operation is O(n), not O(1).
|
||||
"""
|
||||
try:
|
||||
return self._find_no_duplicates(name, domain, path)
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def set(self, name, value, **kwargs):
|
||||
"""Dict-like set() that also supports optional domain and path args in
|
||||
order to resolve naming collisions from using one cookie jar over
|
||||
multiple domains.
|
||||
"""
|
||||
# support client code that unsets cookies by assignment of a None value:
|
||||
if value is None:
|
||||
remove_cookie_by_name(self, name, domain=kwargs.get("domain"), path=kwargs.get("path"))
|
||||
return
|
||||
|
||||
if isinstance(value, Morsel):
|
||||
c = morsel_to_cookie(value)
|
||||
else:
|
||||
c = create_cookie(name, value, **kwargs)
|
||||
self.set_cookie(c)
|
||||
return c
|
||||
|
||||
def iterkeys(self):
|
||||
"""Dict-like iterkeys() that returns an iterator of names of cookies
|
||||
from the jar.
|
||||
|
||||
.. seealso:: itervalues() and iteritems().
|
||||
"""
|
||||
for cookie in iter(self):
|
||||
yield cookie.name
|
||||
|
||||
def keys(self):
|
||||
"""Dict-like keys() that returns a list of names of cookies from the
|
||||
jar.
|
||||
|
||||
.. seealso:: values() and items().
|
||||
"""
|
||||
return list(self.iterkeys())
|
||||
|
||||
def itervalues(self):
|
||||
"""Dict-like itervalues() that returns an iterator of values of cookies
|
||||
from the jar.
|
||||
|
||||
.. seealso:: iterkeys() and iteritems().
|
||||
"""
|
||||
for cookie in iter(self):
|
||||
yield cookie.value
|
||||
|
||||
def values(self):
|
||||
"""Dict-like values() that returns a list of values of cookies from the
|
||||
jar.
|
||||
|
||||
.. seealso:: keys() and items().
|
||||
"""
|
||||
return list(self.itervalues())
|
||||
|
||||
def iteritems(self):
|
||||
"""Dict-like iteritems() that returns an iterator of name-value tuples
|
||||
from the jar.
|
||||
|
||||
.. seealso:: iterkeys() and itervalues().
|
||||
"""
|
||||
for cookie in iter(self):
|
||||
yield cookie.name, cookie.value
|
||||
|
||||
def items(self):
|
||||
"""Dict-like items() that returns a list of name-value tuples from the
|
||||
jar. Allows client-code to call ``dict(RequestsCookieJar)`` and get a
|
||||
vanilla python dict of key value pairs.
|
||||
|
||||
.. seealso:: keys() and values().
|
||||
"""
|
||||
# todo: comply and return ItemView!
|
||||
return list(self.iteritems())
|
||||
|
||||
def list_domains(self) -> list[str]:
|
||||
"""Utility method to list all the domains in the jar."""
|
||||
domains = []
|
||||
for cookie in iter(self):
|
||||
if cookie.domain not in domains:
|
||||
domains.append(cookie.domain)
|
||||
return domains
|
||||
|
||||
def list_paths(self) -> list[str]:
|
||||
"""Utility method to list all the paths in the jar."""
|
||||
paths = []
|
||||
for cookie in iter(self):
|
||||
if cookie.path not in paths:
|
||||
paths.append(cookie.path)
|
||||
return paths
|
||||
|
||||
def multiple_domains(self) -> bool:
|
||||
"""Returns True if there are multiple domains in the jar.
|
||||
Returns False otherwise.
|
||||
"""
|
||||
domains = []
|
||||
for cookie in iter(self):
|
||||
if cookie.domain is not None and cookie.domain in domains:
|
||||
return True
|
||||
domains.append(cookie.domain)
|
||||
return False # there is only one domain in jar
|
||||
|
||||
def get_dict(self, domain: str | None = None, path: str | None = None) -> dict[str, str | None]:
|
||||
"""Takes as an argument an optional domain and path and returns a plain
|
||||
old Python dict of name-value pairs of cookies that meet the
|
||||
requirements.
|
||||
"""
|
||||
dictionary = {}
|
||||
for cookie in iter(self):
|
||||
if (domain is None or cookie.domain == domain) and (path is None or cookie.path == path):
|
||||
dictionary[cookie.name] = cookie.value
|
||||
return dictionary
|
||||
|
||||
def __contains__(self, name) -> bool:
|
||||
try:
|
||||
return super().__contains__(name)
|
||||
except CookieConflictError:
|
||||
return True
|
||||
|
||||
def __getitem__(self, name):
|
||||
"""Dict-like __getitem__() for compatibility with client code. Throws
|
||||
exception if there are more than one cookie with name. In that case,
|
||||
use the more explicit get() method instead.
|
||||
|
||||
.. warning:: operation is O(n), not O(1).
|
||||
"""
|
||||
return self._find_no_duplicates(name)
|
||||
|
||||
def __setitem__(self, name, value):
|
||||
"""Dict-like __setitem__ for compatibility with client code. Throws
|
||||
exception if there is already a cookie of that name in the jar. In that
|
||||
case, use the more explicit set() method instead.
|
||||
"""
|
||||
self.set(name, value)
|
||||
|
||||
def __delitem__(self, name):
|
||||
"""Deletes a cookie given a name. Wraps ``http.cookiejar.CookieJar``'s
|
||||
``remove_cookie_by_name()``.
|
||||
"""
|
||||
remove_cookie_by_name(self, name)
|
||||
|
||||
def set_cookie(self, cookie, *args, **kwargs):
|
||||
if hasattr(cookie.value, "startswith") and cookie.value.startswith('"') and cookie.value.endswith('"'):
|
||||
cookie.value = cookie.value.replace('\\"', "")
|
||||
return super().set_cookie(cookie, *args, **kwargs)
|
||||
|
||||
def update(self, other): # type: ignore[override]
|
||||
"""Updates this jar with cookies from another CookieJar or dict-like"""
|
||||
if isinstance(other, cookielib.CookieJar):
|
||||
for cookie in other:
|
||||
self.set_cookie(copy.copy(cookie))
|
||||
else:
|
||||
super().update(other)
|
||||
|
||||
def _find(self, name, domain=None, path=None):
|
||||
"""Requests uses this method internally to get cookie values.
|
||||
|
||||
If there are conflicting cookies, _find arbitrarily chooses one.
|
||||
See _find_no_duplicates if you want an exception thrown if there are
|
||||
conflicting cookies.
|
||||
|
||||
:param name: a string containing name of cookie
|
||||
:param domain: (optional) string containing domain of cookie
|
||||
:param path: (optional) string containing path of cookie
|
||||
:return: cookie.value
|
||||
"""
|
||||
for cookie in iter(self):
|
||||
if cookie.name == name:
|
||||
if domain is None or cookie.domain == domain:
|
||||
if path is None or cookie.path == path:
|
||||
return cookie.value
|
||||
|
||||
raise KeyError(f"name={name!r}, domain={domain!r}, path={path!r}")
|
||||
|
||||
def _find_no_duplicates(self, name, domain=None, path=None):
|
||||
"""Both ``__get_item__`` and ``get`` call this function: it's never
|
||||
used elsewhere in Requests.
|
||||
|
||||
:param name: a string containing name of cookie
|
||||
:param domain: (optional) string containing domain of cookie
|
||||
:param path: (optional) string containing path of cookie
|
||||
:raises KeyError: if cookie is not found
|
||||
:raises CookieConflictError: if there are multiple cookies
|
||||
that match name and optionally domain and path
|
||||
:return: cookie.value
|
||||
"""
|
||||
toReturn = None
|
||||
for cookie in iter(self):
|
||||
if cookie.name == name:
|
||||
if domain is None or cookie.domain == domain:
|
||||
if path is None or cookie.path == path:
|
||||
if toReturn is not None:
|
||||
# if there are multiple cookies that meet passed in criteria
|
||||
raise CookieConflictError(f"There are multiple cookies with name, {name!r}")
|
||||
# we will eventually return this as long as no cookie conflict
|
||||
toReturn = cookie.value
|
||||
|
||||
if toReturn:
|
||||
return toReturn
|
||||
raise KeyError(f"name={name!r}, domain={domain!r}, path={path!r}")
|
||||
|
||||
def __getstate__(self):
|
||||
"""Unlike a normal CookieJar, this class is pickleable."""
|
||||
state = self.__dict__.copy()
|
||||
# remove the unpickleable RLock object
|
||||
state.pop("_cookies_lock")
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""Unlike a normal CookieJar, this class is pickleable."""
|
||||
self.__dict__.update(state)
|
||||
if "_cookies_lock" not in self.__dict__:
|
||||
self._cookies_lock = threading.RLock() # type: ignore[assignment]
|
||||
|
||||
def copy(self):
|
||||
"""Return a copy of this RequestsCookieJar."""
|
||||
new_cj = RequestsCookieJar()
|
||||
new_cj.set_policy(self.get_policy())
|
||||
new_cj.update(self)
|
||||
return new_cj
|
||||
|
||||
def get_policy(self):
|
||||
"""Return the CookiePolicy instance used."""
|
||||
return self._policy # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _copy_cookie_jar(jar):
|
||||
if jar is None:
|
||||
return None
|
||||
|
||||
if hasattr(jar, "copy"):
|
||||
# We're dealing with an instance of RequestsCookieJar
|
||||
return jar.copy()
|
||||
# We're dealing with a generic CookieJar instance
|
||||
new_jar = copy.copy(jar)
|
||||
new_jar.clear()
|
||||
for cookie in jar:
|
||||
new_jar.set_cookie(copy.copy(cookie))
|
||||
return new_jar
|
||||
|
||||
|
||||
def create_cookie(name, value, **kwargs):
|
||||
"""Make a cookie from underspecified parameters.
|
||||
|
||||
By default, the pair of `name` and `value` will be set for the domain ''
|
||||
and sent on every request (this is sometimes called a "supercookie").
|
||||
"""
|
||||
result = {
|
||||
"version": 0,
|
||||
"name": name,
|
||||
"value": value,
|
||||
"port": None,
|
||||
"domain": "",
|
||||
"path": "/",
|
||||
"secure": False,
|
||||
"expires": None,
|
||||
"discard": True,
|
||||
"comment": None,
|
||||
"comment_url": None,
|
||||
"rest": {"HttpOnly": None},
|
||||
"rfc2109": False,
|
||||
}
|
||||
|
||||
badargs = set(kwargs) - set(result)
|
||||
if badargs:
|
||||
raise TypeError(f"create_cookie() got unexpected keyword arguments: {list(badargs)}")
|
||||
|
||||
result.update(kwargs)
|
||||
result["port_specified"] = bool(result["port"])
|
||||
result["domain_specified"] = bool(result["domain"])
|
||||
result["domain_initial_dot"] = result["domain"].startswith(".")
|
||||
result["path_specified"] = bool(result["path"])
|
||||
|
||||
return cookielib.Cookie(**result)
|
||||
|
||||
|
||||
def morsel_to_cookie(morsel):
|
||||
"""Convert a Morsel object into a Cookie containing the one k/v pair."""
|
||||
|
||||
expires = None
|
||||
if morsel["max-age"]:
|
||||
try:
|
||||
expires = int(time.time() + int(morsel["max-age"]))
|
||||
except ValueError:
|
||||
raise TypeError(f"max-age: {morsel['max-age']} must be integer")
|
||||
elif morsel["expires"]:
|
||||
time_template = "%a, %d-%b-%Y %H:%M:%S GMT"
|
||||
expires = calendar.timegm(time.strptime(morsel["expires"], time_template))
|
||||
return create_cookie(
|
||||
comment=morsel["comment"],
|
||||
comment_url=bool(morsel["comment"]),
|
||||
discard=False,
|
||||
domain=morsel["domain"],
|
||||
expires=expires,
|
||||
name=morsel.key,
|
||||
path=morsel["path"],
|
||||
port=None,
|
||||
rest={"HttpOnly": morsel["httponly"]},
|
||||
rfc2109=False,
|
||||
secure=bool(morsel["secure"]),
|
||||
value=morsel.value,
|
||||
version=morsel["version"] or 0,
|
||||
)
|
||||
|
||||
|
||||
def cookiejar_from_dict(
|
||||
cookie_dict: typing.MutableMapping[str, str] | None,
|
||||
cookiejar: RequestsCookieJar | cookielib.CookieJar | None = None,
|
||||
overwrite: bool = True,
|
||||
thread_free: bool = False,
|
||||
) -> RequestsCookieJar | cookielib.CookieJar:
|
||||
"""Returns a CookieJar from a key/value dictionary.
|
||||
|
||||
:param cookie_dict: Dict of key/values to insert into CookieJar.
|
||||
:param cookiejar: (optional) A cookiejar to add the cookies to.
|
||||
:param overwrite: (optional) If False, will not replace cookies
|
||||
already in the jar with new ones.
|
||||
"""
|
||||
if cookiejar is None:
|
||||
cookiejar = RequestsCookieJar(thread_free=thread_free)
|
||||
|
||||
if cookie_dict is not None:
|
||||
names_from_jar = [cookie.name for cookie in cookiejar]
|
||||
for name in cookie_dict:
|
||||
if overwrite or (name not in names_from_jar):
|
||||
cookiejar.set_cookie(create_cookie(name, cookie_dict[name]))
|
||||
|
||||
return cookiejar
|
||||
|
||||
|
||||
def merge_cookies(
|
||||
cookiejar: RequestsCookieJar | cookielib.CookieJar,
|
||||
cookies: typing.Mapping[str, str] | RequestsCookieJar | CookieJar,
|
||||
) -> RequestsCookieJar | cookielib.CookieJar:
|
||||
"""Add cookies to cookiejar and returns a merged CookieJar.
|
||||
|
||||
:param cookiejar: CookieJar object to add the cookies to.
|
||||
:param cookies: Dictionary or CookieJar object to be added.
|
||||
"""
|
||||
if not isinstance(cookiejar, cookielib.CookieJar):
|
||||
raise ValueError("You can only merge into CookieJar")
|
||||
|
||||
if isinstance(cookies, dict):
|
||||
cookiejar = cookiejar_from_dict(cookies, cookiejar=cookiejar, overwrite=False)
|
||||
elif isinstance(cookies, cookielib.CookieJar):
|
||||
if isinstance(cookiejar, RequestsCookieJar):
|
||||
cookiejar.update(cookies)
|
||||
else:
|
||||
for cookie_in_jar in cookies:
|
||||
cookiejar.set_cookie(cookie_in_jar)
|
||||
|
||||
return cookiejar
|
||||
155
.venv/lib/python3.9/site-packages/niquests/exceptions.py
Normal file
155
.venv/lib/python3.9/site-packages/niquests/exceptions.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""
|
||||
requests.exceptions
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
This module contains the set of Requests' exceptions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from json import JSONDecodeError as CompatJSONDecodeError
|
||||
|
||||
from .packages.urllib3.exceptions import HTTPError as BaseHTTPError
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .models import PreparedRequest, Response
|
||||
|
||||
|
||||
class RequestException(IOError):
|
||||
"""There was an ambiguous exception that occurred while handling your
|
||||
request.
|
||||
"""
|
||||
|
||||
response: Response | None
|
||||
request: PreparedRequest | None
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
"""Initialize RequestException with `request` and `response` objects."""
|
||||
response = kwargs.pop("response", None)
|
||||
self.response = response
|
||||
self.request = kwargs.pop("request", None)
|
||||
if self.response is not None and not self.request and hasattr(self.response, "request"):
|
||||
self.request = self.response.request
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class InvalidJSONError(RequestException):
|
||||
"""A JSON error occurred."""
|
||||
|
||||
|
||||
class JSONDecodeError(InvalidJSONError, CompatJSONDecodeError):
|
||||
"""Couldn't decode the text into json"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
Construct the JSONDecodeError instance first with all
|
||||
args. Then use it's args to construct the IOError so that
|
||||
the json specific args aren't used as IOError specific args
|
||||
and the error message from JSONDecodeError is preserved.
|
||||
"""
|
||||
CompatJSONDecodeError.__init__(self, *args)
|
||||
InvalidJSONError.__init__(self, *self.args, **kwargs)
|
||||
|
||||
|
||||
class HTTPError(RequestException):
|
||||
"""An HTTP error occurred."""
|
||||
|
||||
|
||||
class ConnectionError(RequestException):
|
||||
"""A Connection error occurred."""
|
||||
|
||||
|
||||
class ProxyError(ConnectionError):
|
||||
"""A proxy error occurred."""
|
||||
|
||||
|
||||
class SSLError(ConnectionError):
|
||||
"""An SSL error occurred."""
|
||||
|
||||
|
||||
class Timeout(RequestException):
|
||||
"""The request timed out.
|
||||
|
||||
Catching this error will catch both
|
||||
:exc:`~requests.exceptions.ConnectTimeout` and
|
||||
:exc:`~requests.exceptions.ReadTimeout` errors.
|
||||
"""
|
||||
|
||||
|
||||
class ConnectTimeout(ConnectionError, Timeout):
|
||||
"""The request timed out while trying to connect to the remote server.
|
||||
|
||||
Requests that produced this error are safe to retry.
|
||||
"""
|
||||
|
||||
|
||||
class ReadTimeout(Timeout):
|
||||
"""The server did not send any data in the allotted amount of time."""
|
||||
|
||||
|
||||
class URLRequired(RequestException):
|
||||
"""A valid URL is required to make a request."""
|
||||
|
||||
|
||||
class TooManyRedirects(RequestException):
|
||||
"""Too many redirects."""
|
||||
|
||||
|
||||
class MissingSchema(RequestException, ValueError):
|
||||
"""The URL scheme (e.g. http or https) is missing."""
|
||||
|
||||
|
||||
class InvalidSchema(RequestException, ValueError):
|
||||
"""The URL scheme provided is either invalid or unsupported."""
|
||||
|
||||
|
||||
class InvalidURL(RequestException, ValueError):
|
||||
"""The URL provided was somehow invalid."""
|
||||
|
||||
|
||||
class InvalidHeader(RequestException, ValueError):
|
||||
"""The header value provided was somehow invalid."""
|
||||
|
||||
|
||||
class InvalidProxyURL(InvalidURL):
|
||||
"""The proxy URL provided is invalid."""
|
||||
|
||||
|
||||
class ChunkedEncodingError(RequestException):
|
||||
"""The server declared chunked encoding but sent an invalid chunk."""
|
||||
|
||||
|
||||
class ContentDecodingError(RequestException, BaseHTTPError):
|
||||
"""Failed to decode response content."""
|
||||
|
||||
|
||||
class StreamConsumedError(RequestException, TypeError):
|
||||
"""The content for this response was already consumed."""
|
||||
|
||||
|
||||
class RetryError(RequestException):
|
||||
"""Custom retries logic failed"""
|
||||
|
||||
|
||||
class UnrewindableBodyError(RequestException):
|
||||
"""Requests encountered an error when trying to rewind a body."""
|
||||
|
||||
|
||||
class MultiplexingError(RequestException):
|
||||
"""Requests encountered an unresolvable error in multiplexed mode."""
|
||||
|
||||
|
||||
# Warnings
|
||||
|
||||
|
||||
class RequestsWarning(Warning):
|
||||
"""Base warning for Requests."""
|
||||
|
||||
|
||||
class FileModeWarning(RequestsWarning, DeprecationWarning):
|
||||
"""A file was opened in text mode, but Requests determined its binary length."""
|
||||
|
||||
|
||||
class RequestsDependencyWarning(RequestsWarning):
|
||||
"""An imported dependency doesn't match the expected version range."""
|
||||
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
This subpackage hold anything that is very relevant
|
||||
to the HTTP ecosystem but not per-say Niquests core logic.
|
||||
"""
|
||||
@@ -0,0 +1,394 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import typing
|
||||
from datetime import timedelta
|
||||
|
||||
from pyodide.ffi import run_sync # type: ignore[import]
|
||||
from pyodide.http import pyfetch # type: ignore[import]
|
||||
|
||||
from ..._constant import DEFAULT_RETRIES
|
||||
from ...adapters import BaseAdapter
|
||||
from ...exceptions import ConnectionError, ConnectTimeout
|
||||
from ...models import PreparedRequest, Response
|
||||
from ...packages.urllib3.exceptions import MaxRetryError
|
||||
from ...packages.urllib3.response import BytesQueueBuffer
|
||||
from ...packages.urllib3.response import HTTPResponse as BaseHTTPResponse
|
||||
from ...packages.urllib3.util import Timeout as TimeoutSauce
|
||||
from ...packages.urllib3.util.retry import Retry
|
||||
from ...structures import CaseInsensitiveDict
|
||||
from ...utils import get_encoding_from_headers
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ...typing import ProxyType, RetryType, TLSClientCertType, TLSVerifyType
|
||||
|
||||
|
||||
class _PyodideRawIO:
|
||||
"""File-like wrapper around a Pyodide Fetch response with true streaming via JSPI.
|
||||
|
||||
When constructed with a JS Response object, reads chunks incrementally from the
|
||||
JavaScript ReadableStream using ``run_sync(reader.read())`` per chunk.
|
||||
When constructed with preloaded content (non-streaming), serves from a memory buffer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
js_response: typing.Any = None,
|
||||
preloaded_content: bytes | None = None,
|
||||
) -> None:
|
||||
self._js_response = js_response
|
||||
self._buffer = BytesQueueBuffer()
|
||||
self._closed = False
|
||||
self._finished = preloaded_content is not None or js_response is None
|
||||
self._reader: typing.Any = None
|
||||
self.headers: dict[str, str] = {}
|
||||
self.extension: typing.Any = None
|
||||
|
||||
if preloaded_content is not None:
|
||||
self._buffer.put(preloaded_content)
|
||||
|
||||
def _ensure_reader(self) -> None:
|
||||
"""Initialize the ReadableStream reader if not already done."""
|
||||
if self._reader is None and self._js_response is not None:
|
||||
try:
|
||||
body = self._js_response.body
|
||||
if body is not None:
|
||||
self._reader = body.getReader()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _get_next_chunk(self) -> bytes | None:
|
||||
"""Read the next chunk from the JS ReadableStream, blocking via JSPI."""
|
||||
self._ensure_reader()
|
||||
|
||||
if self._reader is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
result = run_sync(self._reader.read())
|
||||
|
||||
if result.done:
|
||||
return None
|
||||
|
||||
value = result.value
|
||||
if value is not None:
|
||||
return bytes(value.to_py())
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def read(
|
||||
self,
|
||||
amt: int | None = None,
|
||||
decode_content: bool = True,
|
||||
) -> bytes:
|
||||
if self._closed:
|
||||
return self._buffer.get(len(self._buffer)) if len(self._buffer) > 0 else b""
|
||||
|
||||
if self._finished:
|
||||
if len(self._buffer) == 0:
|
||||
return b""
|
||||
if amt is None or amt < 0:
|
||||
return self._buffer.get(len(self._buffer))
|
||||
return self._buffer.get(min(amt, len(self._buffer)))
|
||||
|
||||
if amt is None or amt < 0:
|
||||
# Read everything remaining
|
||||
while True:
|
||||
chunk = self._get_next_chunk()
|
||||
if chunk is None:
|
||||
break
|
||||
self._buffer.put(chunk)
|
||||
self._finished = True
|
||||
return self._buffer.get(len(self._buffer)) if len(self._buffer) > 0 else b""
|
||||
|
||||
# Read until we have enough bytes or stream ends
|
||||
while len(self._buffer) < amt and not self._finished:
|
||||
chunk = self._get_next_chunk()
|
||||
if chunk is None:
|
||||
self._finished = True
|
||||
break
|
||||
self._buffer.put(chunk)
|
||||
|
||||
if len(self._buffer) == 0:
|
||||
return b""
|
||||
|
||||
return self._buffer.get(min(amt, len(self._buffer)))
|
||||
|
||||
def stream(self, amt: int, decode_content: bool = True) -> typing.Generator[bytes, None, None]:
|
||||
"""Iterate over chunks of the response."""
|
||||
while True:
|
||||
chunk = self.read(amt)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
def close(self) -> None:
|
||||
self._closed = True
|
||||
if self._reader is not None:
|
||||
try:
|
||||
run_sync(self._reader.cancel())
|
||||
except Exception:
|
||||
pass
|
||||
self._reader = None
|
||||
self._js_response = None
|
||||
|
||||
def __iter__(self) -> typing.Iterator[bytes]:
|
||||
return self
|
||||
|
||||
def __next__(self) -> bytes:
|
||||
chunk = self.read(8192)
|
||||
if not chunk:
|
||||
raise StopIteration
|
||||
return chunk
|
||||
|
||||
|
||||
class PyodideAdapter(BaseAdapter):
|
||||
"""Synchronous adapter for making HTTP requests in Pyodide using JSPI + pyfetch."""
|
||||
|
||||
def __init__(self, max_retries: RetryType = DEFAULT_RETRIES) -> None:
|
||||
super().__init__()
|
||||
|
||||
if isinstance(max_retries, Retry):
|
||||
self.max_retries = max_retries
|
||||
else:
|
||||
self.max_retries = Retry.from_int(max_retries)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<PyodideAdapter WASM/>"
|
||||
|
||||
def send(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
stream: bool = False,
|
||||
timeout: int | float | tuple | TimeoutSauce | None = None,
|
||||
verify: TLSVerifyType = True,
|
||||
cert: TLSClientCertType | None = None,
|
||||
proxies: ProxyType | None = None,
|
||||
on_post_connection: typing.Callable[[typing.Any], None] | None = None,
|
||||
on_upload_body: typing.Callable[[int, int | None, bool, bool], None] | None = None,
|
||||
on_early_response: typing.Callable[[Response], None] | None = None,
|
||||
multiplexed: bool = False,
|
||||
) -> Response:
|
||||
"""Send a PreparedRequest using Pyodide's pyfetch (synchronous via JSPI)."""
|
||||
if isinstance(timeout, tuple):
|
||||
if len(timeout) == 3:
|
||||
timeout = timeout[2] or timeout[0] # prefer total, fallback connect
|
||||
else:
|
||||
timeout = timeout[0] # use connect
|
||||
elif isinstance(timeout, TimeoutSauce):
|
||||
timeout = timeout.total or timeout.connect_timeout
|
||||
|
||||
retries = self.max_retries
|
||||
method = request.method or "GET"
|
||||
start = time.time()
|
||||
|
||||
while True:
|
||||
try:
|
||||
response = self._do_send(request, stream, timeout)
|
||||
except Exception as err:
|
||||
retries = retries.increment(method, request.url, error=err)
|
||||
|
||||
retries.sleep()
|
||||
continue
|
||||
|
||||
base_response = BaseHTTPResponse(
|
||||
body=b"",
|
||||
headers=response.headers,
|
||||
status=response.status_code,
|
||||
request_method=request.method,
|
||||
request_url=request.url,
|
||||
)
|
||||
|
||||
has_retry_after = bool(response.headers.get("Retry-After"))
|
||||
|
||||
if retries.is_retry(method, response.status_code, has_retry_after):
|
||||
try:
|
||||
retries = retries.increment(method, request.url, response=base_response)
|
||||
except MaxRetryError:
|
||||
if retries.raise_on_status:
|
||||
raise
|
||||
return response
|
||||
|
||||
retries.sleep(base_response)
|
||||
continue
|
||||
|
||||
response.elapsed = timedelta(seconds=time.time() - start)
|
||||
return response
|
||||
|
||||
def _do_send(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
stream: bool,
|
||||
timeout: int | float | None,
|
||||
) -> Response:
|
||||
"""Perform the actual request using pyfetch made synchronous via JSPI."""
|
||||
url = request.url or ""
|
||||
scheme = url.split("://")[0].lower() if "://" in url else ""
|
||||
|
||||
# WebSocket: delegate to browser native WebSocket API
|
||||
if scheme in ("ws", "wss"):
|
||||
return self._do_send_ws(request, url)
|
||||
|
||||
# SSE: delegate to pyfetch streaming + manual SSE parsing
|
||||
if scheme in ("sse", "psse"):
|
||||
return self._do_send_sse(request, url, scheme)
|
||||
|
||||
# Prepare headers
|
||||
headers_dict: dict[str, str] = {}
|
||||
if request.headers:
|
||||
for key, value in request.headers.items():
|
||||
if key.lower() not in ("host", "content-length", "connection", "transfer-encoding"):
|
||||
headers_dict[key] = value
|
||||
|
||||
# Prepare body
|
||||
body = request.body
|
||||
if body is not None:
|
||||
if isinstance(body, str):
|
||||
body = body.encode("utf-8")
|
||||
elif isinstance(body, typing.Iterable) and not isinstance(body, (bytes, bytearray)):
|
||||
chunks: list[bytes] = []
|
||||
for chunk in body:
|
||||
if isinstance(chunk, str):
|
||||
chunks.append(chunk.encode("utf-8"))
|
||||
elif isinstance(chunk, bytes):
|
||||
chunks.append(chunk)
|
||||
body = b"".join(chunks)
|
||||
|
||||
# Build fetch options
|
||||
fetch_options: dict[str, typing.Any] = {
|
||||
"method": request.method or "GET",
|
||||
"headers": headers_dict,
|
||||
}
|
||||
|
||||
if body:
|
||||
fetch_options["body"] = body
|
||||
|
||||
# Use AbortSignal.timeout() for timeout — the browser-native mechanism.
|
||||
# run_sync cannot interrupt a JS Promise.
|
||||
signal = None
|
||||
|
||||
if timeout is not None:
|
||||
from js import AbortSignal # type: ignore[import]
|
||||
|
||||
signal = AbortSignal.timeout(int(timeout * 1000))
|
||||
|
||||
try:
|
||||
js_response = run_sync(pyfetch(request.url, signal=signal, **fetch_options))
|
||||
except Exception as e:
|
||||
err_str = str(e).lower()
|
||||
if "abort" in err_str or "timeout" in err_str or "timed out" in err_str:
|
||||
raise ConnectTimeout(f"Connection to {request.url} timed out")
|
||||
raise ConnectionError(f"Failed to fetch {request.url}: {e}")
|
||||
|
||||
# Parse response headers
|
||||
response_headers: dict[str, str] = {}
|
||||
try:
|
||||
if hasattr(js_response, "headers"):
|
||||
js_headers = js_response.headers
|
||||
if hasattr(js_headers, "items"):
|
||||
for key, value in js_headers.items():
|
||||
response_headers[key] = value
|
||||
elif hasattr(js_headers, "entries"):
|
||||
for entry in js_headers.entries():
|
||||
response_headers[entry[0]] = entry[1]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Build response object
|
||||
response = Response()
|
||||
response.status_code = js_response.status
|
||||
response.headers = CaseInsensitiveDict(response_headers)
|
||||
response.request = request
|
||||
response.url = js_response.url or request.url
|
||||
response.encoding = get_encoding_from_headers(response_headers)
|
||||
|
||||
try:
|
||||
response.reason = js_response.status_text or ""
|
||||
except Exception:
|
||||
response.reason = ""
|
||||
|
||||
if stream:
|
||||
# Streaming: pass the underlying JS Response to the raw IO
|
||||
# so it can read chunks incrementally via run_sync(reader.read())
|
||||
raw_io = _PyodideRawIO(js_response=js_response.js_response)
|
||||
raw_io.headers = response_headers
|
||||
response.raw = raw_io # type: ignore
|
||||
response._content = False # type: ignore[assignment]
|
||||
response._content_consumed = False
|
||||
else:
|
||||
# Non-streaming: read full body upfront
|
||||
try:
|
||||
response_body: bytes = run_sync(js_response.bytes())
|
||||
except Exception:
|
||||
response_body = b""
|
||||
|
||||
raw_io = _PyodideRawIO(preloaded_content=response_body)
|
||||
raw_io.headers = response_headers
|
||||
response.raw = raw_io # type: ignore
|
||||
response._content = response_body
|
||||
|
||||
return response
|
||||
|
||||
def _do_send_ws(self, request: PreparedRequest, url: str) -> Response:
|
||||
"""Handle WebSocket connections via browser native WebSocket API."""
|
||||
from ._ws import PyodideWebSocketExtension
|
||||
|
||||
try:
|
||||
ext = PyodideWebSocketExtension(url)
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"WebSocket connection to {url} failed: {e}")
|
||||
|
||||
response = Response()
|
||||
response.status_code = 101
|
||||
response.headers = CaseInsensitiveDict({"upgrade": "websocket", "connection": "upgrade"})
|
||||
response.request = request
|
||||
response.url = url
|
||||
response.reason = "Switching Protocols"
|
||||
|
||||
raw_io = _PyodideRawIO()
|
||||
raw_io.extension = ext
|
||||
response.raw = raw_io # type: ignore
|
||||
response._content = b""
|
||||
|
||||
return response
|
||||
|
||||
def _do_send_sse(self, request: PreparedRequest, url: str, scheme: str) -> Response:
|
||||
"""Handle SSE connections via pyfetch streaming + manual parsing."""
|
||||
from ._sse import PyodideSSEExtension
|
||||
|
||||
http_url = url.replace("sse://", "https://", 1) if scheme == "sse" else url.replace("psse://", "http://", 1)
|
||||
|
||||
# Pass through user-provided headers
|
||||
headers_dict: dict[str, str] = {}
|
||||
if request.headers:
|
||||
for key, value in request.headers.items():
|
||||
if key.lower() not in ("host", "content-length", "connection"):
|
||||
headers_dict[key] = value
|
||||
|
||||
try:
|
||||
ext = PyodideSSEExtension(http_url, headers=headers_dict)
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"SSE connection to {url} failed: {e}")
|
||||
|
||||
response = Response()
|
||||
response.status_code = 200
|
||||
response.headers = CaseInsensitiveDict({"content-type": "text/event-stream"})
|
||||
response.request = request
|
||||
response.url = url
|
||||
response.reason = "OK"
|
||||
|
||||
raw_io = _PyodideRawIO()
|
||||
raw_io.extension = ext
|
||||
response.raw = raw_io # type: ignore
|
||||
response._content = False # type: ignore[assignment]
|
||||
response._content_consumed = False
|
||||
|
||||
return response
|
||||
|
||||
def close(self) -> None:
|
||||
"""Clean up adapter resources."""
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ("PyodideAdapter",)
|
||||
@@ -0,0 +1,451 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import typing
|
||||
from datetime import timedelta
|
||||
|
||||
from pyodide.http import pyfetch # type: ignore[import]
|
||||
|
||||
from ...._constant import DEFAULT_RETRIES
|
||||
from ....adapters import AsyncBaseAdapter
|
||||
from ....exceptions import ConnectionError, ConnectTimeout, ReadTimeout
|
||||
from ....models import AsyncResponse, PreparedRequest, Response
|
||||
from ....packages.urllib3._async.response import AsyncHTTPResponse as BaseHTTPResponse
|
||||
from ....packages.urllib3.contrib.ssa._timeout import timeout as asyncio_timeout
|
||||
from ....packages.urllib3.exceptions import MaxRetryError
|
||||
from ....packages.urllib3.response import BytesQueueBuffer
|
||||
from ....packages.urllib3.util import Timeout as TimeoutSauce
|
||||
from ....packages.urllib3.util.retry import Retry
|
||||
from ....structures import CaseInsensitiveDict
|
||||
from ....utils import _swap_context, get_encoding_from_headers
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ....typing import ProxyType, RetryType, TLSClientCertType, TLSVerifyType
|
||||
|
||||
|
||||
class _AsyncPyodideRawIO:
|
||||
"""
|
||||
Async file-like wrapper around Pyodide Fetch response for true streaming.
|
||||
|
||||
This class uses the JavaScript ReadableStream API through Pyodide to provide
|
||||
genuine streaming support without buffering the entire response in memory.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
js_response: typing.Any, # JavaScript Response object
|
||||
timeout: float | None = None,
|
||||
) -> None:
|
||||
self._js_response = js_response
|
||||
self._timeout = timeout
|
||||
self._buffer = BytesQueueBuffer()
|
||||
self._closed = False
|
||||
self._finished = False
|
||||
self._reader: typing.Any = None # JavaScript ReadableStreamDefaultReader
|
||||
self.headers: dict[str, str] = {}
|
||||
self.extension: typing.Any = None
|
||||
|
||||
async def _ensure_reader(self) -> None:
|
||||
"""Initialize the stream reader if not already done."""
|
||||
if self._reader is None and self._js_response is not None:
|
||||
try:
|
||||
body = self._js_response.body
|
||||
if body is not None:
|
||||
self._reader = body.getReader()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def read(self, amt: int | None = None, decode_content: bool = True) -> bytes:
|
||||
"""
|
||||
Read up to `amt` bytes from the response stream.
|
||||
|
||||
When `amt` is None, reads the entire remaining response.
|
||||
"""
|
||||
if self._closed:
|
||||
return self._buffer.get(len(self._buffer)) if len(self._buffer) > 0 else b""
|
||||
|
||||
if self._finished:
|
||||
if len(self._buffer) == 0:
|
||||
return b""
|
||||
if amt is None or amt < 0:
|
||||
return self._buffer.get(len(self._buffer))
|
||||
return self._buffer.get(min(amt, len(self._buffer)))
|
||||
|
||||
if amt is None or amt < 0:
|
||||
# Read everything remaining
|
||||
async for chunk in self._stream_chunks():
|
||||
self._buffer.put(chunk)
|
||||
self._finished = True
|
||||
return self._buffer.get(len(self._buffer)) if len(self._buffer) > 0 else b""
|
||||
|
||||
# Read until we have enough bytes or stream ends
|
||||
while len(self._buffer) < amt and not self._finished:
|
||||
chunk = await self._get_next_chunk() # type: ignore[assignment]
|
||||
if chunk is None:
|
||||
self._finished = True
|
||||
break
|
||||
self._buffer.put(chunk)
|
||||
|
||||
if len(self._buffer) == 0:
|
||||
return b""
|
||||
|
||||
return self._buffer.get(min(amt, len(self._buffer)))
|
||||
|
||||
async def _get_next_chunk(self) -> bytes | None:
|
||||
"""Read the next chunk from the JavaScript ReadableStream."""
|
||||
await self._ensure_reader()
|
||||
|
||||
if self._reader is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
if self._timeout is not None:
|
||||
async with asyncio_timeout(self._timeout):
|
||||
result = await self._reader.read()
|
||||
else:
|
||||
result = await self._reader.read()
|
||||
|
||||
if result.done:
|
||||
return None
|
||||
|
||||
value = result.value
|
||||
if value is not None:
|
||||
# Convert Uint8Array to bytes
|
||||
return bytes(value.to_py())
|
||||
return None
|
||||
except asyncio.TimeoutError:
|
||||
raise ReadTimeout("Read timed out while streaming Pyodide response")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def _stream_chunks(self) -> typing.AsyncGenerator[bytes, None]:
|
||||
"""Async generator that yields chunks from the stream."""
|
||||
await self._ensure_reader()
|
||||
|
||||
if self._reader is None:
|
||||
return
|
||||
|
||||
while True:
|
||||
chunk = await self._get_next_chunk()
|
||||
if chunk is None:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
def stream(self, amt: int, decode_content: bool = True) -> typing.AsyncGenerator[bytes, None]:
|
||||
"""Return an async generator that yields chunks of `amt` bytes."""
|
||||
return self._async_stream(amt)
|
||||
|
||||
async def _async_stream(self, amt: int) -> typing.AsyncGenerator[bytes, None]:
|
||||
"""Internal async generator for streaming."""
|
||||
while True:
|
||||
chunk = await self.read(amt)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the stream and release resources."""
|
||||
self._closed = True
|
||||
if self._reader is not None:
|
||||
try:
|
||||
await self._reader.cancel()
|
||||
except Exception:
|
||||
pass
|
||||
self._reader = None
|
||||
self._js_response = None
|
||||
|
||||
def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> bytes:
|
||||
chunk = await self.read(8192)
|
||||
if not chunk:
|
||||
raise StopAsyncIteration
|
||||
return chunk
|
||||
|
||||
|
||||
class AsyncPyodideAdapter(AsyncBaseAdapter):
|
||||
"""Async adapter for making HTTP requests in Pyodide using the native pyfetch API."""
|
||||
|
||||
def __init__(self, max_retries: RetryType = DEFAULT_RETRIES) -> None:
|
||||
"""
|
||||
Initialize the async Pyodide adapter.
|
||||
|
||||
:param max_retries: Maximum number of retries for requests.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if isinstance(max_retries, Retry):
|
||||
self.max_retries = max_retries
|
||||
else:
|
||||
self.max_retries = Retry.from_int(max_retries)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<AsyncPyodideAdapter WASM/>"
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
stream: bool = False,
|
||||
timeout: int | float | tuple | TimeoutSauce | None = None,
|
||||
verify: TLSVerifyType = True,
|
||||
cert: TLSClientCertType | None = None,
|
||||
proxies: ProxyType | None = None,
|
||||
on_post_connection: typing.Callable[[typing.Any], typing.Awaitable[None]] | None = None,
|
||||
on_upload_body: typing.Callable[[int, int | None, bool, bool], typing.Awaitable[None]] | None = None,
|
||||
on_early_response: typing.Callable[[Response], typing.Awaitable[None]] | None = None,
|
||||
multiplexed: bool = False,
|
||||
) -> AsyncResponse:
|
||||
"""Send a PreparedRequest using Pyodide's pyfetch (JavaScript Fetch API)."""
|
||||
if isinstance(timeout, tuple):
|
||||
if len(timeout) == 3:
|
||||
timeout = timeout[2] or timeout[0] # prefer total, fallback connect
|
||||
else:
|
||||
timeout = timeout[0] # use connect
|
||||
elif isinstance(timeout, TimeoutSauce):
|
||||
timeout = timeout.total or timeout.connect_timeout
|
||||
|
||||
retries = self.max_retries
|
||||
method = request.method or "GET"
|
||||
start = time.time()
|
||||
|
||||
while True:
|
||||
try:
|
||||
response = await self._do_send(request, stream, timeout)
|
||||
except Exception as err:
|
||||
retries = retries.increment(method, request.url, error=err)
|
||||
await retries.async_sleep()
|
||||
continue
|
||||
|
||||
# We rely on the urllib3 implementation for retries
|
||||
# so we basically mock a response to get it to work
|
||||
base_response = BaseHTTPResponse(
|
||||
body=b"",
|
||||
headers=response.headers,
|
||||
status=response.status_code,
|
||||
request_method=request.method,
|
||||
request_url=request.url,
|
||||
)
|
||||
|
||||
# Check if we should retry based on status code
|
||||
has_retry_after = bool(response.headers.get("Retry-After"))
|
||||
|
||||
if retries.is_retry(method, response.status_code, has_retry_after):
|
||||
try:
|
||||
retries = retries.increment(method, request.url, response=base_response)
|
||||
except MaxRetryError:
|
||||
if retries.raise_on_status:
|
||||
raise
|
||||
return response
|
||||
|
||||
await retries.async_sleep(base_response)
|
||||
continue
|
||||
|
||||
response.elapsed = timedelta(seconds=time.time() - start)
|
||||
return response
|
||||
|
||||
async def _do_send(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
stream: bool,
|
||||
timeout: int | float | None,
|
||||
) -> AsyncResponse:
|
||||
"""Perform the actual request using Pyodide's pyfetch."""
|
||||
url = request.url or ""
|
||||
scheme = url.split("://")[0].lower() if "://" in url else ""
|
||||
|
||||
# WebSocket: delegate to browser native WebSocket API
|
||||
if scheme in ("ws", "wss"):
|
||||
return await self._do_send_ws(request, url)
|
||||
|
||||
# SSE: delegate to pyfetch streaming + manual SSE parsing
|
||||
if scheme in ("sse", "psse"):
|
||||
return await self._do_send_sse(request, url, scheme)
|
||||
|
||||
# Prepare headers
|
||||
headers_dict: dict[str, str] = {}
|
||||
if request.headers:
|
||||
for key, value in request.headers.items():
|
||||
# Skip headers that browsers don't allow to be set
|
||||
if key.lower() not in ("host", "content-length", "connection", "transfer-encoding"):
|
||||
headers_dict[key] = value
|
||||
|
||||
# Prepare body
|
||||
body = request.body
|
||||
if body is not None:
|
||||
if isinstance(body, str):
|
||||
body = body.encode("utf-8")
|
||||
elif hasattr(body, "__aiter__"):
|
||||
# Consume async iterable body
|
||||
chunks: list[bytes] = []
|
||||
async for chunk in body: # type: ignore[union-attr]
|
||||
if isinstance(chunk, str):
|
||||
chunks.append(chunk.encode("utf-8"))
|
||||
else:
|
||||
chunks.append(chunk)
|
||||
body = b"".join(chunks)
|
||||
elif isinstance(body, typing.Iterable) and not isinstance(body, (bytes, bytearray)):
|
||||
# Consume sync iterable body
|
||||
chunks = []
|
||||
for chunk in body:
|
||||
if isinstance(chunk, str):
|
||||
chunks.append(chunk.encode("utf-8"))
|
||||
elif isinstance(chunk, bytes):
|
||||
chunks.append(chunk)
|
||||
body = b"".join(chunks)
|
||||
|
||||
# Build fetch options
|
||||
fetch_options: dict[str, typing.Any] = {
|
||||
"method": request.method or "GET",
|
||||
"headers": headers_dict,
|
||||
}
|
||||
|
||||
if body:
|
||||
fetch_options["body"] = body
|
||||
|
||||
# Use AbortSignal.timeout() for timeout — the browser-native mechanism.
|
||||
# asyncio_timeout cannot interrupt a single JS Promise await.
|
||||
signal = None
|
||||
|
||||
if timeout is not None:
|
||||
from js import AbortSignal # type: ignore[import]
|
||||
|
||||
signal = AbortSignal.timeout(int(timeout * 1000))
|
||||
|
||||
try:
|
||||
js_response = await pyfetch(request.url, signal=signal, **fetch_options)
|
||||
except Exception as e:
|
||||
err_str = str(e).lower()
|
||||
if "abort" in err_str or "timeout" in err_str or "timed out" in err_str:
|
||||
raise ConnectTimeout(f"Connection to {request.url} timed out")
|
||||
raise ConnectionError(f"Failed to fetch {request.url}: {e}")
|
||||
|
||||
# Parse response headers
|
||||
response_headers: dict[str, str] = {}
|
||||
try:
|
||||
# Pyodide's FetchResponse has headers as a dict-like object
|
||||
if hasattr(js_response, "headers"):
|
||||
js_headers = js_response.headers
|
||||
if hasattr(js_headers, "items"):
|
||||
for key, value in js_headers.items():
|
||||
response_headers[key] = value
|
||||
elif hasattr(js_headers, "entries"):
|
||||
# JavaScript Headers.entries() returns an iterator
|
||||
for entry in js_headers.entries():
|
||||
response_headers[entry[0]] = entry[1]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Build response object
|
||||
response = Response()
|
||||
response.status_code = js_response.status
|
||||
response.headers = CaseInsensitiveDict(response_headers)
|
||||
response.request = request
|
||||
response.url = js_response.url or request.url
|
||||
response.encoding = get_encoding_from_headers(response_headers)
|
||||
|
||||
# Try to get status text
|
||||
try:
|
||||
response.reason = js_response.status_text or ""
|
||||
except Exception:
|
||||
response.reason = ""
|
||||
|
||||
if stream:
|
||||
# For streaming: set up async raw IO using the JS Response object
|
||||
# This provides true streaming without buffering entire response
|
||||
raw_io = _AsyncPyodideRawIO(js_response.js_response, timeout)
|
||||
raw_io.headers = response_headers
|
||||
response.raw = raw_io # type: ignore
|
||||
response._content = False # type: ignore[assignment]
|
||||
response._content_consumed = False
|
||||
else:
|
||||
# For non-streaming: get full response body
|
||||
try:
|
||||
if timeout is not None:
|
||||
async with asyncio_timeout(timeout):
|
||||
response_body = await js_response.bytes()
|
||||
else:
|
||||
response_body = await js_response.bytes()
|
||||
except asyncio.TimeoutError:
|
||||
raise ReadTimeout(f"Read timed out for {request.url}")
|
||||
|
||||
response._content = response_body
|
||||
raw_io = _AsyncPyodideRawIO(None, timeout)
|
||||
raw_io.headers = response_headers
|
||||
response.raw = raw_io # type: ignore
|
||||
|
||||
_swap_context(response)
|
||||
|
||||
return response # type: ignore[return-value]
|
||||
|
||||
async def _do_send_ws(self, request: PreparedRequest, url: str) -> AsyncResponse:
|
||||
"""Handle WebSocket connections via browser native WebSocket API."""
|
||||
from ._ws import AsyncPyodideWebSocketExtension
|
||||
|
||||
ext = AsyncPyodideWebSocketExtension()
|
||||
|
||||
try:
|
||||
await ext.start(url)
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"WebSocket connection to {url} failed: {e}")
|
||||
|
||||
response = Response()
|
||||
response.status_code = 101
|
||||
response.headers = CaseInsensitiveDict({"upgrade": "websocket", "connection": "upgrade"})
|
||||
response.request = request
|
||||
response.url = url
|
||||
response.reason = "Switching Protocols"
|
||||
|
||||
raw_io = _AsyncPyodideRawIO(None)
|
||||
raw_io.extension = ext
|
||||
response.raw = raw_io # type: ignore
|
||||
response._content = b""
|
||||
|
||||
_swap_context(response)
|
||||
|
||||
return response # type: ignore[return-value]
|
||||
|
||||
async def _do_send_sse(self, request: PreparedRequest, url: str, scheme: str) -> AsyncResponse:
|
||||
"""Handle SSE connections via pyfetch streaming + manual parsing."""
|
||||
from ._sse import AsyncPyodideSSEExtension
|
||||
|
||||
http_url = url.replace("sse://", "https://", 1) if scheme == "sse" else url.replace("psse://", "http://", 1)
|
||||
|
||||
# Pass through user-provided headers
|
||||
headers_dict: dict[str, str] = {}
|
||||
if request.headers:
|
||||
for key, value in request.headers.items():
|
||||
if key.lower() not in ("host", "content-length", "connection"):
|
||||
headers_dict[key] = value
|
||||
|
||||
ext = AsyncPyodideSSEExtension()
|
||||
|
||||
try:
|
||||
await ext.start(http_url, headers=headers_dict)
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"SSE connection to {url} failed: {e}")
|
||||
|
||||
response = Response()
|
||||
response.status_code = 200
|
||||
response.headers = CaseInsensitiveDict({"content-type": "text/event-stream"})
|
||||
response.request = request
|
||||
response.url = url
|
||||
response.reason = "OK"
|
||||
|
||||
raw_io = _AsyncPyodideRawIO(None)
|
||||
raw_io.extension = ext
|
||||
response.raw = raw_io # type: ignore
|
||||
response._content = False # type: ignore[assignment]
|
||||
response._content_consumed = False
|
||||
|
||||
_swap_context(response)
|
||||
|
||||
return response # type: ignore[return-value]
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Clean up adapter resources."""
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ("AsyncPyodideAdapter",)
|
||||
@@ -0,0 +1,152 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from pyodide.http import pyfetch # type: ignore[import]
|
||||
|
||||
from ....packages.urllib3.contrib.webextensions.sse import ServerSentEvent
|
||||
|
||||
|
||||
class AsyncPyodideSSEExtension:
|
||||
"""Async SSE extension for Pyodide using pyfetch streaming + manual SSE parsing.
|
||||
|
||||
Reads from a ReadableStream reader, buffers partial lines,
|
||||
and parses complete SSE events.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._closed = False
|
||||
self._buffer: str = ""
|
||||
self._last_event_id: str | None = None
|
||||
self._reader: typing.Any = None
|
||||
|
||||
async def start(self, url: str, headers: dict[str, str] | None = None) -> None:
|
||||
"""Open the SSE stream via pyfetch."""
|
||||
fetch_options: dict[str, typing.Any] = {
|
||||
"method": "GET",
|
||||
"headers": {
|
||||
"Accept": "text/event-stream",
|
||||
"Cache-Control": "no-store",
|
||||
**(headers or {}),
|
||||
},
|
||||
}
|
||||
|
||||
js_response = await pyfetch(url, **fetch_options)
|
||||
|
||||
body = js_response.js_response.body
|
||||
if body is not None:
|
||||
self._reader = body.getReader()
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._closed
|
||||
|
||||
async def _read_chunk(self) -> str | None:
|
||||
"""Read the next chunk from the ReadableStream."""
|
||||
if self._reader is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
result = await self._reader.read()
|
||||
if result.done:
|
||||
return None
|
||||
value = result.value
|
||||
if value is not None:
|
||||
return bytes(value.to_py()).decode("utf-8")
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def next_payload(self, *, raw: bool = False) -> ServerSentEvent | str | None:
|
||||
"""Read and parse the next SSE event from the stream.
|
||||
Returns None when the stream ends."""
|
||||
if self._closed:
|
||||
raise OSError("The SSE extension is closed")
|
||||
|
||||
# Keep reading chunks until we have a complete event (double newline)
|
||||
while True:
|
||||
# Check if we already have a complete event in the buffer
|
||||
sep_idx = self._buffer.find("\n\n")
|
||||
if sep_idx == -1:
|
||||
sep_idx = self._buffer.find("\r\n\r\n")
|
||||
if sep_idx != -1:
|
||||
sep_len = 4
|
||||
else:
|
||||
sep_len = 2
|
||||
else:
|
||||
sep_len = 2
|
||||
|
||||
if sep_idx != -1:
|
||||
raw_event = self._buffer[:sep_idx]
|
||||
self._buffer = self._buffer[sep_idx + sep_len :]
|
||||
|
||||
event = self._parse_event(raw_event)
|
||||
|
||||
if event is not None:
|
||||
if raw:
|
||||
return raw_event + "\n\n"
|
||||
return event
|
||||
|
||||
# Empty event (e.g. just comments), try next
|
||||
continue
|
||||
|
||||
# Need more data
|
||||
chunk = await self._read_chunk()
|
||||
if chunk is None:
|
||||
self._closed = True
|
||||
return None
|
||||
|
||||
self._buffer += chunk
|
||||
|
||||
def _parse_event(self, raw_event: str) -> ServerSentEvent | None:
|
||||
"""Parse a raw SSE event block into a ServerSentEvent."""
|
||||
kwargs: dict[str, typing.Any] = {}
|
||||
|
||||
for line in raw_event.splitlines():
|
||||
if not line or line.startswith(":"):
|
||||
continue
|
||||
key, _, value = line.partition(":")
|
||||
if key not in {"event", "data", "retry", "id"}:
|
||||
continue
|
||||
if value.startswith(" "):
|
||||
value = value[1:]
|
||||
if key == "id":
|
||||
if "\u0000" in value:
|
||||
continue
|
||||
if key == "retry":
|
||||
try:
|
||||
value = int(value) # type: ignore[assignment]
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
kwargs[key] = value
|
||||
|
||||
if not kwargs:
|
||||
return None
|
||||
|
||||
if "id" not in kwargs and self._last_event_id is not None:
|
||||
kwargs["id"] = self._last_event_id
|
||||
|
||||
event = ServerSentEvent(**kwargs)
|
||||
|
||||
if event.id:
|
||||
self._last_event_id = event.id
|
||||
|
||||
return event
|
||||
|
||||
async def send_payload(self, buf: str | bytes) -> None:
|
||||
"""SSE is one-way only."""
|
||||
raise NotImplementedError("SSE is only one-way. Sending is forbidden.")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the stream and release resources."""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
|
||||
if self._reader is not None:
|
||||
try:
|
||||
await self._reader.cancel()
|
||||
except Exception:
|
||||
pass
|
||||
self._reader = None
|
||||
@@ -0,0 +1,124 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import typing
|
||||
|
||||
from pyodide.ffi import create_proxy # type: ignore[import]
|
||||
|
||||
try:
|
||||
from js import WebSocket as JSWebSocket # type: ignore[import]
|
||||
except ImportError:
|
||||
JSWebSocket = None
|
||||
|
||||
|
||||
class AsyncPyodideWebSocketExtension:
|
||||
"""Async WebSocket extension for Pyodide using the browser's native WebSocket API.
|
||||
|
||||
Messages are queued from JS callbacks and dequeued by next_payload()
|
||||
which awaits until a message arrives.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._closed = False
|
||||
self._queue: asyncio.Queue[str | bytes | None] = asyncio.Queue()
|
||||
self._proxies: list[typing.Any] = []
|
||||
self._ws: typing.Any = None
|
||||
|
||||
async def start(self, url: str) -> None:
|
||||
"""Open the WebSocket connection and wait until it's ready."""
|
||||
if JSWebSocket is None: # Defensive: depends on JS runtime
|
||||
raise OSError(
|
||||
"WebSocket is not available in this JavaScript runtime. "
|
||||
"Browser environment required (not supported in Node.js)."
|
||||
)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
open_future: asyncio.Future[None] = loop.create_future()
|
||||
|
||||
self._ws = JSWebSocket.new(url)
|
||||
self._ws.binaryType = "arraybuffer"
|
||||
|
||||
# JS→Python callbacks: invoked by the browser event loop, not Python call
|
||||
# frames, so coverage cannot trace into them.
|
||||
def _onopen(event: typing.Any) -> None: # Defensive: JS callback
|
||||
if not open_future.done():
|
||||
open_future.set_result(None)
|
||||
|
||||
def _onerror(event: typing.Any) -> None: # Defensive: JS callback
|
||||
if not open_future.done():
|
||||
open_future.set_exception(ConnectionError("WebSocket connection failed"))
|
||||
|
||||
def _onmessage(event: typing.Any) -> None: # Defensive: JS callback
|
||||
data = event.data
|
||||
if isinstance(data, str):
|
||||
self._queue.put_nowait(data)
|
||||
else:
|
||||
# ArrayBuffer → bytes via Pyodide
|
||||
self._queue.put_nowait(bytes(data.to_py()))
|
||||
|
||||
def _onclose(event: typing.Any) -> None: # Defensive: JS callback
|
||||
self._queue.put_nowait(None)
|
||||
|
||||
for name, fn in [
|
||||
("onopen", _onopen),
|
||||
("onerror", _onerror),
|
||||
("onmessage", _onmessage),
|
||||
("onclose", _onclose),
|
||||
]:
|
||||
proxy = create_proxy(fn)
|
||||
self._proxies.append(proxy)
|
||||
setattr(self._ws, name, proxy)
|
||||
|
||||
await open_future
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._closed
|
||||
|
||||
async def next_payload(self) -> str | bytes | None:
|
||||
"""Await the next message from the WebSocket.
|
||||
Returns None when the remote end closes the connection."""
|
||||
if self._closed: # Defensive: caller should not call after close
|
||||
raise OSError("The WebSocket extension is closed")
|
||||
|
||||
msg = await self._queue.get()
|
||||
|
||||
if msg is None: # Defensive: server-initiated close sentinel
|
||||
self._closed = True
|
||||
|
||||
return msg
|
||||
|
||||
async def send_payload(self, buf: str | bytes) -> None:
|
||||
"""Send a message over the WebSocket."""
|
||||
if self._closed: # Defensive: caller should not call after close
|
||||
raise OSError("The WebSocket extension is closed")
|
||||
|
||||
if isinstance(buf, (bytes, bytearray)):
|
||||
from js import Uint8Array # type: ignore[import]
|
||||
|
||||
self._ws.send(Uint8Array.new(buf))
|
||||
else:
|
||||
self._ws.send(buf)
|
||||
|
||||
async def ping(self) -> None:
|
||||
"""No-op — browser WebSocket handles ping/pong at protocol level."""
|
||||
pass
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebSocket and clean up proxies."""
|
||||
if self._closed: # Defensive: idempotent close
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
|
||||
try:
|
||||
self._ws.close()
|
||||
except Exception: # Defensive: suppress JS errors on teardown
|
||||
pass
|
||||
|
||||
for proxy in self._proxies:
|
||||
try:
|
||||
proxy.destroy()
|
||||
except Exception: # Defensive: suppress JS errors on teardown
|
||||
pass
|
||||
self._proxies.clear()
|
||||
@@ -0,0 +1,151 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from pyodide.ffi import run_sync # type: ignore[import]
|
||||
from pyodide.http import pyfetch # type: ignore[import]
|
||||
|
||||
from ...packages.urllib3.contrib.webextensions.sse import ServerSentEvent
|
||||
|
||||
|
||||
class PyodideSSEExtension:
|
||||
"""SSE extension for Pyodide using pyfetch streaming + manual SSE parsing.
|
||||
|
||||
Synchronous via JSPI (run_sync). Reads from a ReadableStream reader,
|
||||
buffers partial lines, and parses complete SSE events.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, headers: dict[str, str] | None = None) -> None:
|
||||
self._closed = False
|
||||
self._buffer: str = ""
|
||||
self._last_event_id: str | None = None
|
||||
self._reader: typing.Any = None
|
||||
|
||||
fetch_options: dict[str, typing.Any] = {
|
||||
"method": "GET",
|
||||
"headers": {
|
||||
"Accept": "text/event-stream",
|
||||
"Cache-Control": "no-store",
|
||||
**(headers or {}),
|
||||
},
|
||||
}
|
||||
|
||||
js_response = run_sync(pyfetch(url, **fetch_options))
|
||||
|
||||
body = js_response.js_response.body
|
||||
if body is not None:
|
||||
self._reader = body.getReader()
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._closed
|
||||
|
||||
def _read_chunk(self) -> str | None:
|
||||
"""Read the next chunk from the ReadableStream, blocking via JSPI."""
|
||||
if self._reader is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
result = run_sync(self._reader.read())
|
||||
if result.done:
|
||||
return None
|
||||
value = result.value
|
||||
if value is not None:
|
||||
return bytes(value.to_py()).decode("utf-8")
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def next_payload(self, *, raw: bool = False) -> ServerSentEvent | str | None:
|
||||
"""Read and parse the next SSE event from the stream.
|
||||
Returns None when the stream ends."""
|
||||
if self._closed:
|
||||
raise OSError("The SSE extension is closed")
|
||||
|
||||
# Keep reading chunks until we have a complete event (double newline)
|
||||
while True:
|
||||
# Check if we already have a complete event in the buffer
|
||||
sep_idx = self._buffer.find("\n\n")
|
||||
if sep_idx == -1:
|
||||
sep_idx = self._buffer.find("\r\n\r\n")
|
||||
if sep_idx != -1:
|
||||
sep_len = 4
|
||||
else:
|
||||
sep_len = 2
|
||||
else:
|
||||
sep_len = 2
|
||||
|
||||
if sep_idx != -1:
|
||||
raw_event = self._buffer[:sep_idx]
|
||||
self._buffer = self._buffer[sep_idx + sep_len :]
|
||||
|
||||
event = self._parse_event(raw_event)
|
||||
|
||||
if event is not None:
|
||||
if raw:
|
||||
return raw_event + "\n\n"
|
||||
return event
|
||||
|
||||
# Empty event (e.g. just comments), try next
|
||||
continue
|
||||
|
||||
# Need more data
|
||||
chunk = self._read_chunk()
|
||||
if chunk is None:
|
||||
self._closed = True
|
||||
return None
|
||||
|
||||
self._buffer += chunk
|
||||
|
||||
def _parse_event(self, raw_event: str) -> ServerSentEvent | None:
|
||||
"""Parse a raw SSE event block into a ServerSentEvent."""
|
||||
kwargs: dict[str, typing.Any] = {}
|
||||
|
||||
for line in raw_event.splitlines():
|
||||
if not line or line.startswith(":"):
|
||||
continue
|
||||
key, _, value = line.partition(":")
|
||||
if key not in {"event", "data", "retry", "id"}:
|
||||
continue
|
||||
if value.startswith(" "):
|
||||
value = value[1:]
|
||||
if key == "id":
|
||||
if "\u0000" in value:
|
||||
continue
|
||||
if key == "retry":
|
||||
try:
|
||||
value = int(value) # type: ignore[assignment]
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
kwargs[key] = value
|
||||
|
||||
if not kwargs:
|
||||
return None
|
||||
|
||||
if "id" not in kwargs and self._last_event_id is not None:
|
||||
kwargs["id"] = self._last_event_id
|
||||
|
||||
event = ServerSentEvent(**kwargs)
|
||||
|
||||
if event.id:
|
||||
self._last_event_id = event.id
|
||||
|
||||
return event
|
||||
|
||||
def send_payload(self, buf: str | bytes) -> None:
|
||||
"""SSE is one-way only."""
|
||||
raise NotImplementedError("SSE is only one-way. Sending is forbidden.")
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the stream and release resources."""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
|
||||
if self._reader is not None:
|
||||
try:
|
||||
run_sync(self._reader.cancel())
|
||||
except Exception:
|
||||
pass
|
||||
self._reader = None
|
||||
@@ -0,0 +1,174 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from js import Promise # type: ignore[import]
|
||||
from pyodide.ffi import create_proxy, run_sync # type: ignore[import]
|
||||
|
||||
try:
|
||||
from js import WebSocket as JSWebSocket # type: ignore[import]
|
||||
except ImportError:
|
||||
JSWebSocket = None
|
||||
|
||||
|
||||
class PyodideWebSocketExtension:
|
||||
"""WebSocket extension for Pyodide using the browser's native WebSocket API.
|
||||
|
||||
Synchronous via JSPI (run_sync). Uses JS Promises for signaling instead of
|
||||
asyncio primitives, because run_sync bypasses the asyncio event loop.
|
||||
Messages from JS callbacks are buffered in a list and delivered via
|
||||
next_payload() which blocks via run_sync on a JS Promise.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str) -> None:
|
||||
if JSWebSocket is None: # Defensive: depends on JS runtime
|
||||
raise OSError(
|
||||
"WebSocket is not available in this JavaScript runtime. "
|
||||
"Browser environment required (not supported in Node.js)."
|
||||
)
|
||||
|
||||
self._closed = False
|
||||
self._pending: list[str | bytes | None] = []
|
||||
self._waiting_resolve: typing.Any = None
|
||||
self._last_msg: str | bytes | None = None
|
||||
self._proxies: list[typing.Any] = []
|
||||
|
||||
# Create browser WebSocket
|
||||
self._ws = JSWebSocket.new(url)
|
||||
self._ws.binaryType = "arraybuffer"
|
||||
|
||||
# Open/error signaling via a JS Promise (not asyncio.Future,
|
||||
# because run_sync/JSPI does not drive the asyncio event loop).
|
||||
_open_state: dict[str, typing.Any] = {
|
||||
"resolve": None,
|
||||
"reject": None,
|
||||
}
|
||||
|
||||
def _open_executor(resolve: typing.Any, reject: typing.Any) -> None:
|
||||
_open_state["resolve"] = resolve
|
||||
_open_state["reject"] = reject
|
||||
|
||||
exec_proxy = create_proxy(_open_executor)
|
||||
open_promise = Promise.new(exec_proxy)
|
||||
self._proxies.append(exec_proxy)
|
||||
|
||||
# JS→Python callbacks: invoked by the browser event loop, not Python call
|
||||
# frames, so coverage cannot trace into them.
|
||||
def _onopen(event: typing.Any) -> None: # Defensive: JS callback
|
||||
r = _open_state["resolve"]
|
||||
if r is not None:
|
||||
_open_state["resolve"] = None
|
||||
_open_state["reject"] = None
|
||||
r()
|
||||
|
||||
def _onerror(event: typing.Any) -> None: # Defensive: JS callback
|
||||
r = _open_state["reject"]
|
||||
if r is not None:
|
||||
_open_state["resolve"] = None
|
||||
_open_state["reject"] = None
|
||||
r("WebSocket connection failed")
|
||||
|
||||
def _onmessage(event: typing.Any) -> None: # Defensive: JS callback
|
||||
data = event.data
|
||||
if isinstance(data, str):
|
||||
msg: str | bytes = data
|
||||
else:
|
||||
# ArrayBuffer → bytes via Pyodide
|
||||
msg = bytes(data.to_py())
|
||||
|
||||
if self._waiting_resolve is not None:
|
||||
self._last_msg = msg
|
||||
r = self._waiting_resolve
|
||||
self._waiting_resolve = None
|
||||
r()
|
||||
else:
|
||||
self._pending.append(msg)
|
||||
|
||||
def _onclose(event: typing.Any) -> None: # Defensive: JS callback
|
||||
if self._waiting_resolve is not None:
|
||||
self._last_msg = None
|
||||
r = self._waiting_resolve
|
||||
self._waiting_resolve = None
|
||||
r()
|
||||
else:
|
||||
self._pending.append(None)
|
||||
|
||||
# Create proxies so JS can call these Python functions
|
||||
for name, fn in [
|
||||
("onopen", _onopen),
|
||||
("onerror", _onerror),
|
||||
("onmessage", _onmessage),
|
||||
("onclose", _onclose),
|
||||
]:
|
||||
proxy = create_proxy(fn)
|
||||
self._proxies.append(proxy)
|
||||
setattr(self._ws, name, proxy)
|
||||
|
||||
# Block until the connection is open (or error rejects the promise)
|
||||
run_sync(open_promise)
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._closed
|
||||
|
||||
def next_payload(self) -> str | bytes | None:
|
||||
"""Block (via JSPI) until the next message arrives.
|
||||
Returns None when the remote end closes the connection."""
|
||||
if self._closed: # Defensive: caller should not call after close
|
||||
raise OSError("The WebSocket extension is closed")
|
||||
|
||||
# Drain from buffer first
|
||||
if self._pending:
|
||||
msg = self._pending.pop(0)
|
||||
if msg is None: # Defensive: serverinitiated close sentinel
|
||||
self._closed = True
|
||||
return msg
|
||||
|
||||
# Wait for the next message via a JS Promise
|
||||
def _executor(resolve: typing.Any, reject: typing.Any) -> None:
|
||||
self._waiting_resolve = resolve
|
||||
|
||||
exec_proxy = create_proxy(_executor)
|
||||
promise = Promise.new(exec_proxy)
|
||||
run_sync(promise)
|
||||
exec_proxy.destroy()
|
||||
|
||||
msg = self._last_msg
|
||||
if msg is None: # Defensive: server-initiated close sentinel
|
||||
self._closed = True
|
||||
return msg
|
||||
|
||||
def send_payload(self, buf: str | bytes) -> None:
|
||||
"""Send a message over the WebSocket."""
|
||||
if self._closed: # Defensive: caller should not call after close
|
||||
raise OSError("The WebSocket extension is closed")
|
||||
|
||||
if isinstance(buf, (bytes, bytearray)):
|
||||
from js import Uint8Array # type: ignore[import]
|
||||
|
||||
self._ws.send(Uint8Array.new(buf))
|
||||
else:
|
||||
self._ws.send(buf)
|
||||
|
||||
def ping(self) -> None:
|
||||
"""No-op — browser WebSocket handles ping/pong at protocol level."""
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the WebSocket and clean up proxies."""
|
||||
if self._closed: # Defensive: idempotent close
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
|
||||
try:
|
||||
self._ws.close()
|
||||
except Exception: # Defensive: suppress JS errors on teardown
|
||||
pass
|
||||
|
||||
for proxy in self._proxies:
|
||||
try:
|
||||
proxy.destroy()
|
||||
except Exception: # Defensive: suppress JS errors on teardown
|
||||
pass
|
||||
self._proxies.clear()
|
||||
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class RevocationStrategy(Enum):
|
||||
PREFER_OCSP = 0
|
||||
PREFER_CRL = 1
|
||||
CHECK_ALL = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class RevocationConfiguration:
|
||||
strategy: RevocationStrategy | None = RevocationStrategy.PREFER_OCSP
|
||||
strict_mode: bool = False
|
||||
|
||||
|
||||
DEFAULT_STRATEGY: RevocationConfiguration = RevocationConfiguration()
|
||||
@@ -0,0 +1,374 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import datetime
|
||||
import ipaddress
|
||||
import ssl
|
||||
import threading
|
||||
import typing
|
||||
import warnings
|
||||
from random import randint
|
||||
|
||||
from qh3._hazmat import (
|
||||
Certificate,
|
||||
CertificateRevocationList,
|
||||
)
|
||||
|
||||
from ....exceptions import RequestException, SSLError
|
||||
from ....models import PreparedRequest
|
||||
from ....packages.urllib3 import ConnectionInfo
|
||||
from ....packages.urllib3.contrib.resolver import BaseResolver
|
||||
from ....packages.urllib3.exceptions import SecurityWarning
|
||||
from ....typing import ProxyType
|
||||
from .._ocsp import _parse_x509_der_cached, _str_fingerprint_of, readable_revocation_reason
|
||||
|
||||
|
||||
class InMemoryRevocationList:
|
||||
def __init__(self, max_size: int = 256):
|
||||
self._max_size: int = max_size
|
||||
self._store: dict[str, CertificateRevocationList] = {}
|
||||
self._issuers_map: dict[str, Certificate] = {}
|
||||
self._crl_endpoints: dict[str, str] = {}
|
||||
self._failure_count: int = 0
|
||||
self._access_lock = threading.RLock()
|
||||
self._second_level_locks: dict[str, threading.RLock] = {}
|
||||
|
||||
@contextlib.contextmanager
|
||||
def lock_for(self, peer_certificate: Certificate) -> typing.Generator[None]:
|
||||
fingerprint: str = _str_fingerprint_of(peer_certificate)
|
||||
|
||||
with self._access_lock:
|
||||
if fingerprint not in self._second_level_locks:
|
||||
self._second_level_locks[fingerprint] = threading.RLock()
|
||||
|
||||
lock = self._second_level_locks[fingerprint]
|
||||
|
||||
lock.acquire()
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
def __getstate__(self) -> dict[str, typing.Any]:
|
||||
with self._access_lock:
|
||||
return {
|
||||
"_max_size": self._max_size,
|
||||
"_store": {k: v.serialize() for k, v in self._store.items()},
|
||||
"_issuers_map": {k: v.serialize() for k, v in self._issuers_map.items()},
|
||||
"_failure_count": self._failure_count,
|
||||
"_crl_endpoints": self._crl_endpoints,
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict[str, typing.Any]) -> None:
|
||||
if "_store" not in state or "_issuers_map" not in state or "_max_size" not in state or "_crl_endpoints" not in state:
|
||||
raise OSError("unrecoverable state for InMemoryRevocationStatus")
|
||||
|
||||
self._access_lock = threading.RLock()
|
||||
self._second_level_locks = {}
|
||||
self._max_size = state["_max_size"]
|
||||
self._failure_count = state["_failure_count"] if "_failure_count" in state else 0
|
||||
self._crl_endpoints = state["_crl_endpoints"]
|
||||
|
||||
self._store = {}
|
||||
|
||||
for k, v in state["_store"].items():
|
||||
self._store[k] = CertificateRevocationList.deserialize(v)
|
||||
|
||||
self._issuers_map = {}
|
||||
|
||||
for k, v in state["_issuers_map"].items():
|
||||
self._issuers_map[k] = Certificate.deserialize(v)
|
||||
|
||||
def get_issuer_of(self, peer_certificate: Certificate) -> Certificate | None:
|
||||
with self._access_lock:
|
||||
fingerprint: str = _str_fingerprint_of(peer_certificate)
|
||||
|
||||
if fingerprint not in self._issuers_map:
|
||||
return None
|
||||
|
||||
return self._issuers_map[fingerprint]
|
||||
|
||||
def __len__(self) -> int:
|
||||
with self._access_lock:
|
||||
return len(self._store)
|
||||
|
||||
def incr_failure(self) -> None:
|
||||
with self._access_lock:
|
||||
self._failure_count += 1
|
||||
|
||||
@property
|
||||
def failure_count(self) -> int:
|
||||
return self._failure_count
|
||||
|
||||
def check(self, crl_distribution_point: str) -> CertificateRevocationList | None:
|
||||
with self._access_lock:
|
||||
if crl_distribution_point not in self._store:
|
||||
return None
|
||||
|
||||
cached_response = self._store[crl_distribution_point]
|
||||
|
||||
if cached_response.next_update_at and datetime.datetime.now().timestamp() >= cached_response.next_update_at:
|
||||
del self._store[crl_distribution_point]
|
||||
return None
|
||||
|
||||
return cached_response
|
||||
|
||||
def get_previous_crl_endpoint(self, leaf: Certificate) -> str | None:
|
||||
fingerprint = _str_fingerprint_of(leaf)
|
||||
|
||||
if fingerprint in self._crl_endpoints:
|
||||
return self._crl_endpoints[fingerprint]
|
||||
|
||||
return None
|
||||
|
||||
def save(
|
||||
self,
|
||||
leaf: Certificate,
|
||||
issuer: Certificate,
|
||||
crl: CertificateRevocationList,
|
||||
crl_distribution_point: str,
|
||||
) -> None:
|
||||
with self._access_lock:
|
||||
if len(self._store) >= self._max_size:
|
||||
tbd_key: str | None = None
|
||||
closest_next_update: int | None = None
|
||||
|
||||
for k in self._store:
|
||||
if closest_next_update is None:
|
||||
closest_next_update = self._store[k].next_update_at
|
||||
tbd_key = k
|
||||
continue
|
||||
if self._store[k].next_update_at > closest_next_update: # type: ignore
|
||||
closest_next_update = self._store[k].next_update_at
|
||||
tbd_key = k
|
||||
|
||||
if tbd_key:
|
||||
del self._store[tbd_key]
|
||||
else:
|
||||
first_key = list(self._store.keys())[0]
|
||||
del self._store[first_key]
|
||||
|
||||
peer_fingerprint: str = _str_fingerprint_of(leaf)
|
||||
|
||||
self._store[crl_distribution_point] = crl
|
||||
self._crl_endpoints[peer_fingerprint] = crl_distribution_point
|
||||
self._issuers_map[peer_fingerprint] = issuer
|
||||
self._failure_count = 0
|
||||
|
||||
|
||||
def verify(
|
||||
r: PreparedRequest,
|
||||
strict: bool = False,
|
||||
timeout: float | int = 0.2,
|
||||
proxies: ProxyType | None = None,
|
||||
resolver: BaseResolver | None = None,
|
||||
happy_eyeballs: bool | int = False,
|
||||
cache: InMemoryRevocationList | None = None,
|
||||
) -> None:
|
||||
conn_info: ConnectionInfo | None = r.conn_info
|
||||
|
||||
# we can't do anything in that case.
|
||||
if conn_info is None or conn_info.certificate_der is None or conn_info.certificate_dict is None:
|
||||
return
|
||||
|
||||
endpoints: list[str] = [ # type: ignore
|
||||
# exclude non-HTTP endpoint. like ldap.
|
||||
ep # type: ignore
|
||||
for ep in list(conn_info.certificate_dict.get("crlDistributionPoints", [])) # type: ignore
|
||||
if ep.startswith("http://") # type: ignore
|
||||
]
|
||||
|
||||
# no CRL distribution point available.
|
||||
if not endpoints:
|
||||
return
|
||||
|
||||
if cache is None:
|
||||
cache = InMemoryRevocationList()
|
||||
|
||||
if not strict:
|
||||
if cache.failure_count >= 4:
|
||||
return
|
||||
|
||||
# some corporate environment
|
||||
# have invalid OCSP implementation
|
||||
# they use a cert that IS NOT in the chain
|
||||
# to sign the response. It's weird but true.
|
||||
# see https://github.com/jawah/niquests/issues/274
|
||||
ignore_signature_without_strict = ipaddress.ip_address(conn_info.destination_address[0]).is_private or bool(proxies)
|
||||
verify_signature = strict is True or ignore_signature_without_strict is False
|
||||
|
||||
peer_certificate: Certificate = _parse_x509_der_cached(conn_info.certificate_der)
|
||||
|
||||
crl_distribution_point: str = cache.get_previous_crl_endpoint(peer_certificate) or endpoints[randint(0, len(endpoints) - 1)]
|
||||
|
||||
with cache.lock_for(peer_certificate):
|
||||
cached_revocation_list = cache.check(crl_distribution_point)
|
||||
|
||||
if cached_revocation_list is not None:
|
||||
issuer_certificate = cache.get_issuer_of(peer_certificate)
|
||||
|
||||
if issuer_certificate:
|
||||
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
|
||||
|
||||
revocation_status = cached_revocation_list.is_revoked(peer_certificate.serial_number)
|
||||
|
||||
if revocation_status is not None:
|
||||
r.ocsp_verified = False
|
||||
raise SSLError(
|
||||
(
|
||||
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
|
||||
f"by issuer ({readable_revocation_reason(revocation_status.reason)}). "
|
||||
"You should avoid trying to request anything from it as the remote has been compromised. ",
|
||||
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
|
||||
"for more information.",
|
||||
)
|
||||
)
|
||||
else:
|
||||
r.ocsp_verified = True
|
||||
|
||||
return
|
||||
|
||||
from ....sessions import Session
|
||||
|
||||
with Session(resolver=resolver, happy_eyeballs=happy_eyeballs) as session:
|
||||
session.trust_env = False
|
||||
if proxies:
|
||||
session.proxies = proxies
|
||||
|
||||
# When using Python native capabilities, you won't have the issuerCA DER by default.
|
||||
# Unfortunately! But no worries, we can circumvent it!
|
||||
# Three ways are valid to fetch it (in order of preference, safest to riskiest):
|
||||
# - The issuer can be (but unlikely) a root CA.
|
||||
# - Retrieve it by asking it from the TLS layer.
|
||||
# - Downloading it using specified caIssuers from the peer certificate.
|
||||
if conn_info.issuer_certificate_der is None:
|
||||
# It could be a root (self-signed) certificate. Or a previously seen issuer.
|
||||
issuer_certificate = cache.get_issuer_of(peer_certificate)
|
||||
|
||||
hint_ca_issuers: list[str] = [
|
||||
ep # type: ignore
|
||||
for ep in list(conn_info.certificate_dict.get("caIssuers", [])) # type: ignore
|
||||
if ep.startswith("http://") # type: ignore
|
||||
]
|
||||
|
||||
if issuer_certificate is None and hint_ca_issuers:
|
||||
try:
|
||||
raw_intermediary_response = session.get(hint_ca_issuers[0])
|
||||
except RequestException:
|
||||
pass
|
||||
else:
|
||||
if raw_intermediary_response.status_code and 300 > raw_intermediary_response.status_code >= 200:
|
||||
raw_intermediary_content = raw_intermediary_response.content
|
||||
|
||||
if raw_intermediary_content is not None:
|
||||
# binary DER
|
||||
if b"-----BEGIN CERTIFICATE-----" not in raw_intermediary_content:
|
||||
issuer_certificate = Certificate(raw_intermediary_content)
|
||||
# b64 PEM
|
||||
elif b"-----BEGIN CERTIFICATE-----" in raw_intermediary_content:
|
||||
issuer_certificate = Certificate(
|
||||
ssl.PEM_cert_to_DER_cert(raw_intermediary_content.decode())
|
||||
)
|
||||
|
||||
# Well! We're out of luck. No further should we go.
|
||||
if issuer_certificate is None:
|
||||
cache.incr_failure()
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate "
|
||||
"via CRL. You are seeing this warning due to enabling strict mode for OCSP / "
|
||||
"Revocation check. Reason: Remote did not provide any intermediary certificate."
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
|
||||
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
|
||||
else:
|
||||
issuer_certificate = Certificate(conn_info.issuer_certificate_der)
|
||||
|
||||
try:
|
||||
crl_http_response = session.get(
|
||||
crl_distribution_point,
|
||||
timeout=timeout,
|
||||
)
|
||||
except RequestException as e:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via CRL. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
f"Reason: {e}"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
|
||||
if crl_http_response.status_code and 300 > crl_http_response.status_code >= 200:
|
||||
if crl_http_response.content is None:
|
||||
return
|
||||
|
||||
try:
|
||||
crl = CertificateRevocationList(crl_http_response.content)
|
||||
except ValueError:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via CRL. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
"Reason: The X509 CRL parser failed to read the response"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
|
||||
if verify_signature:
|
||||
# Verify the signature of the OCSP response with issuer public key
|
||||
try:
|
||||
if not crl.authenticate_for(issuer_certificate.public_bytes()):
|
||||
raise SSLError(
|
||||
f"Unable to establish a secure connection to {r.url} "
|
||||
"because the CRL response received has been tampered. "
|
||||
"You could be targeted by a MITM attack."
|
||||
)
|
||||
except ValueError:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate "
|
||||
"via CRL. You are seeing this warning due to enabling strict mode for OCSP / "
|
||||
"Revocation check. Reason: The X509 CRL is signed using an unsupported algorithm."
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
|
||||
cache.save(peer_certificate, issuer_certificate, crl, crl_distribution_point)
|
||||
|
||||
revocation_status = crl.is_revoked(peer_certificate.serial_number)
|
||||
|
||||
if revocation_status is not None:
|
||||
r.ocsp_verified = False
|
||||
raise SSLError(
|
||||
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
|
||||
f"by issuer ({readable_revocation_reason(revocation_status.reason)}). "
|
||||
"You should avoid trying to request anything from it as the remote has been compromised. "
|
||||
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
|
||||
"for more information."
|
||||
)
|
||||
else:
|
||||
r.ocsp_verified = True
|
||||
else:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via CRL. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
f"CRL endpoint: {str(crl_http_response)}"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ("verify",)
|
||||
@@ -0,0 +1,405 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import ipaddress
|
||||
import ssl
|
||||
import typing
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
from random import randint
|
||||
|
||||
from qh3._hazmat import (
|
||||
Certificate,
|
||||
CertificateRevocationList,
|
||||
)
|
||||
|
||||
from .....exceptions import RequestException, SSLError
|
||||
from .....models import PreparedRequest
|
||||
from .....packages.urllib3 import ConnectionInfo
|
||||
from .....packages.urllib3.contrib.resolver._async import AsyncBaseResolver
|
||||
from .....packages.urllib3.exceptions import SecurityWarning
|
||||
from .....typing import ProxyType
|
||||
from .....utils import is_cancelled_error_root_cause
|
||||
from ..._ocsp import _parse_x509_der_cached, _str_fingerprint_of, readable_revocation_reason
|
||||
|
||||
|
||||
class InMemoryRevocationList:
|
||||
def __init__(self, max_size: int = 256):
|
||||
self._max_size: int = max_size
|
||||
self._store: dict[str, CertificateRevocationList] = {}
|
||||
self._semaphores: dict[str, asyncio.Semaphore] = {}
|
||||
self._issuers_map: dict[str, Certificate] = {}
|
||||
self._crl_endpoints: dict[str, str] = {}
|
||||
self._failure_count: int = 0
|
||||
|
||||
def __getstate__(self) -> dict[str, typing.Any]:
|
||||
return {
|
||||
"_max_size": self._max_size,
|
||||
"_store": {k: v.serialize() for k, v in self._store.items()},
|
||||
"_issuers_map": {k: v.serialize() for k, v in self._issuers_map.items()},
|
||||
"_failure_count": self._failure_count,
|
||||
"_crl_endpoints": self._crl_endpoints,
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict[str, typing.Any]) -> None:
|
||||
if "_store" not in state or "_issuers_map" not in state or "_max_size" not in state or "_crl_endpoints" not in state:
|
||||
raise OSError("unrecoverable state for InMemoryRevocationStatus")
|
||||
|
||||
self._max_size = state["_max_size"]
|
||||
self._failure_count = state["_failure_count"] if "_failure_count" in state else 0
|
||||
self._crl_endpoints = state["_crl_endpoints"]
|
||||
self._semaphores = {}
|
||||
|
||||
self._store = {}
|
||||
|
||||
for k, v in state["_store"].items():
|
||||
self._store[k] = CertificateRevocationList.deserialize(v)
|
||||
|
||||
self._issuers_map = {}
|
||||
|
||||
for k, v in state["_issuers_map"].items():
|
||||
self._issuers_map[k] = Certificate.deserialize(v)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lock(self, peer_certificate: Certificate) -> typing.AsyncGenerator[None, None]:
|
||||
fingerprint: str = _str_fingerprint_of(peer_certificate)
|
||||
|
||||
if fingerprint not in self._semaphores:
|
||||
self._semaphores[fingerprint] = asyncio.Semaphore()
|
||||
|
||||
await self._semaphores[fingerprint].acquire()
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._semaphores[fingerprint].release()
|
||||
|
||||
def get_issuer_of(self, peer_certificate: Certificate) -> Certificate | None:
|
||||
fingerprint: str = _str_fingerprint_of(peer_certificate)
|
||||
|
||||
if fingerprint not in self._issuers_map:
|
||||
return None
|
||||
|
||||
return self._issuers_map[fingerprint]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._store)
|
||||
|
||||
def incr_failure(self) -> None:
|
||||
self._failure_count += 1
|
||||
|
||||
@property
|
||||
def failure_count(self) -> int:
|
||||
return self._failure_count
|
||||
|
||||
def check(self, crl_distribution_point: str) -> CertificateRevocationList | None:
|
||||
if crl_distribution_point not in self._store:
|
||||
return None
|
||||
|
||||
cached_response = self._store[crl_distribution_point]
|
||||
|
||||
if cached_response.next_update_at and datetime.datetime.now().timestamp() >= cached_response.next_update_at:
|
||||
del self._store[crl_distribution_point]
|
||||
return None
|
||||
|
||||
return cached_response
|
||||
|
||||
def get_previous_crl_endpoint(self, leaf: Certificate) -> str | None:
|
||||
fingerprint = _str_fingerprint_of(leaf)
|
||||
|
||||
if fingerprint in self._crl_endpoints:
|
||||
return self._crl_endpoints[fingerprint]
|
||||
|
||||
return None
|
||||
|
||||
def save(
|
||||
self,
|
||||
leaf: Certificate,
|
||||
issuer: Certificate,
|
||||
crl: CertificateRevocationList,
|
||||
crl_distribution_point: str,
|
||||
) -> None:
|
||||
if len(self._store) >= self._max_size:
|
||||
tbd_key: str | None = None
|
||||
closest_next_update: int | None = None
|
||||
|
||||
for k in self._store:
|
||||
if closest_next_update is None:
|
||||
closest_next_update = self._store[k].next_update_at
|
||||
tbd_key = k
|
||||
continue
|
||||
if self._store[k].next_update_at > closest_next_update: # type: ignore
|
||||
closest_next_update = self._store[k].next_update_at
|
||||
tbd_key = k
|
||||
|
||||
if tbd_key:
|
||||
del self._store[tbd_key]
|
||||
else:
|
||||
first_key = list(self._store.keys())[0]
|
||||
del self._store[first_key]
|
||||
|
||||
peer_fingerprint: str = _str_fingerprint_of(leaf)
|
||||
|
||||
self._store[crl_distribution_point] = crl
|
||||
self._crl_endpoints[peer_fingerprint] = crl_distribution_point
|
||||
self._issuers_map[peer_fingerprint] = issuer
|
||||
self._failure_count = 0
|
||||
|
||||
|
||||
async def verify(
|
||||
r: PreparedRequest,
|
||||
strict: bool = False,
|
||||
timeout: float | int = 0.2,
|
||||
proxies: ProxyType | None = None,
|
||||
resolver: AsyncBaseResolver | None = None,
|
||||
happy_eyeballs: bool | int = False,
|
||||
cache: InMemoryRevocationList | None = None,
|
||||
) -> None:
|
||||
conn_info: ConnectionInfo | None = r.conn_info
|
||||
|
||||
# we can't do anything in that case.
|
||||
if conn_info is None or conn_info.certificate_der is None or conn_info.certificate_dict is None:
|
||||
return
|
||||
|
||||
endpoints: list[str] = [ # type: ignore
|
||||
# exclude non-HTTP endpoint. like ldap.
|
||||
ep # type: ignore
|
||||
for ep in list(conn_info.certificate_dict.get("crlDistributionPoints", [])) # type: ignore
|
||||
if ep.startswith("http://") # type: ignore
|
||||
]
|
||||
|
||||
# well... not all issued certificate have a OCSP entry. e.g. mkcert.
|
||||
if not endpoints:
|
||||
return
|
||||
|
||||
if cache is None:
|
||||
cache = InMemoryRevocationList()
|
||||
|
||||
if not strict:
|
||||
if cache.failure_count >= 4:
|
||||
return
|
||||
|
||||
# some corporate environment
|
||||
# have invalid OCSP implementation
|
||||
# they use a cert that IS NOT in the chain
|
||||
# to sign the response. It's weird but true.
|
||||
# see https://github.com/jawah/niquests/issues/274
|
||||
ignore_signature_without_strict = ipaddress.ip_address(conn_info.destination_address[0]).is_private or bool(proxies)
|
||||
verify_signature = strict is True or ignore_signature_without_strict is False
|
||||
|
||||
peer_certificate: Certificate = _parse_x509_der_cached(conn_info.certificate_der)
|
||||
|
||||
crl_distribution_point: str = cache.get_previous_crl_endpoint(peer_certificate) or endpoints[randint(0, len(endpoints) - 1)]
|
||||
|
||||
cached_revocation_list = cache.check(crl_distribution_point)
|
||||
|
||||
if cached_revocation_list is not None:
|
||||
issuer_certificate = cache.get_issuer_of(peer_certificate)
|
||||
|
||||
if issuer_certificate is not None:
|
||||
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
|
||||
|
||||
revocation_status = cached_revocation_list.is_revoked(peer_certificate.serial_number)
|
||||
|
||||
if revocation_status is not None:
|
||||
r.ocsp_verified = False
|
||||
raise SSLError(
|
||||
(
|
||||
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
|
||||
f"by issuer ({readable_revocation_reason(revocation_status.reason)}). "
|
||||
"You should avoid trying to request anything from it as the remote has been compromised. ",
|
||||
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
|
||||
"for more information.",
|
||||
)
|
||||
)
|
||||
else:
|
||||
r.ocsp_verified = True
|
||||
|
||||
return
|
||||
|
||||
async with cache.lock(peer_certificate):
|
||||
# why are we doing this twice?
|
||||
# because using Semaphore to prevent concurrent
|
||||
# revocation check have a heavy toll!
|
||||
# todo: respect DRY
|
||||
cached_revocation_list = cache.check(crl_distribution_point)
|
||||
|
||||
if cached_revocation_list is not None:
|
||||
issuer_certificate = cache.get_issuer_of(peer_certificate)
|
||||
|
||||
if issuer_certificate is not None:
|
||||
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
|
||||
|
||||
revocation_status = cached_revocation_list.is_revoked(peer_certificate.serial_number)
|
||||
|
||||
if revocation_status is not None:
|
||||
r.ocsp_verified = False
|
||||
raise SSLError(
|
||||
(
|
||||
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
|
||||
f"by issuer ({readable_revocation_reason(revocation_status.reason)}). "
|
||||
"You should avoid trying to request anything from it as the remote has been compromised. ",
|
||||
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
|
||||
"for more information.",
|
||||
)
|
||||
)
|
||||
else:
|
||||
r.ocsp_verified = True
|
||||
|
||||
return
|
||||
|
||||
from .....async_session import AsyncSession
|
||||
|
||||
async with AsyncSession(resolver=resolver, happy_eyeballs=happy_eyeballs) as session:
|
||||
session.trust_env = False
|
||||
if proxies:
|
||||
session.proxies = proxies
|
||||
|
||||
# When using Python native capabilities, you won't have the issuerCA DER by default.
|
||||
# Unfortunately! But no worries, we can circumvent it!
|
||||
# Three ways are valid to fetch it (in order of preference, safest to riskiest):
|
||||
# - The issuer can be (but unlikely) a root CA.
|
||||
# - Retrieve it by asking it from the TLS layer.
|
||||
# - Downloading it using specified caIssuers from the peer certificate.
|
||||
if conn_info.issuer_certificate_der is None:
|
||||
# It could be a root (self-signed) certificate. Or a previously seen issuer.
|
||||
issuer_certificate = cache.get_issuer_of(peer_certificate)
|
||||
|
||||
hint_ca_issuers: list[str] = [
|
||||
ep # type: ignore
|
||||
for ep in list(conn_info.certificate_dict.get("caIssuers", [])) # type: ignore
|
||||
if ep.startswith("http://") # type: ignore
|
||||
]
|
||||
|
||||
if issuer_certificate is None and hint_ca_issuers:
|
||||
try:
|
||||
raw_intermediary_response = await session.get(hint_ca_issuers[0])
|
||||
except RequestException as e:
|
||||
if is_cancelled_error_root_cause(e):
|
||||
return
|
||||
except asyncio.CancelledError: # don't raise any error or warnings!
|
||||
return
|
||||
else:
|
||||
if raw_intermediary_response.status_code and 300 > raw_intermediary_response.status_code >= 200:
|
||||
raw_intermediary_content = raw_intermediary_response.content
|
||||
|
||||
if raw_intermediary_content is not None:
|
||||
# binary DER
|
||||
if b"-----BEGIN CERTIFICATE-----" not in raw_intermediary_content:
|
||||
issuer_certificate = Certificate(raw_intermediary_content)
|
||||
# b64 PEM
|
||||
elif b"-----BEGIN CERTIFICATE-----" in raw_intermediary_content:
|
||||
issuer_certificate = Certificate(
|
||||
ssl.PEM_cert_to_DER_cert(raw_intermediary_content.decode())
|
||||
)
|
||||
|
||||
# Well! We're out of luck. No further should we go.
|
||||
if issuer_certificate is None:
|
||||
# aia fetching should be counted as general ocsp failure too.
|
||||
cache.incr_failure()
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate "
|
||||
"via CRL. You are seeing this warning due to enabling strict mode for OCSP / "
|
||||
"Revocation check. Reason: Remote did not provide any intermediary certificate."
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
|
||||
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
|
||||
else:
|
||||
issuer_certificate = Certificate(conn_info.issuer_certificate_der)
|
||||
|
||||
try:
|
||||
crl_http_response = await session.get(
|
||||
crl_distribution_point,
|
||||
timeout=timeout,
|
||||
)
|
||||
except RequestException as e:
|
||||
if is_cancelled_error_root_cause(e):
|
||||
return
|
||||
# aia fetching should be counted as general ocsp failure too.
|
||||
cache.incr_failure()
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via CRL. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
f"Reason: {e}"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
except asyncio.CancelledError: # don't raise any error or warnings!
|
||||
return
|
||||
|
||||
if crl_http_response.status_code and 300 > crl_http_response.status_code >= 200:
|
||||
if crl_http_response.content is None:
|
||||
return
|
||||
|
||||
try:
|
||||
crl = CertificateRevocationList(crl_http_response.content)
|
||||
except ValueError:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via CRL. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
"Reason: The X509 CRL parser failed to read the response"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
|
||||
if verify_signature:
|
||||
# Verify the signature of the OCSP response with issuer public key
|
||||
try:
|
||||
if not crl.authenticate_for(issuer_certificate.public_bytes()):
|
||||
raise SSLError(
|
||||
f"Unable to establish a secure connection to {r.url} "
|
||||
"because the CRL response received has been tampered. "
|
||||
"You could be targeted by a MITM attack."
|
||||
)
|
||||
except ValueError:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid "
|
||||
"certificate via CRL. You are seeing this warning due to enabling strict "
|
||||
"mode for OCSP / Revocation check. "
|
||||
"Reason: The X509 CRL is signed using an unsupported algorithm."
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
|
||||
cache.save(peer_certificate, issuer_certificate, crl, crl_distribution_point)
|
||||
|
||||
revocation_status = crl.is_revoked(peer_certificate.serial_number)
|
||||
|
||||
if revocation_status is not None:
|
||||
r.ocsp_verified = False
|
||||
raise SSLError(
|
||||
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
|
||||
f"by issuer ({readable_revocation_reason(revocation_status.reason)}). "
|
||||
"You should avoid trying to request anything from it as the remote has been compromised. "
|
||||
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
|
||||
"for more information."
|
||||
)
|
||||
else:
|
||||
r.ocsp_verified = True
|
||||
else:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via CRL. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
f"CRL endpoint: {str(crl_http_response)}"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ("verify",)
|
||||
@@ -0,0 +1,477 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import datetime
|
||||
import ipaddress
|
||||
import ssl
|
||||
import threading
|
||||
import typing
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
from hashlib import sha256
|
||||
from random import randint
|
||||
|
||||
from qh3._hazmat import (
|
||||
Certificate,
|
||||
OCSPCertStatus,
|
||||
OCSPRequest,
|
||||
OCSPResponse,
|
||||
OCSPResponseStatus,
|
||||
ReasonFlags,
|
||||
)
|
||||
|
||||
from ....exceptions import RequestException, SSLError
|
||||
from ....models import PreparedRequest
|
||||
from ....packages.urllib3 import ConnectionInfo
|
||||
from ....packages.urllib3.contrib.resolver import BaseResolver
|
||||
from ....packages.urllib3.exceptions import SecurityWarning
|
||||
from ....typing import ProxyType
|
||||
|
||||
|
||||
@lru_cache(maxsize=64)
|
||||
def _parse_x509_der_cached(der: bytes) -> Certificate:
|
||||
return Certificate(der)
|
||||
|
||||
|
||||
@lru_cache(maxsize=64)
|
||||
def _fingerprint_raw_data(payload: bytes) -> str:
|
||||
return "".join([format(i, "02x") for i in sha256(payload).digest()])
|
||||
|
||||
|
||||
def _str_fingerprint_of(certificate: Certificate) -> str:
|
||||
return _fingerprint_raw_data(certificate.public_bytes())
|
||||
|
||||
|
||||
def readable_revocation_reason(flag: ReasonFlags | None) -> str | None:
|
||||
return str(flag).split(".")[-1].lower() if flag is not None else None
|
||||
|
||||
|
||||
class InMemoryRevocationStatus:
|
||||
def __init__(self, max_size: int = 2048):
|
||||
self._max_size: int = max_size
|
||||
self._store: dict[str, OCSPResponse] = {}
|
||||
self._issuers_map: dict[str, Certificate] = {}
|
||||
self._timings: list[datetime.datetime] = []
|
||||
self._failure_count: int = 0
|
||||
self._access_lock = threading.RLock()
|
||||
self._second_level_locks: dict[str, threading.RLock] = {}
|
||||
self.hold: bool = False
|
||||
|
||||
@staticmethod
|
||||
def support_pickle() -> bool:
|
||||
"""This gives you a hint on whether you can cache it to restore later."""
|
||||
return hasattr(OCSPResponse, "serialize")
|
||||
|
||||
def __getstate__(self) -> dict[str, typing.Any]:
|
||||
with self._access_lock:
|
||||
return {
|
||||
"_max_size": self._max_size,
|
||||
"_store": {k: v.serialize() for k, v in self._store.items()},
|
||||
"_issuers_map": {k: v.serialize() for k, v in self._issuers_map.items()},
|
||||
"_failure_count": self._failure_count,
|
||||
}
|
||||
|
||||
@contextlib.contextmanager
|
||||
def lock_for(self, peer_certificate: Certificate) -> typing.Generator[None]:
|
||||
fingerprint: str = _str_fingerprint_of(peer_certificate)
|
||||
|
||||
with self._access_lock:
|
||||
if fingerprint not in self._second_level_locks:
|
||||
self._second_level_locks[fingerprint] = threading.RLock()
|
||||
|
||||
lock = self._second_level_locks[fingerprint]
|
||||
|
||||
lock.acquire()
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
def __setstate__(self, state: dict[str, typing.Any]) -> None:
|
||||
if "_store" not in state or "_issuers_map" not in state or "_max_size" not in state:
|
||||
raise OSError("unrecoverable state for InMemoryRevocationStatus")
|
||||
|
||||
self._access_lock = threading.RLock()
|
||||
self._second_level_locks = {}
|
||||
self.hold = False
|
||||
self._timings = []
|
||||
|
||||
self._max_size = state["_max_size"]
|
||||
self._failure_count = state["_failure_count"] if "_failure_count" in state else 0
|
||||
|
||||
self._store = {}
|
||||
|
||||
for k, v in state["_store"].items():
|
||||
self._store[k] = OCSPResponse.deserialize(v)
|
||||
|
||||
self._issuers_map = {}
|
||||
|
||||
for k, v in state["_issuers_map"].items():
|
||||
self._issuers_map[k] = Certificate.deserialize(v)
|
||||
|
||||
def get_issuer_of(self, peer_certificate: Certificate) -> Certificate | None:
|
||||
with self._access_lock:
|
||||
fingerprint: str = _str_fingerprint_of(peer_certificate)
|
||||
|
||||
if fingerprint not in self._issuers_map:
|
||||
return None
|
||||
|
||||
return self._issuers_map[fingerprint]
|
||||
|
||||
def __len__(self) -> int:
|
||||
with self._access_lock:
|
||||
return len(self._store)
|
||||
|
||||
def incr_failure(self) -> None:
|
||||
with self._access_lock:
|
||||
self._failure_count += 1
|
||||
|
||||
@property
|
||||
def failure_count(self) -> int:
|
||||
return self._failure_count
|
||||
|
||||
def rate(self):
|
||||
with self._access_lock:
|
||||
previous_dt: datetime.datetime | None = None
|
||||
delays: list[float] = []
|
||||
|
||||
for dt in self._timings:
|
||||
if previous_dt is None:
|
||||
previous_dt = dt
|
||||
continue
|
||||
delays.append((dt - previous_dt).total_seconds())
|
||||
previous_dt = dt
|
||||
|
||||
return sum(delays) / len(delays) if delays else 0.0
|
||||
|
||||
def check(self, peer_certificate: Certificate) -> OCSPResponse | None:
|
||||
with self._access_lock:
|
||||
fingerprint: str = _str_fingerprint_of(peer_certificate)
|
||||
|
||||
if fingerprint not in self._store:
|
||||
return None
|
||||
|
||||
cached_response = self._store[fingerprint]
|
||||
|
||||
if cached_response.certificate_status == OCSPCertStatus.GOOD:
|
||||
if cached_response.next_update and datetime.datetime.now().timestamp() >= cached_response.next_update:
|
||||
del self._store[fingerprint]
|
||||
return None
|
||||
return cached_response
|
||||
|
||||
return cached_response
|
||||
|
||||
def save(
|
||||
self,
|
||||
peer_certificate: Certificate,
|
||||
issuer_certificate: Certificate,
|
||||
ocsp_response: OCSPResponse,
|
||||
) -> None:
|
||||
with self._access_lock:
|
||||
if len(self._store) >= self._max_size:
|
||||
tbd_key: str | None = None
|
||||
closest_next_update: int | None = None
|
||||
|
||||
for k in self._store:
|
||||
if self._store[k].response_status != OCSPResponseStatus.SUCCESSFUL:
|
||||
tbd_key = k
|
||||
break
|
||||
|
||||
if self._store[k].certificate_status != OCSPCertStatus.REVOKED:
|
||||
if closest_next_update is None:
|
||||
closest_next_update = self._store[k].next_update
|
||||
tbd_key = k
|
||||
continue
|
||||
if self._store[k].next_update > closest_next_update: # type: ignore
|
||||
closest_next_update = self._store[k].next_update
|
||||
tbd_key = k
|
||||
|
||||
if tbd_key:
|
||||
del self._store[tbd_key]
|
||||
del self._issuers_map[tbd_key]
|
||||
else:
|
||||
first_key = list(self._store.keys())[0]
|
||||
del self._store[first_key]
|
||||
del self._issuers_map[first_key]
|
||||
|
||||
peer_fingerprint: str = _str_fingerprint_of(peer_certificate)
|
||||
|
||||
self._store[peer_fingerprint] = ocsp_response
|
||||
self._issuers_map[peer_fingerprint] = issuer_certificate
|
||||
self._failure_count = 0
|
||||
|
||||
self._timings.append(datetime.datetime.now())
|
||||
|
||||
if len(self._timings) >= self._max_size:
|
||||
self._timings.pop(0)
|
||||
|
||||
|
||||
def verify(
|
||||
r: PreparedRequest,
|
||||
strict: bool = False,
|
||||
timeout: float | int = 0.2,
|
||||
proxies: ProxyType | None = None,
|
||||
resolver: BaseResolver | None = None,
|
||||
happy_eyeballs: bool | int = False,
|
||||
cache: InMemoryRevocationStatus | None = None,
|
||||
) -> None:
|
||||
conn_info: ConnectionInfo | None = r.conn_info
|
||||
|
||||
# we can't do anything in that case.
|
||||
if conn_info is None or conn_info.certificate_der is None or conn_info.certificate_dict is None:
|
||||
return
|
||||
|
||||
endpoints: list[str] = [ # type: ignore
|
||||
# exclude non-HTTP endpoint. like ldap.
|
||||
ep # type: ignore
|
||||
for ep in list(conn_info.certificate_dict.get("OCSP", [])) # type: ignore
|
||||
if ep.startswith("http://") # type: ignore
|
||||
]
|
||||
|
||||
# well... not all issued certificate have a OCSP entry. e.g. mkcert.
|
||||
if not endpoints:
|
||||
return
|
||||
|
||||
if cache is None:
|
||||
cache = InMemoryRevocationStatus()
|
||||
|
||||
# this feature, by default, is reserved for a reasonable usage.
|
||||
if not strict:
|
||||
if cache.failure_count >= 4:
|
||||
return
|
||||
|
||||
mean_rate_sec = cache.rate()
|
||||
cache_count = len(cache)
|
||||
|
||||
if cache_count >= 10 and mean_rate_sec <= 1.0:
|
||||
cache.hold = True
|
||||
|
||||
if cache.hold:
|
||||
return
|
||||
|
||||
# some corporate environment
|
||||
# have invalid OCSP implementation
|
||||
# they use a cert that IS NOT in the chain
|
||||
# to sign the response. It's weird but true.
|
||||
# see https://github.com/jawah/niquests/issues/274
|
||||
ignore_signature_without_strict = ipaddress.ip_address(conn_info.destination_address[0]).is_private or bool(proxies)
|
||||
verify_signature = strict is True or ignore_signature_without_strict is False
|
||||
|
||||
peer_certificate = _parse_x509_der_cached(conn_info.certificate_der)
|
||||
|
||||
with cache.lock_for(peer_certificate):
|
||||
cached_response = cache.check(peer_certificate)
|
||||
|
||||
if cached_response is not None:
|
||||
issuer_certificate = cache.get_issuer_of(peer_certificate)
|
||||
|
||||
if issuer_certificate:
|
||||
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
|
||||
|
||||
if cached_response.response_status == OCSPResponseStatus.SUCCESSFUL:
|
||||
if cached_response.certificate_status == OCSPCertStatus.REVOKED:
|
||||
r.ocsp_verified = False
|
||||
raise SSLError(
|
||||
(
|
||||
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
|
||||
f"by issuer ({readable_revocation_reason(cached_response.revocation_reason) or 'unspecified'}). "
|
||||
"You should avoid trying to request anything from it as the remote has been compromised. ",
|
||||
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
|
||||
"for more information.",
|
||||
)
|
||||
)
|
||||
elif cached_response.certificate_status == OCSPCertStatus.UNKNOWN:
|
||||
r.ocsp_verified = False
|
||||
if strict is True:
|
||||
raise SSLError(
|
||||
f"Unable to establish a secure connection to {r.url} because the issuer does not know "
|
||||
"whether certificate is valid or not. This error occurred because you enabled strict mode "
|
||||
"for the OCSP / Revocation check."
|
||||
)
|
||||
else:
|
||||
r.ocsp_verified = True
|
||||
|
||||
return
|
||||
|
||||
from ....sessions import Session
|
||||
|
||||
with Session(resolver=resolver, happy_eyeballs=happy_eyeballs) as session:
|
||||
session.trust_env = False
|
||||
if proxies:
|
||||
session.proxies = proxies
|
||||
|
||||
# When using Python native capabilities, you won't have the issuerCA DER by default.
|
||||
# Unfortunately! But no worries, we can circumvent it!
|
||||
# Three ways are valid to fetch it (in order of preference, safest to riskiest):
|
||||
# - The issuer can be (but unlikely) a root CA.
|
||||
# - Retrieve it by asking it from the TLS layer.
|
||||
# - Downloading it using specified caIssuers from the peer certificate.
|
||||
if conn_info.issuer_certificate_der is None:
|
||||
# It could be a root (self-signed) certificate. Or a previously seen issuer.
|
||||
issuer_certificate = cache.get_issuer_of(peer_certificate)
|
||||
|
||||
hint_ca_issuers: list[str] = [
|
||||
ep # type: ignore
|
||||
for ep in list(conn_info.certificate_dict.get("caIssuers", [])) # type: ignore
|
||||
if ep.startswith("http://") # type: ignore
|
||||
]
|
||||
|
||||
# try to do AIA fetching of intermediate certificate (issuer)
|
||||
if issuer_certificate is None and hint_ca_issuers:
|
||||
try:
|
||||
raw_intermediary_response = session.get(hint_ca_issuers[0])
|
||||
except RequestException:
|
||||
pass
|
||||
else:
|
||||
if raw_intermediary_response.status_code and 300 > raw_intermediary_response.status_code >= 200:
|
||||
raw_intermediary_content = raw_intermediary_response.content
|
||||
|
||||
if raw_intermediary_content is not None:
|
||||
# binary DER
|
||||
if b"-----BEGIN CERTIFICATE-----" not in raw_intermediary_content:
|
||||
issuer_certificate = Certificate(raw_intermediary_content)
|
||||
# b64 PEM
|
||||
elif b"-----BEGIN CERTIFICATE-----" in raw_intermediary_content:
|
||||
issuer_certificate = Certificate(
|
||||
ssl.PEM_cert_to_DER_cert(raw_intermediary_content.decode())
|
||||
)
|
||||
|
||||
# Well! We're out of luck. No further should we go.
|
||||
if issuer_certificate is None:
|
||||
# aia fetching should be counted as general ocsp failure too.
|
||||
cache.incr_failure()
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate "
|
||||
"via OCSP. You are seeing this warning due to enabling strict mode for OCSP / "
|
||||
"Revocation check. Reason: Remote did not provide any intermediary certificate."
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
|
||||
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
|
||||
else:
|
||||
issuer_certificate = Certificate(conn_info.issuer_certificate_der)
|
||||
|
||||
try:
|
||||
req = OCSPRequest(peer_certificate.public_bytes(), issuer_certificate.public_bytes())
|
||||
except ValueError:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
"Reason: The X509 OCSP generator failed to assemble the request."
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
ocsp_http_response = session.post(
|
||||
endpoints[randint(0, len(endpoints) - 1)],
|
||||
data=req.public_bytes(),
|
||||
headers={"Content-Type": "application/ocsp-request"},
|
||||
timeout=timeout,
|
||||
)
|
||||
except RequestException as e:
|
||||
# we want to monitor failures related to the responder.
|
||||
# we don't want to ruin the http experience in normal circumstances.
|
||||
cache.incr_failure()
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
f"Reason: {e}"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
|
||||
if ocsp_http_response.status_code and 300 > ocsp_http_response.status_code >= 200:
|
||||
if ocsp_http_response.content is None:
|
||||
return
|
||||
|
||||
try:
|
||||
ocsp_resp = OCSPResponse(ocsp_http_response.content)
|
||||
except ValueError:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
"Reason: The X509 OCSP parser failed to read the response"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
|
||||
# Verify the signature of the OCSP response with issuer public key
|
||||
if verify_signature:
|
||||
try:
|
||||
if not ocsp_resp.authenticate_for(issuer_certificate.public_bytes()): # type: ignore[attr-defined]
|
||||
raise SSLError(
|
||||
f"Unable to establish a secure connection to {r.url} "
|
||||
"because the OCSP response received has been tampered. "
|
||||
"You could be targeted by a MITM attack."
|
||||
)
|
||||
except ValueError: # Defensive: unsupported signature case
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate "
|
||||
"via OCSP. You are seeing this warning due to enabling strict mode for OCSP / "
|
||||
"Revocation check. Reason: The X509 OCSP response is signed using an unsupported algorithm."
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
|
||||
cache.save(peer_certificate, issuer_certificate, ocsp_resp)
|
||||
|
||||
if ocsp_resp.response_status == OCSPResponseStatus.SUCCESSFUL:
|
||||
if ocsp_resp.certificate_status == OCSPCertStatus.REVOKED:
|
||||
r.ocsp_verified = False
|
||||
raise SSLError(
|
||||
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
|
||||
f"by issuer ({readable_revocation_reason(ocsp_resp.revocation_reason) or 'unspecified'}). "
|
||||
"You should avoid trying to request anything from it as the remote has been compromised. "
|
||||
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
|
||||
"for more information."
|
||||
)
|
||||
if ocsp_resp.certificate_status == OCSPCertStatus.UNKNOWN:
|
||||
r.ocsp_verified = False
|
||||
if strict is True:
|
||||
raise SSLError(
|
||||
f"Unable to establish a secure connection to {r.url} because the issuer does not know whether "
|
||||
"certificate is valid or not. This error occurred because you enabled strict mode for "
|
||||
"the OCSP / Revocation check."
|
||||
)
|
||||
else:
|
||||
r.ocsp_verified = True
|
||||
else:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
f"OCSP Server Status: {ocsp_resp.response_status}"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
else:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
f"OCSP Server Status: {str(ocsp_http_response)}"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ("verify",)
|
||||
@@ -0,0 +1,492 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import ipaddress
|
||||
import ssl
|
||||
import typing
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
from random import randint
|
||||
|
||||
from qh3._hazmat import (
|
||||
Certificate,
|
||||
OCSPCertStatus,
|
||||
OCSPRequest,
|
||||
OCSPResponse,
|
||||
OCSPResponseStatus,
|
||||
)
|
||||
|
||||
from .....exceptions import RequestException, SSLError
|
||||
from .....models import PreparedRequest
|
||||
from .....packages.urllib3 import ConnectionInfo
|
||||
from .....packages.urllib3.contrib.resolver._async import AsyncBaseResolver
|
||||
from .....packages.urllib3.exceptions import SecurityWarning
|
||||
from .....typing import ProxyType
|
||||
from .....utils import is_cancelled_error_root_cause
|
||||
from .. import (
|
||||
_parse_x509_der_cached,
|
||||
_str_fingerprint_of,
|
||||
readable_revocation_reason,
|
||||
)
|
||||
|
||||
|
||||
class InMemoryRevocationStatus:
|
||||
def __init__(self, max_size: int = 2048):
|
||||
self._max_size: int = max_size
|
||||
self._store: dict[str, OCSPResponse] = {}
|
||||
self._semaphores: dict[str, asyncio.Semaphore] = {}
|
||||
self._issuers_map: dict[str, Certificate] = {}
|
||||
self._timings: list[datetime.datetime] = []
|
||||
self._failure_count: int = 0
|
||||
self.hold: bool = False
|
||||
|
||||
@staticmethod
|
||||
def support_pickle() -> bool:
|
||||
"""This gives you a hint on whether you can cache it to restore later."""
|
||||
return hasattr(OCSPResponse, "serialize")
|
||||
|
||||
def __getstate__(self) -> dict[str, typing.Any]:
|
||||
return {
|
||||
"_max_size": self._max_size,
|
||||
"_store": {k: v.serialize() for k, v in self._store.items()},
|
||||
"_issuers_map": {k: v.serialize() for k, v in self._issuers_map.items()},
|
||||
"_failure_count": self._failure_count,
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict[str, typing.Any]) -> None:
|
||||
if "_store" not in state or "_issuers_map" not in state or "_max_size" not in state:
|
||||
raise OSError("unrecoverable state for InMemoryRevocationStatus")
|
||||
|
||||
self.hold = False
|
||||
self._timings = []
|
||||
|
||||
self._max_size = state["_max_size"]
|
||||
self._failure_count = state["_failure_count"] if "_failure_count" in state else 0
|
||||
|
||||
self._store = {}
|
||||
self._semaphores = {}
|
||||
|
||||
for k, v in state["_store"].items():
|
||||
self._store[k] = OCSPResponse.deserialize(v)
|
||||
self._semaphores[k] = asyncio.Semaphore()
|
||||
|
||||
self._issuers_map = {}
|
||||
|
||||
for k, v in state["_issuers_map"].items():
|
||||
self._issuers_map[k] = Certificate.deserialize(v)
|
||||
|
||||
def get_issuer_of(self, peer_certificate: Certificate) -> Certificate | None:
|
||||
fingerprint: str = _str_fingerprint_of(peer_certificate)
|
||||
|
||||
if fingerprint not in self._issuers_map:
|
||||
return None
|
||||
|
||||
return self._issuers_map[fingerprint]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._store)
|
||||
|
||||
def incr_failure(self) -> None:
|
||||
self._failure_count += 1
|
||||
|
||||
@property
|
||||
def failure_count(self) -> int:
|
||||
return self._failure_count
|
||||
|
||||
@asynccontextmanager
|
||||
async def lock(self, peer_certificate: Certificate) -> typing.AsyncGenerator[None, None]:
|
||||
fingerprint: str = _str_fingerprint_of(peer_certificate)
|
||||
|
||||
if fingerprint not in self._semaphores:
|
||||
self._semaphores[fingerprint] = asyncio.Semaphore()
|
||||
|
||||
await self._semaphores[fingerprint].acquire()
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._semaphores[fingerprint].release()
|
||||
|
||||
def rate(self):
|
||||
previous_dt: datetime.datetime | None = None
|
||||
delays: list[float] = []
|
||||
|
||||
for dt in self._timings:
|
||||
if previous_dt is None:
|
||||
previous_dt = dt
|
||||
continue
|
||||
delays.append((dt - previous_dt).total_seconds())
|
||||
previous_dt = dt
|
||||
|
||||
return sum(delays) / len(delays) if delays else 0.0
|
||||
|
||||
def check(self, peer_certificate: Certificate) -> OCSPResponse | None:
|
||||
fingerprint: str = _str_fingerprint_of(peer_certificate)
|
||||
|
||||
if fingerprint not in self._store:
|
||||
return None
|
||||
|
||||
cached_response = self._store[fingerprint]
|
||||
|
||||
if cached_response.certificate_status == OCSPCertStatus.GOOD:
|
||||
if cached_response.next_update and datetime.datetime.now().timestamp() >= cached_response.next_update:
|
||||
del self._store[fingerprint]
|
||||
return None
|
||||
return cached_response
|
||||
|
||||
return cached_response
|
||||
|
||||
def save(
|
||||
self,
|
||||
peer_certificate: Certificate,
|
||||
issuer_certificate: Certificate,
|
||||
ocsp_response: OCSPResponse,
|
||||
) -> None:
|
||||
if len(self._store) >= self._max_size:
|
||||
tbd_key: str | None = None
|
||||
closest_next_update: int | None = None
|
||||
|
||||
for k in self._store:
|
||||
if self._store[k].response_status != OCSPResponseStatus.SUCCESSFUL:
|
||||
tbd_key = k
|
||||
break
|
||||
|
||||
if self._store[k].certificate_status != OCSPCertStatus.REVOKED:
|
||||
if closest_next_update is None:
|
||||
closest_next_update = self._store[k].next_update
|
||||
tbd_key = k
|
||||
continue
|
||||
if self._store[k].next_update > closest_next_update: # type: ignore
|
||||
closest_next_update = self._store[k].next_update
|
||||
tbd_key = k
|
||||
|
||||
if tbd_key:
|
||||
del self._store[tbd_key]
|
||||
del self._issuers_map[tbd_key]
|
||||
else:
|
||||
first_key = list(self._store.keys())[0]
|
||||
del self._store[first_key]
|
||||
del self._issuers_map[first_key]
|
||||
|
||||
peer_fingerprint: str = _str_fingerprint_of(peer_certificate)
|
||||
|
||||
self._store[peer_fingerprint] = ocsp_response
|
||||
self._issuers_map[peer_fingerprint] = issuer_certificate
|
||||
self._failure_count = 0
|
||||
|
||||
self._timings.append(datetime.datetime.now())
|
||||
|
||||
if len(self._timings) >= self._max_size:
|
||||
self._timings.pop(0)
|
||||
|
||||
|
||||
async def verify(
|
||||
r: PreparedRequest,
|
||||
strict: bool = False,
|
||||
timeout: float | int = 0.2,
|
||||
proxies: ProxyType | None = None,
|
||||
resolver: AsyncBaseResolver | None = None,
|
||||
happy_eyeballs: bool | int = False,
|
||||
cache: InMemoryRevocationStatus | None = None,
|
||||
) -> None:
|
||||
conn_info: ConnectionInfo | None = r.conn_info
|
||||
|
||||
# we can't do anything in that case.
|
||||
if conn_info is None or conn_info.certificate_der is None or conn_info.certificate_dict is None:
|
||||
return
|
||||
|
||||
endpoints: list[str] = [ # type: ignore
|
||||
# exclude non-HTTP endpoint. like ldap.
|
||||
ep # type: ignore
|
||||
for ep in list(conn_info.certificate_dict.get("OCSP", [])) # type: ignore
|
||||
if ep.startswith("http://") # type: ignore
|
||||
]
|
||||
|
||||
# well... not all issued certificate have a OCSP entry. e.g. mkcert.
|
||||
if not endpoints:
|
||||
return
|
||||
|
||||
if cache is None:
|
||||
cache = InMemoryRevocationStatus()
|
||||
|
||||
peer_certificate = _parse_x509_der_cached(conn_info.certificate_der)
|
||||
|
||||
cached_response = cache.check(peer_certificate)
|
||||
|
||||
if cached_response is not None:
|
||||
issuer_certificate = cache.get_issuer_of(peer_certificate)
|
||||
|
||||
if issuer_certificate:
|
||||
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
|
||||
|
||||
if cached_response.response_status == OCSPResponseStatus.SUCCESSFUL:
|
||||
if cached_response.certificate_status == OCSPCertStatus.REVOKED:
|
||||
r.ocsp_verified = False
|
||||
raise SSLError(
|
||||
(
|
||||
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
|
||||
f"by issuer ({readable_revocation_reason(cached_response.revocation_reason) or 'unspecified'}). "
|
||||
"You should avoid trying to request anything from it as the remote has been compromised. ",
|
||||
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
|
||||
"for more information.",
|
||||
)
|
||||
)
|
||||
elif cached_response.certificate_status == OCSPCertStatus.UNKNOWN:
|
||||
r.ocsp_verified = False
|
||||
if strict is True:
|
||||
raise SSLError(
|
||||
f"Unable to establish a secure connection to {r.url} because the issuer does not know "
|
||||
"whether certificate is valid or not. This error occurred because you enabled strict mode "
|
||||
"for the OCSP / Revocation check."
|
||||
)
|
||||
else:
|
||||
r.ocsp_verified = True
|
||||
|
||||
return
|
||||
|
||||
async with cache.lock(peer_certificate):
|
||||
# this feature, by default, is reserved for a reasonable usage.
|
||||
if not strict:
|
||||
if cache.failure_count >= 4:
|
||||
return
|
||||
|
||||
mean_rate_sec = cache.rate()
|
||||
cache_count = len(cache)
|
||||
|
||||
if cache_count >= 10 and mean_rate_sec <= 1.0:
|
||||
cache.hold = True
|
||||
|
||||
if cache.hold:
|
||||
return
|
||||
|
||||
# some corporate environment
|
||||
# have invalid OCSP implementation
|
||||
# they use a cert that IS NOT in the chain
|
||||
# to sign the response. It's weird but true.
|
||||
# see https://github.com/jawah/niquests/issues/274
|
||||
ignore_signature_without_strict = ipaddress.ip_address(conn_info.destination_address[0]).is_private or bool(proxies)
|
||||
verify_signature = strict is True or ignore_signature_without_strict is False
|
||||
|
||||
# why are we doing this twice?
|
||||
# because using Semaphore to prevent concurrent
|
||||
# revocation check have a heavy toll!
|
||||
# todo: respect DRY
|
||||
cached_response = cache.check(peer_certificate)
|
||||
|
||||
if cached_response is not None:
|
||||
issuer_certificate = cache.get_issuer_of(peer_certificate)
|
||||
|
||||
if issuer_certificate:
|
||||
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
|
||||
|
||||
if cached_response.response_status == OCSPResponseStatus.SUCCESSFUL:
|
||||
if cached_response.certificate_status == OCSPCertStatus.REVOKED:
|
||||
r.ocsp_verified = False
|
||||
raise SSLError(
|
||||
(
|
||||
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
|
||||
f"by issuer ({readable_revocation_reason(cached_response.revocation_reason) or 'unspecified'}). "
|
||||
"You should avoid trying to request anything from it as the remote has been compromised. ",
|
||||
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
|
||||
"for more information.",
|
||||
)
|
||||
)
|
||||
elif cached_response.certificate_status == OCSPCertStatus.UNKNOWN:
|
||||
r.ocsp_verified = False
|
||||
if strict is True:
|
||||
raise SSLError(
|
||||
f"Unable to establish a secure connection to {r.url} because the issuer does not know "
|
||||
"whether certificate is valid or not. This error occurred because you enabled strict mode "
|
||||
"for the OCSP / Revocation check."
|
||||
)
|
||||
else:
|
||||
r.ocsp_verified = True
|
||||
|
||||
return
|
||||
|
||||
from .....async_session import AsyncSession
|
||||
|
||||
async with AsyncSession(resolver=resolver, happy_eyeballs=happy_eyeballs) as session:
|
||||
session.trust_env = False
|
||||
if proxies:
|
||||
session.proxies = proxies
|
||||
|
||||
# When using Python native capabilities, you won't have the issuerCA DER by default (Python 3.7 to 3.9).
|
||||
# Unfortunately! But no worries, we can circumvent it! (Python 3.10+ is not concerned anymore)
|
||||
# Three ways are valid to fetch it (in order of preference, safest to riskiest):
|
||||
# - The issuer can be (but unlikely) a root CA.
|
||||
# - Retrieve it by asking it from the TLS layer.
|
||||
# - Downloading it using specified caIssuers from the peer certificate.
|
||||
if conn_info.issuer_certificate_der is None:
|
||||
# It could be a root (self-signed) certificate. Or a previously seen issuer.
|
||||
issuer_certificate = cache.get_issuer_of(peer_certificate)
|
||||
|
||||
hint_ca_issuers: list[str] = [
|
||||
ep # type: ignore
|
||||
for ep in list(conn_info.certificate_dict.get("caIssuers", [])) # type: ignore
|
||||
if ep.startswith("http://") # type: ignore
|
||||
]
|
||||
|
||||
if issuer_certificate is None and hint_ca_issuers:
|
||||
try:
|
||||
raw_intermediary_response = await session.get(hint_ca_issuers[0])
|
||||
except RequestException as e:
|
||||
if is_cancelled_error_root_cause(e):
|
||||
return
|
||||
except asyncio.CancelledError: # don't raise any error or warnings!
|
||||
return
|
||||
else:
|
||||
if raw_intermediary_response.status_code and 300 > raw_intermediary_response.status_code >= 200:
|
||||
raw_intermediary_content = raw_intermediary_response.content
|
||||
|
||||
if raw_intermediary_content is not None:
|
||||
# binary DER
|
||||
if b"-----BEGIN CERTIFICATE-----" not in raw_intermediary_content:
|
||||
issuer_certificate = Certificate(raw_intermediary_content)
|
||||
# b64 PEM
|
||||
elif b"-----BEGIN CERTIFICATE-----" in raw_intermediary_content:
|
||||
issuer_certificate = Certificate(
|
||||
ssl.PEM_cert_to_DER_cert(raw_intermediary_content.decode())
|
||||
)
|
||||
|
||||
# Well! We're out of luck. No further should we go.
|
||||
if issuer_certificate is None:
|
||||
cache.incr_failure()
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate "
|
||||
"via OCSP. You are seeing this warning due to enabling strict mode for OCSP / "
|
||||
"Revocation check. Reason: Remote did not provide any intermediary certificate."
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
|
||||
conn_info.issuer_certificate_der = issuer_certificate.public_bytes()
|
||||
else:
|
||||
issuer_certificate = Certificate(conn_info.issuer_certificate_der)
|
||||
|
||||
try:
|
||||
req = OCSPRequest(peer_certificate.public_bytes(), issuer_certificate.public_bytes())
|
||||
except ValueError:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
"Reason: The X509 OCSP generator failed to assemble the request."
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
ocsp_http_response = await session.post(
|
||||
endpoints[randint(0, len(endpoints) - 1)],
|
||||
data=req.public_bytes(),
|
||||
headers={"Content-Type": "application/ocsp-request"},
|
||||
timeout=timeout,
|
||||
)
|
||||
except RequestException as e:
|
||||
if is_cancelled_error_root_cause(e):
|
||||
return
|
||||
cache.incr_failure()
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
f"Reason: {e}"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
except asyncio.CancelledError: # don't raise any error or warnings!
|
||||
return
|
||||
|
||||
if ocsp_http_response.status_code and 300 > ocsp_http_response.status_code >= 200:
|
||||
if ocsp_http_response.content is None:
|
||||
return
|
||||
|
||||
try:
|
||||
ocsp_resp = OCSPResponse(ocsp_http_response.content)
|
||||
except ValueError:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
"Reason: The X509 OCSP parser failed to read the response"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
return
|
||||
|
||||
if verify_signature:
|
||||
# Verify the signature of the OCSP response with issuer public key
|
||||
try:
|
||||
if not ocsp_resp.authenticate_for(issuer_certificate.public_bytes()): # type: ignore[attr-defined]
|
||||
raise SSLError(
|
||||
f"Unable to establish a secure connection to {r.url} "
|
||||
"because the OCSP response received has been tampered. "
|
||||
"You could be targeted by a MITM attack."
|
||||
)
|
||||
except ValueError:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently "
|
||||
"valid certificate via OCSP. You are seeing this warning due to "
|
||||
"enabling strict mode for OCSP / Revocation check. "
|
||||
"Reason: The X509 OCSP response is signed using an unsupported algorithm."
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
|
||||
cache.save(peer_certificate, issuer_certificate, ocsp_resp)
|
||||
|
||||
if ocsp_resp.response_status == OCSPResponseStatus.SUCCESSFUL:
|
||||
if ocsp_resp.certificate_status == OCSPCertStatus.REVOKED:
|
||||
r.ocsp_verified = False
|
||||
raise SSLError(
|
||||
f"Unable to establish a secure connection to {r.url} because the certificate has been revoked "
|
||||
f"by issuer ({readable_revocation_reason(ocsp_resp.revocation_reason) or 'unspecified'}). "
|
||||
"You should avoid trying to request anything from it as the remote has been compromised. "
|
||||
"See https://niquests.readthedocs.io/en/latest/user/advanced.html#ocsp-or-certificate-revocation "
|
||||
"for more information."
|
||||
)
|
||||
if ocsp_resp.certificate_status == OCSPCertStatus.UNKNOWN:
|
||||
r.ocsp_verified = False
|
||||
if strict is True:
|
||||
raise SSLError(
|
||||
f"Unable to establish a secure connection to {r.url} because the issuer does not know whether "
|
||||
"certificate is valid or not. This error occurred because you enabled strict mode for "
|
||||
"the OCSP / Revocation check."
|
||||
)
|
||||
else:
|
||||
r.ocsp_verified = True
|
||||
else:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
f"OCSP Server Status: {ocsp_resp.response_status}"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
else:
|
||||
if strict:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Unable to insure that the remote peer ({r.url}) has a currently valid certificate via OCSP. "
|
||||
"You are seeing this warning due to enabling strict mode for OCSP / Revocation check. "
|
||||
f"OCSP Server Status: {str(ocsp_http_response)}"
|
||||
),
|
||||
SecurityWarning,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ("verify",)
|
||||
@@ -0,0 +1,325 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from ..._constant import DEFAULT_RETRIES
|
||||
from ...adapters import BaseAdapter
|
||||
from ...models import PreparedRequest, Response
|
||||
from ...packages.urllib3.exceptions import MaxRetryError
|
||||
from ...packages.urllib3.response import BytesQueueBuffer
|
||||
from ...packages.urllib3.response import HTTPResponse as BaseHTTPResponse
|
||||
from ...packages.urllib3.util import Timeout as TimeoutSauce
|
||||
from ...packages.urllib3.util.retry import Retry
|
||||
from ...structures import CaseInsensitiveDict
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ...typing import ProxyType, RetryType, TLSClientCertType, TLSVerifyType, WSGIApp
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
class _WSGIRawIO:
|
||||
"""File-like wrapper around a WSGI response iterator for streaming."""
|
||||
|
||||
def __init__(self, generator: typing.Generator[bytes, None, None], headers: list[tuple[str, str]]) -> None:
|
||||
self._generator = generator
|
||||
self._buffer = BytesQueueBuffer()
|
||||
self._closed = False
|
||||
self.headers = headers
|
||||
self.extension: typing.Any = None
|
||||
|
||||
def read(
|
||||
self,
|
||||
amt: int | None = None,
|
||||
decode_content: bool = True,
|
||||
) -> bytes:
|
||||
if self._closed:
|
||||
return b""
|
||||
|
||||
if amt is None or amt < 0:
|
||||
# Read all remaining
|
||||
for chunk in self._generator:
|
||||
self._buffer.put(chunk)
|
||||
return self._buffer.get(len(self._buffer))
|
||||
|
||||
# Read specific amount
|
||||
while len(self._buffer) < amt:
|
||||
try:
|
||||
self._buffer.put(next(self._generator))
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
if len(self._buffer) == 0:
|
||||
return b""
|
||||
|
||||
return self._buffer.get(min(amt, len(self._buffer)))
|
||||
|
||||
def stream(self, amt: int, decode_content: bool = True) -> typing.Generator[bytes, None, None]:
|
||||
"""Iterate over chunks of the response."""
|
||||
while True:
|
||||
chunk = self.read(amt)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
def close(self) -> None:
|
||||
self._closed = True
|
||||
if hasattr(self._generator, "close"):
|
||||
self._generator.close()
|
||||
|
||||
def __iter__(self) -> typing.Iterator[bytes]:
|
||||
return self
|
||||
|
||||
def __next__(self) -> bytes:
|
||||
chunk = self.read(8192)
|
||||
if not chunk:
|
||||
raise StopIteration
|
||||
return chunk
|
||||
|
||||
|
||||
class WebServerGatewayInterface(BaseAdapter):
|
||||
"""Adapter for making requests to WSGI applications directly."""
|
||||
|
||||
def __init__(self, app: WSGIApp, max_retries: RetryType = DEFAULT_RETRIES) -> None:
|
||||
"""
|
||||
Initialize the WSGI adapter.
|
||||
|
||||
:param app: A WSGI application callable.
|
||||
:param max_retries: Maximum number of retries for requests.
|
||||
"""
|
||||
super().__init__()
|
||||
self.app = app
|
||||
|
||||
if isinstance(max_retries, Retry):
|
||||
self.max_retries = max_retries
|
||||
else:
|
||||
self.max_retries = Retry.from_int(max_retries)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<WSGIAdapter />"
|
||||
|
||||
def send(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
stream: bool = False,
|
||||
timeout: int | float | tuple | TimeoutSauce | None = None,
|
||||
verify: TLSVerifyType = True,
|
||||
cert: TLSClientCertType | None = None,
|
||||
proxies: ProxyType | None = None,
|
||||
on_post_connection: typing.Callable[[typing.Any], None] | None = None,
|
||||
on_upload_body: typing.Callable[[int, int | None, bool, bool], None] | None = None,
|
||||
on_early_response: typing.Callable[[Response], None] | None = None,
|
||||
multiplexed: bool = False,
|
||||
) -> Response:
|
||||
"""Send a PreparedRequest to the WSGI application."""
|
||||
if isinstance(timeout, tuple):
|
||||
if len(timeout) == 3:
|
||||
timeout = timeout[2] or timeout[0] # prefer total, fallback connect
|
||||
else:
|
||||
timeout = timeout[0] # use connect
|
||||
elif isinstance(timeout, TimeoutSauce):
|
||||
timeout = timeout.total or timeout.connect_timeout
|
||||
|
||||
retries = self.max_retries
|
||||
method = request.method or "GET"
|
||||
|
||||
while True:
|
||||
try:
|
||||
response = self._do_send(request, stream)
|
||||
except Exception as err:
|
||||
try:
|
||||
retries = retries.increment(method, request.url, error=err)
|
||||
except MaxRetryError:
|
||||
raise
|
||||
|
||||
retries.sleep()
|
||||
continue
|
||||
|
||||
# we rely on the urllib3 implementation for retries
|
||||
# so we basically mock a response to get it to work
|
||||
base_response = BaseHTTPResponse(
|
||||
body=b"",
|
||||
headers=response.headers,
|
||||
status=response.status_code,
|
||||
request_method=request.method,
|
||||
request_url=request.url,
|
||||
)
|
||||
|
||||
# Check if we should retry based on status code
|
||||
has_retry_after = bool(response.headers.get("Retry-After"))
|
||||
|
||||
if retries.is_retry(method, response.status_code, has_retry_after):
|
||||
try:
|
||||
retries = retries.increment(method, request.url, response=base_response)
|
||||
except MaxRetryError:
|
||||
if retries.raise_on_status:
|
||||
raise
|
||||
return response
|
||||
|
||||
retries.sleep(base_response)
|
||||
continue
|
||||
|
||||
return response
|
||||
|
||||
def _do_send_sse(self, request: PreparedRequest) -> Response:
|
||||
"""Handle SSE requests by wrapping the WSGI response with SSE parsing."""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ._sse import WSGISSEExtension
|
||||
|
||||
# Convert scheme: sse:// -> https://, psse:// -> http://
|
||||
parsed = urlparse(request.url)
|
||||
if parsed.scheme == "sse":
|
||||
http_scheme = "https"
|
||||
else:
|
||||
http_scheme = "http"
|
||||
|
||||
original_url = request.url
|
||||
request.url = request.url.replace(f"{parsed.scheme}://", f"{http_scheme}://", 1) # type: ignore[union-attr,str-bytes-safe]
|
||||
|
||||
environ = self._create_environ(request)
|
||||
request.url = original_url
|
||||
|
||||
status_code = None
|
||||
response_headers: list[tuple[str, str]] = []
|
||||
|
||||
def start_response(status: str, headers: list[tuple[str, str]], exc_info=None):
|
||||
nonlocal status_code, response_headers
|
||||
status_code = int(status.split(" ", 1)[0])
|
||||
response_headers = headers
|
||||
|
||||
result = self.app(environ, start_response)
|
||||
|
||||
def generate():
|
||||
try:
|
||||
yield from result
|
||||
finally:
|
||||
if hasattr(result, "close"):
|
||||
result.close()
|
||||
|
||||
ext = WSGISSEExtension(generate())
|
||||
|
||||
response = Response()
|
||||
response.status_code = status_code
|
||||
response.headers = CaseInsensitiveDict(response_headers)
|
||||
response.request = request
|
||||
response.url = original_url
|
||||
response.encoding = response.headers.get("content-type", "utf-8") # type: ignore[assignment]
|
||||
|
||||
raw_io = _WSGIRawIO(iter([]), response_headers) # type: ignore[arg-type]
|
||||
raw_io.extension = ext
|
||||
response.raw = raw_io # type: ignore
|
||||
response._content = False
|
||||
response._content_consumed = False
|
||||
|
||||
return response
|
||||
|
||||
def _do_send(self, request: PreparedRequest, stream: bool) -> Response:
|
||||
"""Perform the actual WSGI request."""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(request.url)
|
||||
|
||||
if parsed.scheme in ("ws", "wss"):
|
||||
raise NotImplementedError("WebSocket is not supported over WSGI")
|
||||
|
||||
if parsed.scheme in ("sse", "psse"):
|
||||
return self._do_send_sse(request)
|
||||
|
||||
environ = self._create_environ(request)
|
||||
|
||||
status_code = None
|
||||
response_headers: list[tuple[str, str]] = []
|
||||
|
||||
def start_response(status: str, headers: list[tuple[str, str]], exc_info=None):
|
||||
nonlocal status_code, response_headers
|
||||
status_code = int(status.split(" ", 1)[0])
|
||||
response_headers = headers
|
||||
|
||||
result = self.app(environ, start_response)
|
||||
|
||||
response = Response()
|
||||
|
||||
response.status_code = status_code
|
||||
response.headers = CaseInsensitiveDict(response_headers)
|
||||
response.request = request
|
||||
response.url = request.url
|
||||
response.encoding = response.headers.get("content-type", "utf-8") # type: ignore[assignment]
|
||||
|
||||
# Wrap the WSGI iterator for streaming
|
||||
def generate():
|
||||
try:
|
||||
yield from result
|
||||
finally:
|
||||
if hasattr(result, "close"):
|
||||
result.close()
|
||||
|
||||
response.raw = _WSGIRawIO(generate(), response_headers) # type: ignore
|
||||
|
||||
if stream:
|
||||
response._content = False # Indicate content not yet consumed
|
||||
response._content_consumed = False
|
||||
else:
|
||||
# Consume all content immediately
|
||||
body_chunks: list[bytes] = []
|
||||
try:
|
||||
for chunk in result:
|
||||
body_chunks.append(chunk)
|
||||
finally:
|
||||
if hasattr(result, "close"):
|
||||
result.close()
|
||||
response._content = b"".join(body_chunks)
|
||||
|
||||
return response
|
||||
|
||||
def _create_environ(self, request: PreparedRequest) -> dict:
|
||||
"""Create a WSGI environ dict from a PreparedRequest."""
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
parsed = urlparse(request.url)
|
||||
|
||||
body = request.body or b""
|
||||
|
||||
if isinstance(body, str):
|
||||
body = body.encode("utf-8")
|
||||
elif isinstance(body, typing.Iterable) and not isinstance(body, (str, bytes, bytearray, tuple)):
|
||||
tmp = b""
|
||||
|
||||
for chunk in body:
|
||||
tmp += chunk # type: ignore[operator]
|
||||
|
||||
body = tmp
|
||||
|
||||
environ = {
|
||||
"REQUEST_METHOD": request.method,
|
||||
"SCRIPT_NAME": "",
|
||||
"PATH_INFO": unquote(parsed.path) or "/",
|
||||
"QUERY_STRING": parsed.query or "",
|
||||
"SERVER_NAME": parsed.hostname or "localhost",
|
||||
"SERVER_PORT": str(parsed.port or (443 if parsed.scheme == "https" else 80)),
|
||||
"SERVER_PROTOCOL": "HTTP/1.1",
|
||||
"wsgi.version": (1, 0),
|
||||
"wsgi.url_scheme": parsed.scheme or "http",
|
||||
"wsgi.input": BytesIO(body), # type: ignore[arg-type]
|
||||
"wsgi.errors": BytesIO(),
|
||||
"wsgi.multithread": True,
|
||||
"wsgi.multiprocess": True,
|
||||
"wsgi.run_once": False,
|
||||
"CONTENT_LENGTH": str(len(body)), # type: ignore[arg-type]
|
||||
}
|
||||
|
||||
if request.headers:
|
||||
for key, value in request.headers.items():
|
||||
key_upper = key.upper().replace("-", "_")
|
||||
if key_upper == "CONTENT_TYPE":
|
||||
environ["CONTENT_TYPE"] = value
|
||||
elif key_upper == "CONTENT_LENGTH":
|
||||
environ["CONTENT_LENGTH"] = value
|
||||
else:
|
||||
environ[f"HTTP_{key_upper}"] = value
|
||||
|
||||
return environ
|
||||
|
||||
def close(self) -> None:
|
||||
"""Clean up adapter resources."""
|
||||
pass
|
||||
@@ -0,0 +1,821 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import threading
|
||||
import typing
|
||||
from concurrent.futures import Future
|
||||
|
||||
from ...._constant import DEFAULT_RETRIES
|
||||
from ....adapters import AsyncBaseAdapter, BaseAdapter
|
||||
from ....exceptions import ConnectTimeout, ReadTimeout
|
||||
from ....models import AsyncResponse, PreparedRequest, Response
|
||||
from ....packages.urllib3._async.response import AsyncHTTPResponse as BaseHTTPResponse
|
||||
from ....packages.urllib3.contrib.ssa._timeout import timeout as asyncio_timeout
|
||||
from ....packages.urllib3.exceptions import MaxRetryError
|
||||
from ....packages.urllib3.response import BytesQueueBuffer
|
||||
from ....packages.urllib3.util import Timeout as TimeoutSauce
|
||||
from ....packages.urllib3.util.retry import Retry
|
||||
from ....structures import CaseInsensitiveDict
|
||||
from ....utils import _swap_context
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ....typing import ASGIApp, ASGIMessage, ProxyType, RetryType, TLSClientCertType, TLSVerifyType
|
||||
|
||||
|
||||
class _ASGIRawIO:
|
||||
"""Async file-like wrapper around an ASGI response for true async streaming."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
response_queue: asyncio.Queue[ASGIMessage | None],
|
||||
response_complete: asyncio.Event,
|
||||
timeout: float | None = None,
|
||||
) -> None:
|
||||
self._response_queue = response_queue
|
||||
self._response_complete = response_complete
|
||||
self._timeout = timeout
|
||||
self._buffer = BytesQueueBuffer()
|
||||
self._closed = False
|
||||
self._finished = False
|
||||
self._task: asyncio.Task | None = None
|
||||
self.headers: dict | None = None
|
||||
self.extension: typing.Any = None
|
||||
|
||||
async def read(self, amt: int | None = None, decode_content: bool = True) -> bytes:
|
||||
if self._closed or self._finished:
|
||||
return self._buffer.get(len(self._buffer))
|
||||
|
||||
if amt is None or amt < 0:
|
||||
async for chunk in self._async_iter_chunks():
|
||||
self._buffer.put(chunk)
|
||||
self._finished = True
|
||||
return self._buffer.get(len(self._buffer))
|
||||
|
||||
while len(self._buffer) < amt and not self._finished:
|
||||
chunk = await self._get_next_chunk() # type: ignore[assignment]
|
||||
if chunk is None:
|
||||
self._finished = True
|
||||
break
|
||||
self._buffer.put(chunk)
|
||||
|
||||
if len(self._buffer) == 0:
|
||||
return b""
|
||||
|
||||
return self._buffer.get(min(amt, len(self._buffer)))
|
||||
|
||||
async def _get_next_chunk(self) -> bytes | None:
|
||||
try:
|
||||
async with asyncio_timeout(self._timeout):
|
||||
message = await self._response_queue.get()
|
||||
if message is None:
|
||||
return None
|
||||
if message["type"] == "http.response.body":
|
||||
return message.get("body", b"")
|
||||
return None
|
||||
except asyncio.TimeoutError:
|
||||
await self._cancel_task()
|
||||
raise ReadTimeout("Read timed out while streaming ASGI response")
|
||||
except asyncio.CancelledError:
|
||||
return None
|
||||
|
||||
async def _async_iter_chunks(self) -> typing.AsyncGenerator[bytes]:
|
||||
while True:
|
||||
try:
|
||||
async with asyncio_timeout(self._timeout):
|
||||
message = await self._response_queue.get()
|
||||
except asyncio.TimeoutError:
|
||||
await self._cancel_task()
|
||||
raise ReadTimeout("Read timed out while streaming ASGI response")
|
||||
if message is None:
|
||||
break
|
||||
if message["type"] == "http.response.body":
|
||||
chunk = message.get("body", b"")
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
async def _cancel_task(self) -> None:
|
||||
if self._task is not None and not self._task.done():
|
||||
self._task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._task
|
||||
|
||||
def stream(self, amt: int, decode_content: bool = True) -> typing.AsyncGenerator[bytes]:
|
||||
return self._async_stream(amt)
|
||||
|
||||
async def _async_stream(self, amt: int) -> typing.AsyncGenerator[bytes]:
|
||||
while True:
|
||||
chunk = await self.read(amt)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
def close(self) -> None:
|
||||
self._closed = True
|
||||
self._response_complete.set()
|
||||
|
||||
def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
return self._async_iter_self()
|
||||
|
||||
async def _async_iter_self(self) -> typing.AsyncIterator[bytes]:
|
||||
async for chunk in self._async_iter_chunks():
|
||||
yield chunk
|
||||
|
||||
async def __anext__(self) -> bytes:
|
||||
chunk = await self.read(8192)
|
||||
if not chunk:
|
||||
raise StopAsyncIteration
|
||||
return chunk
|
||||
|
||||
|
||||
class AsyncServerGatewayInterface(AsyncBaseAdapter):
|
||||
"""Adapter for making requests to ASGI applications directly."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
raise_app_exceptions: bool = True,
|
||||
max_retries: RetryType = DEFAULT_RETRIES,
|
||||
lifespan_state: dict[str, typing.Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.app = app
|
||||
self.raise_app_exceptions = raise_app_exceptions
|
||||
self._lifespan_state = lifespan_state
|
||||
|
||||
if isinstance(max_retries, Retry):
|
||||
self.max_retries = max_retries
|
||||
else:
|
||||
self.max_retries = Retry.from_int(max_retries)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<ASGIAdapter Native/>"
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
stream: bool = False,
|
||||
timeout: int | float | tuple | TimeoutSauce | None = None,
|
||||
verify: TLSVerifyType = True,
|
||||
cert: TLSClientCertType | None = None,
|
||||
proxies: ProxyType | None = None,
|
||||
on_post_connection: typing.Callable[[typing.Any], typing.Awaitable[None]] | None = None,
|
||||
on_upload_body: typing.Callable[[int, int | None, bool, bool], typing.Awaitable[None]] | None = None,
|
||||
on_early_response: typing.Callable[[Response], typing.Awaitable[None]] | None = None,
|
||||
multiplexed: bool = False,
|
||||
) -> AsyncResponse:
|
||||
"""Send a PreparedRequest to the ASGI application."""
|
||||
if isinstance(timeout, tuple):
|
||||
if len(timeout) == 3:
|
||||
timeout = timeout[2] or timeout[0] # prefer total, fallback connect
|
||||
else:
|
||||
timeout = timeout[0] # use connect
|
||||
elif isinstance(timeout, TimeoutSauce):
|
||||
timeout = timeout.total or timeout.connect_timeout
|
||||
|
||||
retries = self.max_retries
|
||||
method = request.method or "GET"
|
||||
|
||||
while True:
|
||||
try:
|
||||
response = await self._do_send(request, stream, timeout)
|
||||
except Exception as err:
|
||||
try:
|
||||
retries = retries.increment(method, request.url, error=err)
|
||||
except MaxRetryError:
|
||||
raise
|
||||
|
||||
await retries.async_sleep()
|
||||
continue
|
||||
|
||||
# we rely on the urllib3 implementation for retries
|
||||
# so we basically mock a response to get it to work
|
||||
base_response = BaseHTTPResponse(
|
||||
body=b"",
|
||||
headers=response.headers,
|
||||
status=response.status_code,
|
||||
request_method=request.method,
|
||||
request_url=request.url,
|
||||
)
|
||||
|
||||
# Check if we should retry based on status code
|
||||
has_retry_after = bool(response.headers.get("Retry-After"))
|
||||
|
||||
if retries.is_retry(method, response.status_code, has_retry_after):
|
||||
try:
|
||||
retries = retries.increment(method, request.url, response=base_response)
|
||||
except MaxRetryError:
|
||||
if retries.raise_on_status:
|
||||
raise
|
||||
return response
|
||||
|
||||
await retries.async_sleep(base_response)
|
||||
continue
|
||||
|
||||
return response
|
||||
|
||||
async def _do_send_ws(self, request: PreparedRequest) -> AsyncResponse:
|
||||
"""Handle WebSocket requests via the ASGI websocket protocol."""
|
||||
from ._ws import ASGIWebSocketExtension
|
||||
|
||||
scope = self._create_ws_scope(request)
|
||||
|
||||
ext = ASGIWebSocketExtension()
|
||||
await ext.start(self.app, scope)
|
||||
|
||||
response = Response()
|
||||
response.status_code = 101
|
||||
response.headers = CaseInsensitiveDict({"upgrade": "websocket"})
|
||||
response.request = request
|
||||
response.url = request.url
|
||||
|
||||
raw_io = _ASGIRawIO(asyncio.Queue(), asyncio.Event())
|
||||
raw_io.extension = ext
|
||||
response.raw = raw_io # type: ignore
|
||||
|
||||
_swap_context(response)
|
||||
return response # type: ignore
|
||||
|
||||
async def _do_send_sse(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
timeout: int | float | None,
|
||||
) -> AsyncResponse:
|
||||
"""Handle SSE requests via ASGI HTTP streaming with SSE parsing."""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ._sse import ASGISSEExtension
|
||||
|
||||
# Convert scheme: sse:// -> https://, psse:// -> http://
|
||||
parsed = urlparse(request.url)
|
||||
if parsed.scheme == "sse":
|
||||
http_scheme = "https"
|
||||
else:
|
||||
http_scheme = "http"
|
||||
|
||||
# Build a modified request with http scheme for scope creation
|
||||
original_url = request.url
|
||||
request.url = request.url.replace(f"{parsed.scheme}://", f"{http_scheme}://", 1) # type: ignore[union-attr,str-bytes-safe]
|
||||
|
||||
scope = self._create_scope(request)
|
||||
request.url = original_url
|
||||
|
||||
body = request.body or b""
|
||||
if isinstance(body, str):
|
||||
body = body.encode("utf-8")
|
||||
|
||||
ext = ASGISSEExtension()
|
||||
start_message = await ext.start(self.app, scope, body) # type: ignore[arg-type]
|
||||
|
||||
status_code = start_message["status"]
|
||||
response_headers = start_message.get("headers", [])
|
||||
headers_dict = {k.decode("latin-1"): v.decode("latin-1") for k, v in response_headers}
|
||||
|
||||
response = Response()
|
||||
response.status_code = status_code
|
||||
response.headers = CaseInsensitiveDict(headers_dict)
|
||||
response.request = request
|
||||
response.url = original_url
|
||||
response.encoding = response.headers.get("content-type", "utf-8") # type: ignore[assignment]
|
||||
|
||||
raw_io = _ASGIRawIO(asyncio.Queue(), asyncio.Event(), timeout)
|
||||
raw_io.headers = headers_dict
|
||||
raw_io.extension = ext
|
||||
response.raw = raw_io # type: ignore
|
||||
response._content = False
|
||||
response._content_consumed = False
|
||||
|
||||
_swap_context(response)
|
||||
return response # type: ignore
|
||||
|
||||
async def _do_send(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
stream: bool,
|
||||
timeout: int | float | None,
|
||||
) -> AsyncResponse:
|
||||
"""Perform the actual ASGI request."""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(request.url)
|
||||
|
||||
if parsed.scheme in ("ws", "wss"):
|
||||
return await self._do_send_ws(request)
|
||||
|
||||
if parsed.scheme in ("sse", "psse"):
|
||||
return await self._do_send_sse(request, timeout)
|
||||
|
||||
scope = self._create_scope(request)
|
||||
|
||||
body = request.body or b""
|
||||
body_iter: typing.AsyncIterator[bytes] | typing.AsyncIterator[str] | None = None
|
||||
|
||||
# Check if body is an async iterable
|
||||
if hasattr(body, "__aiter__"):
|
||||
body_iter = body.__aiter__()
|
||||
body = b"" # Will be streamed
|
||||
elif isinstance(body, str):
|
||||
body = body.encode("utf-8")
|
||||
|
||||
request_complete = False
|
||||
response_complete = asyncio.Event()
|
||||
response_queue: asyncio.Queue[ASGIMessage | None] = asyncio.Queue()
|
||||
app_exception: Exception | None = None
|
||||
|
||||
async def receive() -> ASGIMessage:
|
||||
nonlocal request_complete
|
||||
if request_complete:
|
||||
await response_complete.wait()
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
if body_iter is not None:
|
||||
# Stream chunks from async iterable
|
||||
try:
|
||||
chunk = await body_iter.__anext__()
|
||||
if isinstance(chunk, str):
|
||||
chunk = chunk.encode("utf-8")
|
||||
return {"type": "http.request", "body": chunk, "more_body": True}
|
||||
except StopAsyncIteration:
|
||||
request_complete = True
|
||||
return {"type": "http.request", "body": b"", "more_body": False}
|
||||
else:
|
||||
# Single body chunk
|
||||
request_complete = True
|
||||
return {"type": "http.request", "body": body, "more_body": False}
|
||||
|
||||
async def send_func(message: ASGIMessage) -> None:
|
||||
await response_queue.put(message)
|
||||
if message["type"] == "http.response.body" and not message.get("more_body", False):
|
||||
response_complete.set()
|
||||
|
||||
async def run_app() -> None:
|
||||
nonlocal app_exception
|
||||
try:
|
||||
await self.app(scope, receive, send_func)
|
||||
except Exception as ex:
|
||||
app_exception = ex
|
||||
finally:
|
||||
await response_queue.put(None)
|
||||
|
||||
if stream:
|
||||
return await self._stream_response(
|
||||
request, response_queue, response_complete, run_app, lambda: app_exception, timeout
|
||||
)
|
||||
else:
|
||||
return await self._buffered_response(
|
||||
request, response_queue, response_complete, run_app, lambda: app_exception, timeout
|
||||
)
|
||||
|
||||
async def _stream_response(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
response_queue: asyncio.Queue[ASGIMessage | None],
|
||||
response_complete: asyncio.Event,
|
||||
run_app: typing.Callable[[], typing.Awaitable[None]],
|
||||
get_exception: typing.Callable[[], Exception | None],
|
||||
timeout: float | None,
|
||||
) -> AsyncResponse:
|
||||
status_code: int | None = None
|
||||
response_headers: list[tuple[bytes, bytes]] = []
|
||||
|
||||
task = asyncio.create_task(run_app()) # type: ignore[var-annotated,arg-type]
|
||||
|
||||
try:
|
||||
# Wait for http.response.start with timeout
|
||||
async with asyncio_timeout(timeout):
|
||||
while True:
|
||||
message = await response_queue.get()
|
||||
if message is None:
|
||||
break
|
||||
if message["type"] == "http.response.start":
|
||||
status_code = message["status"]
|
||||
response_headers = message.get("headers", [])
|
||||
break
|
||||
|
||||
headers_dict = {k.decode("latin-1"): v.decode("latin-1") for k, v in response_headers}
|
||||
|
||||
raw_io = _ASGIRawIO(response_queue, response_complete, timeout)
|
||||
raw_io.headers = headers_dict
|
||||
raw_io._task = task
|
||||
|
||||
response = Response()
|
||||
response.status_code = status_code
|
||||
response.headers = CaseInsensitiveDict(headers_dict)
|
||||
response.request = request
|
||||
response.url = request.url
|
||||
response.encoding = response.headers.get("content-type", "utf-8") # type: ignore[assignment]
|
||||
response.raw = raw_io # type: ignore
|
||||
response._content = False
|
||||
response._content_consumed = False
|
||||
_swap_context(response)
|
||||
|
||||
return response # type: ignore
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
raise ConnectTimeout("Timed out waiting for ASGI response headers")
|
||||
|
||||
except Exception:
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
raise
|
||||
|
||||
async def _buffered_response(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
response_queue: asyncio.Queue[ASGIMessage | None],
|
||||
response_complete: asyncio.Event,
|
||||
run_app: typing.Callable[[], typing.Awaitable[None]],
|
||||
get_exception: typing.Callable[[], Exception | None],
|
||||
timeout: float | None,
|
||||
) -> AsyncResponse:
|
||||
status_code: int | None = None
|
||||
response_headers: list[tuple[bytes, bytes]] = []
|
||||
body_chunks: list[bytes] = []
|
||||
|
||||
task = asyncio.create_task(run_app()) # type: ignore[var-annotated,arg-type]
|
||||
|
||||
try:
|
||||
async with asyncio_timeout(timeout):
|
||||
while True:
|
||||
message = await response_queue.get()
|
||||
if message is None:
|
||||
break
|
||||
if message["type"] == "http.response.start":
|
||||
status_code = message["status"]
|
||||
response_headers = message.get("headers", [])
|
||||
elif message["type"] == "http.response.body":
|
||||
chunk = message.get("body", b"")
|
||||
if chunk:
|
||||
body_chunks.append(chunk)
|
||||
|
||||
await task
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
raise ReadTimeout("Timed out reading ASGI response body")
|
||||
|
||||
except Exception:
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
raise
|
||||
|
||||
if self.raise_app_exceptions and get_exception() is not None:
|
||||
raise get_exception() # type: ignore
|
||||
|
||||
headers_dict = {k.decode("latin-1"): v.decode("latin-1") for k, v in response_headers}
|
||||
|
||||
response = Response()
|
||||
response.status_code = status_code
|
||||
response.headers = CaseInsensitiveDict(headers_dict)
|
||||
response.request = request
|
||||
response.url = request.url
|
||||
response.encoding = response.headers.get("content-type", "utf-8") # type: ignore[assignment]
|
||||
response._content = b"".join(body_chunks)
|
||||
response.raw = _ASGIRawIO(response_queue, response_complete, timeout) # type: ignore
|
||||
response.raw.headers = headers_dict
|
||||
|
||||
_swap_context(response)
|
||||
|
||||
return response # type: ignore[return-value]
|
||||
|
||||
def _create_scope(self, request: PreparedRequest) -> dict:
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
parsed = urlparse(request.url)
|
||||
headers: list[tuple[bytes, bytes]] = []
|
||||
if request.headers:
|
||||
for key, value in request.headers.items():
|
||||
headers.append((key.lower().encode("latin-1"), value.encode("latin-1")))
|
||||
|
||||
scope: dict[str, typing.Any] = {
|
||||
"type": "http",
|
||||
"asgi": {"version": "3.0"},
|
||||
"http_version": "1.1",
|
||||
"method": request.method,
|
||||
"scheme": "http",
|
||||
"path": unquote(parsed.path) or "/",
|
||||
"query_string": (parsed.query or "").encode("latin-1"), # type: ignore[union-attr]
|
||||
"root_path": "",
|
||||
"headers": headers,
|
||||
"server": (
|
||||
parsed.hostname or "localhost",
|
||||
parsed.port or (443 if parsed.scheme == "https" else 80),
|
||||
),
|
||||
}
|
||||
|
||||
# Include lifespan state if available (for frameworks like Starlette that use it)
|
||||
if self._lifespan_state is not None:
|
||||
scope["state"] = self._lifespan_state.copy()
|
||||
|
||||
return scope
|
||||
|
||||
def _create_ws_scope(self, request: PreparedRequest) -> dict:
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
parsed = urlparse(request.url)
|
||||
headers: list[tuple[bytes, bytes]] = []
|
||||
if request.headers:
|
||||
for key, value in request.headers.items():
|
||||
headers.append((key.lower().encode("latin-1"), value.encode("latin-1")))
|
||||
|
||||
scheme = "wss" if parsed.scheme == "wss" else "ws"
|
||||
|
||||
scope: dict[str, typing.Any] = {
|
||||
"type": "websocket",
|
||||
"asgi": {"version": "3.0"},
|
||||
"http_version": "1.1",
|
||||
"scheme": scheme,
|
||||
"path": unquote(parsed.path) or "/",
|
||||
"query_string": (parsed.query or "").encode("latin-1"), # type: ignore[union-attr]
|
||||
"root_path": "",
|
||||
"headers": headers,
|
||||
"server": (
|
||||
parsed.hostname or "localhost",
|
||||
parsed.port or (443 if scheme == "wss" else 80),
|
||||
),
|
||||
}
|
||||
|
||||
if self._lifespan_state is not None:
|
||||
scope["state"] = self._lifespan_state.copy()
|
||||
|
||||
return scope
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class ThreadAsyncServerGatewayInterface(BaseAdapter):
|
||||
"""Synchronous adapter for ASGI applications using a background event loop."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
raise_app_exceptions: bool = True,
|
||||
max_retries: RetryType = DEFAULT_RETRIES,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.app = app
|
||||
self.raise_app_exceptions = raise_app_exceptions
|
||||
|
||||
if isinstance(max_retries, Retry):
|
||||
self.max_retries = max_retries
|
||||
else:
|
||||
self.max_retries = Retry.from_int(max_retries)
|
||||
|
||||
self._async_adapter: AsyncServerGatewayInterface | None = None
|
||||
self._loop: typing.Any = None # asyncio.AbstractEventLoop
|
||||
self._thread: threading.Thread | None = None
|
||||
self._started = threading.Event()
|
||||
self._lifespan_task: asyncio.Task | None = None
|
||||
self._lifespan_receive_queue: asyncio.Queue[ASGIMessage] | None = None
|
||||
self._lifespan_startup_complete = threading.Event()
|
||||
self._lifespan_startup_failed: Exception | None = None
|
||||
self._lifespan_state: dict[str, typing.Any] = {}
|
||||
self._startup_lock: threading.Lock = threading.Lock()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<ASGIAdapter Thread/>"
|
||||
|
||||
def _ensure_loop_running(self) -> None:
|
||||
"""Start the background event loop thread if not already running."""
|
||||
with self._startup_lock:
|
||||
if self._thread is not None and self._thread.is_alive():
|
||||
return
|
||||
|
||||
import asyncio
|
||||
|
||||
def run_loop() -> None:
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._lifespan_receive_queue = asyncio.Queue()
|
||||
self._async_adapter = AsyncServerGatewayInterface(
|
||||
self.app,
|
||||
raise_app_exceptions=self.raise_app_exceptions,
|
||||
max_retries=self.max_retries,
|
||||
lifespan_state=self._lifespan_state,
|
||||
)
|
||||
# Start lifespan handler
|
||||
self._lifespan_task = self._loop.create_task(self._handle_lifespan())
|
||||
self._started.set()
|
||||
self._loop.run_forever()
|
||||
|
||||
self._thread = threading.Thread(target=run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
self._started.wait()
|
||||
self._lifespan_startup_complete.wait()
|
||||
|
||||
if self._lifespan_startup_failed is not None:
|
||||
raise self._lifespan_startup_failed
|
||||
|
||||
async def _handle_lifespan(self) -> None:
|
||||
"""Handle ASGI lifespan protocol."""
|
||||
scope = {
|
||||
"type": "lifespan",
|
||||
"asgi": {"version": "3.0"},
|
||||
"state": self._lifespan_state,
|
||||
}
|
||||
|
||||
startup_complete = asyncio.Event()
|
||||
shutdown_complete = asyncio.Event()
|
||||
startup_failed: list[Exception] = []
|
||||
|
||||
# Keep local reference to avoid race condition during shutdown
|
||||
receive_queue = self._lifespan_receive_queue
|
||||
|
||||
async def receive() -> ASGIMessage:
|
||||
return await receive_queue.get() # type: ignore[union-attr]
|
||||
|
||||
async def send(message: ASGIMessage) -> None:
|
||||
if message["type"] == "lifespan.startup.complete":
|
||||
startup_complete.set()
|
||||
elif message["type"] == "lifespan.startup.failed":
|
||||
startup_failed.append(RuntimeError(message.get("message", "Lifespan startup failed")))
|
||||
startup_complete.set()
|
||||
elif message["type"] == "lifespan.shutdown.complete":
|
||||
shutdown_complete.set()
|
||||
elif message["type"] == "lifespan.shutdown.failed":
|
||||
shutdown_complete.set()
|
||||
|
||||
async def run_lifespan() -> None:
|
||||
try:
|
||||
await self.app(scope, receive, send)
|
||||
except Exception as e:
|
||||
if not startup_complete.is_set():
|
||||
startup_failed.append(e)
|
||||
startup_complete.set()
|
||||
|
||||
lifespan_task = asyncio.create_task(run_lifespan())
|
||||
|
||||
# Send startup event
|
||||
await receive_queue.put({"type": "lifespan.startup"}) # type: ignore[union-attr]
|
||||
await startup_complete.wait()
|
||||
|
||||
if startup_failed:
|
||||
self._lifespan_startup_failed = startup_failed[0]
|
||||
self._lifespan_startup_complete.set()
|
||||
|
||||
# Wait for shutdown signal (loop.stop() will cancel this)
|
||||
try:
|
||||
await asyncio.Future() # Wait forever until canceled
|
||||
except (asyncio.CancelledError, GeneratorExit):
|
||||
pass
|
||||
|
||||
# Send shutdown event - must happen before loop stops
|
||||
if receive_queue is not None:
|
||||
try:
|
||||
await receive_queue.put({"type": "lifespan.shutdown"})
|
||||
await asyncio.wait_for(shutdown_complete.wait(), timeout=5.0)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError, RuntimeError):
|
||||
pass
|
||||
lifespan_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await lifespan_task
|
||||
|
||||
def _do_send_ws(self, request: PreparedRequest) -> Response:
|
||||
"""Handle WebSocket requests synchronously via the background loop."""
|
||||
from .._ws import ThreadASGIWebSocketExtension
|
||||
|
||||
self._ensure_loop_running()
|
||||
|
||||
future: Future[Response] = Future()
|
||||
|
||||
async def run_ws() -> None:
|
||||
try:
|
||||
result = await self._async_adapter._do_send_ws(request) # type: ignore[union-attr]
|
||||
_swap_context(result)
|
||||
|
||||
# Wrap the async WS extension in a sync wrapper
|
||||
async_ext = result.raw.extension # type: ignore[union-attr]
|
||||
result.raw.extension = ThreadASGIWebSocketExtension(async_ext, self._loop) # type: ignore[union-attr]
|
||||
|
||||
future.set_result(result) # type: ignore[arg-type]
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
|
||||
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(run_ws()))
|
||||
return future.result()
|
||||
|
||||
def _do_send_sse(self, request: PreparedRequest, timeout: int | float | None = None) -> Response:
|
||||
"""Handle SSE requests synchronously via the background loop."""
|
||||
from .._sse import ThreadASGISSEExtension
|
||||
|
||||
self._ensure_loop_running()
|
||||
|
||||
future: Future[Response] = Future()
|
||||
|
||||
async def run_sse() -> None:
|
||||
try:
|
||||
result = await self._async_adapter._do_send_sse(request, timeout) # type: ignore[union-attr]
|
||||
_swap_context(result)
|
||||
|
||||
# Wrap the async SSE extension in a sync wrapper
|
||||
async_ext = result.raw.extension # type: ignore[union-attr]
|
||||
result.raw.extension = ThreadASGISSEExtension(async_ext, self._loop) # type: ignore[union-attr]
|
||||
|
||||
future.set_result(result) # type: ignore[arg-type]
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
|
||||
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(run_sse()))
|
||||
return future.result()
|
||||
|
||||
def send(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
stream: bool = False,
|
||||
timeout: int | float | tuple | TimeoutSauce | None = None,
|
||||
verify: TLSVerifyType = True,
|
||||
cert: TLSClientCertType | None = None,
|
||||
proxies: ProxyType | None = None,
|
||||
on_post_connection: typing.Callable[[typing.Any], None] | None = None,
|
||||
on_upload_body: typing.Callable[[int, int | None, bool, bool], None] | None = None,
|
||||
on_early_response: typing.Callable[[Response], None] | None = None,
|
||||
multiplexed: bool = False,
|
||||
) -> Response:
|
||||
"""Send a PreparedRequest to the ASGI application synchronously."""
|
||||
if isinstance(timeout, tuple):
|
||||
if len(timeout) == 3:
|
||||
timeout = timeout[2] or timeout[0] # prefer total, fallback connect
|
||||
else:
|
||||
timeout = timeout[0] # use connect
|
||||
elif isinstance(timeout, TimeoutSauce):
|
||||
timeout = timeout.total or timeout.connect_timeout
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(request.url)
|
||||
|
||||
if parsed.scheme in ("ws", "wss"):
|
||||
return self._do_send_ws(request)
|
||||
|
||||
if parsed.scheme in ("sse", "psse"):
|
||||
return self._do_send_sse(request, timeout)
|
||||
|
||||
if stream:
|
||||
raise ValueError(
|
||||
"ThreadAsyncServerGatewayInterface does not support streaming responses. "
|
||||
"Use stream=False or migrate to pure async/await implementation."
|
||||
)
|
||||
|
||||
self._ensure_loop_running()
|
||||
|
||||
future: Future[Response] = Future()
|
||||
|
||||
async def run_send() -> None:
|
||||
try:
|
||||
result = await self._async_adapter.send( # type: ignore[union-attr]
|
||||
request,
|
||||
stream=False,
|
||||
timeout=timeout,
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
proxies=proxies,
|
||||
)
|
||||
_swap_context(result)
|
||||
future.set_result(result) # type: ignore[arg-type]
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
|
||||
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(run_send()))
|
||||
|
||||
return future.result()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Clean up adapter resources."""
|
||||
if self._loop is not None and self._lifespan_task is not None:
|
||||
# Signal shutdown and wait for it to complete
|
||||
shutdown_done = threading.Event()
|
||||
|
||||
async def do_shutdown() -> None:
|
||||
if self._lifespan_task is not None:
|
||||
self._lifespan_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._lifespan_task
|
||||
shutdown_done.set()
|
||||
|
||||
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(do_shutdown()))
|
||||
shutdown_done.wait(timeout=6.0)
|
||||
|
||||
if self._loop is not None:
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
if self._thread is not None:
|
||||
self._thread.join(timeout=5.0)
|
||||
self._thread = None
|
||||
# Clear resources only after thread has stopped
|
||||
self._loop = None
|
||||
self._async_adapter = None
|
||||
self._lifespan_task = None
|
||||
self._lifespan_receive_queue = None
|
||||
self._started.clear()
|
||||
self._lifespan_startup_complete.clear()
|
||||
self._lifespan_startup_failed = None
|
||||
|
||||
|
||||
__all__ = ("AsyncServerGatewayInterface", "ThreadAsyncServerGatewayInterface")
|
||||
@@ -0,0 +1,175 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import typing
|
||||
|
||||
from ....packages.urllib3.contrib.webextensions.sse import ServerSentEvent
|
||||
|
||||
|
||||
class ASGISSEExtension:
|
||||
"""Async SSE extension for ASGI applications.
|
||||
|
||||
Runs the ASGI app as a normal HTTP streaming request and parses
|
||||
SSE events from the response body chunks.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._closed = False
|
||||
self._buffer: str = ""
|
||||
self._last_event_id: str | None = None
|
||||
self._response_queue: asyncio.Queue[dict[str, typing.Any] | None] | None = None
|
||||
self._task: asyncio.Task[None] | None = None
|
||||
|
||||
async def start(
|
||||
self,
|
||||
app: typing.Any,
|
||||
scope: dict[str, typing.Any],
|
||||
body: bytes = b"",
|
||||
) -> dict[str, typing.Any]:
|
||||
"""Start the ASGI app and wait for http.response.start.
|
||||
|
||||
Returns the response start message (with status and headers).
|
||||
"""
|
||||
self._response_queue = asyncio.Queue()
|
||||
|
||||
request_complete = False
|
||||
response_complete = asyncio.Event()
|
||||
|
||||
async def receive() -> dict[str, typing.Any]:
|
||||
nonlocal request_complete
|
||||
if request_complete:
|
||||
await response_complete.wait()
|
||||
return {"type": "http.disconnect"}
|
||||
request_complete = True
|
||||
return {"type": "http.request", "body": body, "more_body": False}
|
||||
|
||||
async def send(message: dict[str, typing.Any]) -> None:
|
||||
await self._response_queue.put(message) # type: ignore[union-attr]
|
||||
if message["type"] == "http.response.body" and not message.get("more_body", False):
|
||||
response_complete.set()
|
||||
|
||||
async def run_app() -> None:
|
||||
try:
|
||||
await app(scope, receive, send)
|
||||
finally:
|
||||
await self._response_queue.put(None) # type: ignore[union-attr]
|
||||
|
||||
self._task = asyncio.create_task(run_app())
|
||||
|
||||
# Wait for http.response.start
|
||||
while True:
|
||||
message = await self._response_queue.get()
|
||||
if message is None:
|
||||
raise ConnectionError("ASGI app closed before sending response headers")
|
||||
if message["type"] == "http.response.start":
|
||||
return message
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._closed
|
||||
|
||||
async def next_payload(self, *, raw: bool = False) -> ServerSentEvent | str | None:
|
||||
"""Read and parse the next SSE event from the ASGI response stream.
|
||||
Returns None when the stream ends."""
|
||||
if self._closed:
|
||||
raise OSError("The SSE extension is closed")
|
||||
|
||||
while True:
|
||||
# Check if we already have a complete event in the buffer
|
||||
sep_idx = self._buffer.find("\n\n")
|
||||
if sep_idx == -1:
|
||||
sep_idx = self._buffer.find("\r\n\r\n")
|
||||
if sep_idx != -1:
|
||||
sep_len = 4
|
||||
else:
|
||||
sep_len = 2
|
||||
else:
|
||||
sep_len = 2
|
||||
|
||||
if sep_idx != -1:
|
||||
raw_event = self._buffer[:sep_idx]
|
||||
self._buffer = self._buffer[sep_idx + sep_len :]
|
||||
|
||||
event = self._parse_event(raw_event)
|
||||
|
||||
if event is not None:
|
||||
if raw:
|
||||
return raw_event + "\n\n"
|
||||
return event
|
||||
|
||||
# Empty event (e.g. just comments), try next
|
||||
continue
|
||||
|
||||
# Need more data from the ASGI response queue
|
||||
chunk = await self._read_chunk()
|
||||
if chunk is None:
|
||||
self._closed = True
|
||||
return None
|
||||
|
||||
self._buffer += chunk
|
||||
|
||||
async def _read_chunk(self) -> str | None:
|
||||
"""Read the next body chunk from the ASGI response."""
|
||||
if self._response_queue is None:
|
||||
return None
|
||||
|
||||
message = await self._response_queue.get()
|
||||
if message is None:
|
||||
return None
|
||||
if message["type"] == "http.response.body":
|
||||
body = message.get("body", b"")
|
||||
if body:
|
||||
return body.decode("utf-8")
|
||||
return None
|
||||
|
||||
def _parse_event(self, raw_event: str) -> ServerSentEvent | None:
|
||||
"""Parse a raw SSE event block into a ServerSentEvent."""
|
||||
kwargs: dict[str, typing.Any] = {}
|
||||
|
||||
for line in raw_event.splitlines():
|
||||
if not line or line.startswith(":"):
|
||||
continue
|
||||
key, _, value = line.partition(":")
|
||||
if key not in {"event", "data", "retry", "id"}:
|
||||
continue
|
||||
if value.startswith(" "):
|
||||
value = value[1:]
|
||||
if key == "id":
|
||||
if "\u0000" in value:
|
||||
continue
|
||||
if key == "retry":
|
||||
try:
|
||||
value = int(value) # type: ignore[assignment]
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
kwargs[key] = value
|
||||
|
||||
if not kwargs:
|
||||
return None
|
||||
|
||||
if "id" not in kwargs and self._last_event_id is not None:
|
||||
kwargs["id"] = self._last_event_id
|
||||
|
||||
event = ServerSentEvent(**kwargs)
|
||||
|
||||
if event.id:
|
||||
self._last_event_id = event.id
|
||||
|
||||
return event
|
||||
|
||||
async def send_payload(self, buf: str | bytes) -> None:
|
||||
"""SSE is one-way only."""
|
||||
raise NotImplementedError("SSE is only one-way. Sending is forbidden.")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the SSE stream and clean up the app task."""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
|
||||
if self._task is not None and not self._task.done():
|
||||
self._task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._task
|
||||
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import typing
|
||||
|
||||
|
||||
class ASGIWebSocketExtension:
|
||||
"""Async WebSocket extension for ASGI applications.
|
||||
|
||||
Uses the ASGI websocket protocol with send/receive queues to communicate
|
||||
with the application task.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._closed = False
|
||||
self._app_send_queue: asyncio.Queue[dict[str, typing.Any]] = asyncio.Queue()
|
||||
self._app_receive_queue: asyncio.Queue[dict[str, typing.Any]] = asyncio.Queue()
|
||||
self._task: asyncio.Task[None] | None = None
|
||||
|
||||
async def start(self, app: typing.Any, scope: dict[str, typing.Any]) -> None:
|
||||
"""Start the ASGI app task and perform the WebSocket handshake."""
|
||||
|
||||
async def receive() -> dict[str, typing.Any]:
|
||||
return await self._app_receive_queue.get()
|
||||
|
||||
async def send(message: dict[str, typing.Any]) -> None:
|
||||
await self._app_send_queue.put(message)
|
||||
|
||||
self._task = asyncio.create_task(app(scope, receive, send))
|
||||
|
||||
# Send connect and wait for accept
|
||||
await self._app_receive_queue.put({"type": "websocket.connect"})
|
||||
|
||||
message = await self._app_send_queue.get()
|
||||
|
||||
if message["type"] == "websocket.close":
|
||||
self._closed = True
|
||||
self._task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._task
|
||||
raise ConnectionError(f"WebSocket connection rejected with code {message.get('code', 1000)}")
|
||||
|
||||
if message["type"] != "websocket.accept":
|
||||
self._closed = True
|
||||
self._task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._task
|
||||
raise ConnectionError(f"Unexpected ASGI message during handshake: {message['type']}")
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._closed
|
||||
|
||||
async def next_payload(self) -> str | bytes | None:
|
||||
"""Await the next message from the ASGI WebSocket app.
|
||||
Returns None when the app closes the connection."""
|
||||
if self._closed:
|
||||
raise OSError("The WebSocket extension is closed")
|
||||
|
||||
message = await self._app_send_queue.get()
|
||||
|
||||
if message["type"] == "websocket.send":
|
||||
if "text" in message:
|
||||
return message["text"]
|
||||
if "bytes" in message:
|
||||
return message["bytes"]
|
||||
return b""
|
||||
|
||||
if message["type"] == "websocket.close":
|
||||
self._closed = True
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
async def send_payload(self, buf: str | bytes) -> None:
|
||||
"""Send a message to the ASGI WebSocket app."""
|
||||
if self._closed:
|
||||
raise OSError("The WebSocket extension is closed")
|
||||
|
||||
if isinstance(buf, (bytes, bytearray)):
|
||||
await self._app_receive_queue.put({"type": "websocket.receive", "bytes": bytes(buf)})
|
||||
else:
|
||||
await self._app_receive_queue.put({"type": "websocket.receive", "text": buf})
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebSocket and clean up the app task."""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
|
||||
await self._app_receive_queue.put({"type": "websocket.disconnect", "code": 1000})
|
||||
|
||||
if self._task is not None and not self._task.done():
|
||||
self._task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._task
|
||||
@@ -0,0 +1,175 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from concurrent.futures import Future
|
||||
|
||||
from ...packages.urllib3.contrib.webextensions.sse import ServerSentEvent
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import asyncio
|
||||
|
||||
from ._async._sse import ASGISSEExtension
|
||||
|
||||
|
||||
class WSGISSEExtension:
|
||||
"""SSE extension for WSGI applications.
|
||||
|
||||
Reads from a WSGI response iterator, buffers text, and parses SSE events.
|
||||
"""
|
||||
|
||||
def __init__(self, generator: typing.Generator[bytes, None, None]) -> None:
|
||||
self._generator = generator
|
||||
self._closed = False
|
||||
self._buffer: str = ""
|
||||
self._last_event_id: str | None = None
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._closed
|
||||
|
||||
def next_payload(self, *, raw: bool = False) -> ServerSentEvent | str | None:
|
||||
"""Read and parse the next SSE event from the WSGI response.
|
||||
Returns None when the stream ends."""
|
||||
if self._closed:
|
||||
raise OSError("The SSE extension is closed")
|
||||
|
||||
while True:
|
||||
# Check if we already have a complete event in the buffer
|
||||
sep_idx = self._buffer.find("\n\n")
|
||||
if sep_idx == -1:
|
||||
sep_idx = self._buffer.find("\r\n\r\n")
|
||||
if sep_idx != -1:
|
||||
sep_len = 4
|
||||
else:
|
||||
sep_len = 2
|
||||
else:
|
||||
sep_len = 2
|
||||
|
||||
if sep_idx != -1:
|
||||
raw_event = self._buffer[:sep_idx]
|
||||
self._buffer = self._buffer[sep_idx + sep_len :]
|
||||
|
||||
event = self._parse_event(raw_event)
|
||||
|
||||
if event is not None:
|
||||
if raw:
|
||||
return raw_event + "\n\n"
|
||||
return event
|
||||
|
||||
# Empty event (e.g. just comments), try next
|
||||
continue
|
||||
|
||||
# Need more data
|
||||
chunk = self._read_chunk()
|
||||
if chunk is None:
|
||||
self._closed = True
|
||||
return None
|
||||
|
||||
self._buffer += chunk
|
||||
|
||||
def _read_chunk(self) -> str | None:
|
||||
"""Read the next chunk from the WSGI response iterator."""
|
||||
try:
|
||||
chunk = next(self._generator)
|
||||
if chunk:
|
||||
return chunk.decode("utf-8")
|
||||
return None
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def _parse_event(self, raw_event: str) -> ServerSentEvent | None:
|
||||
"""Parse a raw SSE event block into a ServerSentEvent."""
|
||||
kwargs: dict[str, typing.Any] = {}
|
||||
|
||||
for line in raw_event.splitlines():
|
||||
if not line or line.startswith(":"):
|
||||
continue
|
||||
key, _, value = line.partition(":")
|
||||
if key not in {"event", "data", "retry", "id"}:
|
||||
continue
|
||||
if value.startswith(" "):
|
||||
value = value[1:]
|
||||
if key == "id":
|
||||
if "\u0000" in value:
|
||||
continue
|
||||
if key == "retry":
|
||||
try:
|
||||
value = int(value) # type: ignore[assignment]
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
kwargs[key] = value
|
||||
|
||||
if not kwargs:
|
||||
return None
|
||||
|
||||
if "id" not in kwargs and self._last_event_id is not None:
|
||||
kwargs["id"] = self._last_event_id
|
||||
|
||||
event = ServerSentEvent(**kwargs)
|
||||
|
||||
if event.id:
|
||||
self._last_event_id = event.id
|
||||
|
||||
return event
|
||||
|
||||
def send_payload(self, buf: str | bytes) -> None:
|
||||
"""SSE is one-way only."""
|
||||
raise NotImplementedError("SSE is only one-way. Sending is forbidden.")
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the stream and release resources."""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
|
||||
if hasattr(self._generator, "close"):
|
||||
self._generator.close()
|
||||
|
||||
|
||||
class ThreadASGISSEExtension:
|
||||
"""Synchronous SSE extension wrapping an async ASGISSEExtension.
|
||||
|
||||
Delegates all operations to the async extension on a background event loop,
|
||||
blocking the calling thread via concurrent.futures.Future.
|
||||
"""
|
||||
|
||||
def __init__(self, async_ext: ASGISSEExtension, loop: asyncio.AbstractEventLoop) -> None:
|
||||
self._async_ext = async_ext
|
||||
self._loop = loop
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._async_ext.closed
|
||||
|
||||
def next_payload(self, *, raw: bool = False) -> ServerSentEvent | str | None:
|
||||
"""Block until the next SSE event arrives from the ASGI app."""
|
||||
future: Future[ServerSentEvent | str | None] = Future()
|
||||
|
||||
async def _do() -> None:
|
||||
try:
|
||||
result = await self._async_ext.next_payload(raw=raw)
|
||||
future.set_result(result)
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
|
||||
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(_do()))
|
||||
return future.result()
|
||||
|
||||
def send_payload(self, buf: str | bytes) -> None:
|
||||
"""SSE is one-way only."""
|
||||
raise NotImplementedError("SSE is only one-way. Sending is forbidden.")
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the SSE stream and clean up."""
|
||||
future: Future[None] = Future()
|
||||
|
||||
async def _do() -> None:
|
||||
try:
|
||||
await self._async_ext.close()
|
||||
future.set_result(None)
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
|
||||
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(_do()))
|
||||
future.result()
|
||||
@@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from concurrent.futures import Future
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import asyncio
|
||||
|
||||
from ._async._ws import ASGIWebSocketExtension
|
||||
|
||||
|
||||
class ThreadASGIWebSocketExtension:
|
||||
"""Synchronous WebSocket extension wrapping an async ASGIWebSocketExtension.
|
||||
|
||||
Delegates all operations to the async extension on a background event loop,
|
||||
blocking the calling thread via concurrent.futures.Future.
|
||||
"""
|
||||
|
||||
def __init__(self, async_ext: ASGIWebSocketExtension, loop: asyncio.AbstractEventLoop) -> None:
|
||||
self._async_ext = async_ext
|
||||
self._loop = loop
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._async_ext.closed
|
||||
|
||||
def next_payload(self) -> str | bytes | None:
|
||||
"""Block until the next message arrives from the ASGI WebSocket app."""
|
||||
future: Future[str | bytes | None] = Future()
|
||||
|
||||
async def _do() -> None:
|
||||
try:
|
||||
result = await self._async_ext.next_payload()
|
||||
future.set_result(result)
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
|
||||
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(_do()))
|
||||
return future.result()
|
||||
|
||||
def send_payload(self, buf: str | bytes) -> None:
|
||||
"""Send a message to the ASGI WebSocket app."""
|
||||
future: Future[None] = Future()
|
||||
|
||||
async def _do() -> None:
|
||||
try:
|
||||
await self._async_ext.send_payload(buf)
|
||||
future.set_result(None)
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
|
||||
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(_do()))
|
||||
future.result()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the WebSocket and clean up."""
|
||||
future: Future[None] = Future()
|
||||
|
||||
async def _do() -> None:
|
||||
try:
|
||||
await self._async_ext.close()
|
||||
future.set_result(None)
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
|
||||
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(_do()))
|
||||
future.result()
|
||||
@@ -0,0 +1,100 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
from urllib.parse import unquote
|
||||
|
||||
from ..._constant import DEFAULT_POOLBLOCK
|
||||
from ...adapters import HTTPAdapter
|
||||
from ...exceptions import RequestException
|
||||
from ...packages.urllib3.connection import HTTPConnection
|
||||
from ...packages.urllib3.connectionpool import HTTPConnectionPool
|
||||
from ...packages.urllib3.contrib.webextensions import ServerSideEventExtensionFromHTTP, WebSocketExtensionFromHTTP
|
||||
from ...packages.urllib3.poolmanager import PoolManager
|
||||
from ...typing import CacheLayerAltSvcType
|
||||
from ...utils import select_proxy
|
||||
|
||||
|
||||
class UnixServerSideEventExtensionFromHTTP(ServerSideEventExtensionFromHTTP):
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
return "unix"
|
||||
|
||||
@staticmethod
|
||||
def scheme_to_http_scheme(scheme: str) -> str:
|
||||
return {"psse": "http+unix"}[scheme]
|
||||
|
||||
|
||||
if WebSocketExtensionFromHTTP is not None:
|
||||
|
||||
class UnixWebSocketExtensionFromHTTP(WebSocketExtensionFromHTTP):
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
return "unix"
|
||||
|
||||
@staticmethod
|
||||
def scheme_to_http_scheme(scheme: str) -> str:
|
||||
return {"ws": "http+unix"}[scheme]
|
||||
|
||||
|
||||
class UnixHTTPConnection(HTTPConnection):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.host: str = unquote(self.host)
|
||||
self.socket_path = self.host
|
||||
self.host = self.socket_path.split("/")[-1]
|
||||
|
||||
def connect(self):
|
||||
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
sock.settimeout(self.timeout)
|
||||
sock.connect(self.socket_path)
|
||||
self.sock = sock
|
||||
self._post_conn()
|
||||
|
||||
|
||||
class UnixHTTPConnectionPool(HTTPConnectionPool):
|
||||
ConnectionCls = UnixHTTPConnection
|
||||
|
||||
|
||||
class UnixAdapter(HTTPAdapter):
|
||||
def init_poolmanager(
|
||||
self,
|
||||
connections: int,
|
||||
maxsize: int,
|
||||
block: bool = DEFAULT_POOLBLOCK,
|
||||
quic_cache_layer: CacheLayerAltSvcType | None = None,
|
||||
**pool_kwargs: typing.Any,
|
||||
):
|
||||
self._pool_connections = connections
|
||||
self._pool_maxsize = maxsize
|
||||
self._pool_block = block
|
||||
self._quic_cache_layer = quic_cache_layer
|
||||
|
||||
self.poolmanager = PoolManager(
|
||||
num_pools=connections,
|
||||
maxsize=maxsize,
|
||||
block=block,
|
||||
preemptive_quic_cache=quic_cache_layer,
|
||||
**pool_kwargs,
|
||||
)
|
||||
self.poolmanager.key_fn_by_scheme["http+unix"] = self.poolmanager.key_fn_by_scheme["http"]
|
||||
self.poolmanager.pool_classes_by_scheme = {
|
||||
"http+unix": UnixHTTPConnectionPool,
|
||||
}
|
||||
|
||||
def get_connection(self, url, proxies=None):
|
||||
proxy = select_proxy(url, proxies)
|
||||
|
||||
if proxy:
|
||||
raise RequestException("unix socket cannot be associated with proxies")
|
||||
|
||||
return self.poolmanager.connection_from_url(url)
|
||||
|
||||
def request_url(self, request, proxies):
|
||||
return request.path_url
|
||||
|
||||
def close(self):
|
||||
self.poolmanager.clear()
|
||||
|
||||
|
||||
__all__ = ("UnixAdapter",)
|
||||
@@ -0,0 +1,101 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
from urllib.parse import unquote
|
||||
|
||||
from ...._constant import DEFAULT_POOLBLOCK
|
||||
from ....adapters import AsyncHTTPAdapter
|
||||
from ....exceptions import RequestException
|
||||
from ....packages.urllib3._async.connection import AsyncHTTPConnection
|
||||
from ....packages.urllib3._async.connectionpool import AsyncHTTPConnectionPool
|
||||
from ....packages.urllib3._async.poolmanager import AsyncPoolManager
|
||||
from ....packages.urllib3.contrib.ssa import AsyncSocket
|
||||
from ....packages.urllib3.contrib.webextensions._async import (
|
||||
AsyncServerSideEventExtensionFromHTTP,
|
||||
AsyncWebSocketExtensionFromHTTP,
|
||||
)
|
||||
from ....typing import CacheLayerAltSvcType
|
||||
from ....utils import select_proxy
|
||||
|
||||
|
||||
class AsyncUnixServerSideEventExtensionFromHTTP(AsyncServerSideEventExtensionFromHTTP):
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
return "unix"
|
||||
|
||||
@staticmethod
|
||||
def scheme_to_http_scheme(scheme: str) -> str:
|
||||
return {"psse": "http+unix"}[scheme]
|
||||
|
||||
|
||||
if AsyncWebSocketExtensionFromHTTP is not None:
|
||||
|
||||
class AsyncUnixWebSocketExtensionFromHTTP(AsyncWebSocketExtensionFromHTTP):
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
return "unix"
|
||||
|
||||
@staticmethod
|
||||
def scheme_to_http_scheme(scheme: str) -> str:
|
||||
return {"ws": "http+unix"}[scheme]
|
||||
|
||||
|
||||
class AsyncUnixHTTPConnection(AsyncHTTPConnection):
|
||||
def __init__(self, host, **kwargs):
|
||||
super().__init__(host, **kwargs)
|
||||
self.host: str = unquote(self.host)
|
||||
self.socket_path = self.host
|
||||
self.host = self.socket_path.split("/")[-1]
|
||||
|
||||
async def connect(self):
|
||||
sock = AsyncSocket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
sock.settimeout(self.timeout)
|
||||
await sock.connect(self.socket_path)
|
||||
self.sock = sock
|
||||
await self._post_conn()
|
||||
|
||||
|
||||
class AsyncUnixHTTPConnectionPool(AsyncHTTPConnectionPool):
|
||||
ConnectionCls = AsyncUnixHTTPConnection
|
||||
|
||||
|
||||
class AsyncUnixAdapter(AsyncHTTPAdapter):
|
||||
def init_poolmanager(
|
||||
self,
|
||||
connections: int,
|
||||
maxsize: int,
|
||||
block: bool = DEFAULT_POOLBLOCK,
|
||||
quic_cache_layer: CacheLayerAltSvcType | None = None,
|
||||
**pool_kwargs: typing.Any,
|
||||
):
|
||||
self._pool_connections = connections
|
||||
self._pool_maxsize = maxsize
|
||||
self._pool_block = block
|
||||
self._quic_cache_layer = quic_cache_layer
|
||||
|
||||
self.poolmanager = AsyncPoolManager(
|
||||
num_pools=connections,
|
||||
maxsize=maxsize,
|
||||
block=block,
|
||||
preemptive_quic_cache=quic_cache_layer,
|
||||
**pool_kwargs,
|
||||
)
|
||||
self.poolmanager.key_fn_by_scheme["http+unix"] = self.poolmanager.key_fn_by_scheme["http"]
|
||||
self.poolmanager.pool_classes_by_scheme = {
|
||||
"http+unix": AsyncUnixHTTPConnectionPool,
|
||||
}
|
||||
|
||||
def get_connection(self, url, proxies=None):
|
||||
proxy = select_proxy(url, proxies)
|
||||
|
||||
if proxy:
|
||||
raise RequestException("unix socket cannot be associated with proxies")
|
||||
|
||||
return self.poolmanager.connection_from_url(url)
|
||||
|
||||
def request_url(self, request, proxies):
|
||||
return request.path_url
|
||||
|
||||
|
||||
__all__ = ("AsyncUnixAdapter",)
|
||||
227
.venv/lib/python3.9/site-packages/niquests/help.py
Normal file
227
.venv/lib/python3.9/site-packages/niquests/help.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""Module containing bug report helper(s)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import platform
|
||||
|
||||
try:
|
||||
import ssl
|
||||
except ImportError:
|
||||
ssl = None # type: ignore
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
from json import JSONDecodeError
|
||||
|
||||
import charset_normalizer
|
||||
import h11
|
||||
|
||||
try:
|
||||
import idna # type: ignore[import-not-found]
|
||||
except ImportError:
|
||||
idna = None # type: ignore[assignment]
|
||||
|
||||
import jh2 # type: ignore
|
||||
import wassima
|
||||
|
||||
from . import HTTPError, RequestException, Session
|
||||
from . import __version__ as niquests_version
|
||||
from ._compat import HAS_LEGACY_URLLIB3
|
||||
|
||||
if HAS_LEGACY_URLLIB3 is True:
|
||||
import urllib3_future as urllib3
|
||||
|
||||
try:
|
||||
from urllib3 import __version__ as __legacy_urllib3_version__
|
||||
except (ImportError, AttributeError):
|
||||
__legacy_urllib3_version__ = None # type: ignore[assignment]
|
||||
else:
|
||||
import urllib3 # type: ignore[no-redef]
|
||||
|
||||
__legacy_urllib3_version__ = None # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
import qh3 # type: ignore
|
||||
except ImportError:
|
||||
qh3 = None # type: ignore
|
||||
|
||||
try:
|
||||
import certifi # type: ignore
|
||||
except ImportError:
|
||||
certifi = None # type: ignore
|
||||
|
||||
try:
|
||||
from .extensions.revocation._ocsp import verify as ocsp_verify
|
||||
except ImportError:
|
||||
ocsp_verify = None # type: ignore
|
||||
|
||||
try:
|
||||
import wsproto # type: ignore[import-not-found]
|
||||
except ImportError:
|
||||
wsproto = None # type: ignore
|
||||
|
||||
|
||||
_IS_GIL_DISABLED: bool = hasattr(sys, "_is_gil_enabled") and sys._is_gil_enabled() is False
|
||||
|
||||
|
||||
def _implementation():
|
||||
"""Return a dict with the Python implementation and version.
|
||||
|
||||
Provide both the name and the version of the Python implementation
|
||||
currently running. For example, on CPython 3.10.3 it will return
|
||||
{'name': 'CPython', 'version': '3.10.3'}.
|
||||
|
||||
This function works best on CPython and PyPy: in particular, it probably
|
||||
doesn't work for Jython or IronPython. Future investigation should be done
|
||||
to work out the correct shape of the code for those platforms.
|
||||
"""
|
||||
implementation = platform.python_implementation()
|
||||
|
||||
if implementation == "CPython":
|
||||
implementation_version = platform.python_version()
|
||||
elif implementation == "PyPy":
|
||||
implementation_version = (
|
||||
f"{sys.pypy_version_info.major}" # type: ignore[attr-defined]
|
||||
f".{sys.pypy_version_info.minor}" # type: ignore[attr-defined]
|
||||
f".{sys.pypy_version_info.micro}" # type: ignore[attr-defined]
|
||||
)
|
||||
if sys.pypy_version_info.releaselevel != "final": # type: ignore[attr-defined]
|
||||
implementation_version = "".join(
|
||||
[implementation_version, sys.pypy_version_info.releaselevel] # type: ignore[attr-defined]
|
||||
)
|
||||
elif implementation == "Jython":
|
||||
implementation_version = platform.python_version() # Complete Guess
|
||||
elif implementation == "IronPython":
|
||||
implementation_version = platform.python_version() # Complete Guess
|
||||
else:
|
||||
implementation_version = "Unknown"
|
||||
|
||||
return {"name": implementation, "version": implementation_version}
|
||||
|
||||
|
||||
def info():
|
||||
"""Generate information for a bug report."""
|
||||
try:
|
||||
platform_info = {
|
||||
"system": platform.system(),
|
||||
"release": platform.release(),
|
||||
}
|
||||
except OSError:
|
||||
platform_info = {
|
||||
"system": "Unknown",
|
||||
"release": "Unknown",
|
||||
}
|
||||
|
||||
implementation_info = _implementation()
|
||||
urllib3_info = {
|
||||
"version": urllib3.__version__,
|
||||
"cohabitation_version": __legacy_urllib3_version__,
|
||||
}
|
||||
|
||||
charset_normalizer_info = {"version": charset_normalizer.__version__}
|
||||
|
||||
idna_info = {
|
||||
"version": getattr(idna, "__version__", "N/A"),
|
||||
}
|
||||
|
||||
if ssl is not None:
|
||||
system_ssl = ssl.OPENSSL_VERSION_NUMBER
|
||||
|
||||
system_ssl_info = {
|
||||
"version": f"{system_ssl:x}" if system_ssl is not None else "N/A",
|
||||
"name": ssl.OPENSSL_VERSION,
|
||||
}
|
||||
else:
|
||||
system_ssl_info = {"version": "N/A", "name": "N/A"}
|
||||
|
||||
return {
|
||||
"platform": platform_info,
|
||||
"implementation": implementation_info,
|
||||
"system_ssl": system_ssl_info,
|
||||
"gil": not _IS_GIL_DISABLED,
|
||||
"urllib3.future": urllib3_info,
|
||||
"charset_normalizer": charset_normalizer_info,
|
||||
"idna": idna_info,
|
||||
"niquests": {
|
||||
"version": niquests_version,
|
||||
},
|
||||
"http3": {
|
||||
"enabled": qh3 is not None,
|
||||
"qh3": qh3.__version__ if qh3 is not None else None,
|
||||
},
|
||||
"http2": {
|
||||
"jh2": jh2.__version__,
|
||||
},
|
||||
"http1": {
|
||||
"h11": h11.__version__,
|
||||
},
|
||||
"wassima": {
|
||||
"version": wassima.__version__,
|
||||
},
|
||||
"ocsp": {"enabled": ocsp_verify is not None},
|
||||
"websocket": {
|
||||
"enabled": wsproto is not None,
|
||||
"wsproto": wsproto.__version__ if wsproto is not None else None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
pypi_session = Session()
|
||||
|
||||
|
||||
def check_update(package_name: str, actual_version: str) -> None:
|
||||
"""
|
||||
Small and concise utility to check for updates.
|
||||
"""
|
||||
try:
|
||||
response = pypi_session.get(f"https://pypi.org/pypi/{package_name}/json")
|
||||
package_info = response.raise_for_status().json()
|
||||
|
||||
if isinstance(package_info, dict) and "info" in package_info and "version" in package_info["info"]:
|
||||
if package_info["info"]["version"] != actual_version:
|
||||
warnings.warn(
|
||||
f"You are using {package_name} {actual_version} and "
|
||||
f"PyPI yield version ({package_info['info']['version']}) as the stable one. "
|
||||
"We invite you to install this version as soon as possible. "
|
||||
f"Run `python -m pip install {package_name} -U`.",
|
||||
UserWarning,
|
||||
)
|
||||
except (RequestException, JSONDecodeError, HTTPError):
|
||||
pass
|
||||
|
||||
|
||||
PACKAGE_TO_CHECK_FOR_UPGRADE = {
|
||||
"niquests": niquests_version,
|
||||
"urllib3-future": urllib3.__version__,
|
||||
"qh3": qh3.__version__ if qh3 is not None else None,
|
||||
"jh2": jh2.__version__,
|
||||
"h11": h11.__version__,
|
||||
"charset-normalizer": charset_normalizer.__version__,
|
||||
"wassima": wassima.__version__,
|
||||
"idna": idna.__version__ if idna is not None else None,
|
||||
"wsproto": wsproto.__version__ if wsproto is not None else None,
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Pretty-print the bug information as JSON."""
|
||||
for package, actual_version in PACKAGE_TO_CHECK_FOR_UPGRADE.items():
|
||||
if actual_version is None:
|
||||
continue
|
||||
check_update(package, actual_version)
|
||||
|
||||
if __legacy_urllib3_version__ is not None:
|
||||
warnings.warn(
|
||||
"urllib3-future is installed alongside (legacy) urllib3. This may cause compatibility issues. "
|
||||
"Some (Requests) 3rd parties may be bound to urllib3, therefor the plugins may wrongfully invoke "
|
||||
"urllib3 (legacy) instead of urllib3-future. To remediate this, run "
|
||||
"`python -m pip uninstall -y urllib3 urllib3-future`, then run `python -m pip install urllib3-future`.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
print(json.dumps(info(), sort_keys=True, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
436
.venv/lib/python3.9/site-packages/niquests/hooks.py
Normal file
436
.venv/lib/python3.9/site-packages/niquests/hooks.py
Normal file
@@ -0,0 +1,436 @@
|
||||
"""
|
||||
requests.hooks
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
This module provides the capabilities for the Requests hooks system.
|
||||
|
||||
Available hooks:
|
||||
|
||||
``pre_request``:
|
||||
The prepared request just got built. You may alter it prior to be sent through HTTP.
|
||||
``pre_send``:
|
||||
The prepared request got his ConnectionInfo injected.
|
||||
This event is triggered just after picking a live connection from the pool.
|
||||
``on_upload``:
|
||||
Permit to monitor the upload progress of passed body.
|
||||
This event is triggered each time a block of data is transmitted to the remote peer.
|
||||
Use this hook carefully as it may impact the overall performance.
|
||||
``response``:
|
||||
The response generated from a Request.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
import typing
|
||||
from collections.abc import MutableMapping
|
||||
|
||||
from ._compat import iscoroutinefunction
|
||||
from .typing import (
|
||||
_HV,
|
||||
AsyncHookCallableType,
|
||||
AsyncHookType,
|
||||
HookCallableType,
|
||||
HookType,
|
||||
)
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .models import PreparedRequest, Response
|
||||
|
||||
HOOKS = [
|
||||
"pre_request",
|
||||
"pre_send",
|
||||
"on_upload",
|
||||
"early_response",
|
||||
"response",
|
||||
]
|
||||
|
||||
|
||||
def default_hooks() -> HookType[_HV]:
|
||||
return {event: [] for event in HOOKS}
|
||||
|
||||
|
||||
def dispatch_hook(key: str, hooks: HookType[_HV] | None, hook_data: _HV, **kwargs: typing.Any) -> _HV:
|
||||
"""Dispatches a hook dictionary on a given piece of data."""
|
||||
if hooks is None:
|
||||
return hook_data
|
||||
|
||||
callables: list[HookCallableType[_HV]] | None = hooks.get(key) # type: ignore[assignment]
|
||||
|
||||
if callables:
|
||||
if callable(callables):
|
||||
callables = [callables]
|
||||
for hook in callables:
|
||||
try:
|
||||
_hook_data = hook(hook_data, **kwargs)
|
||||
except TypeError:
|
||||
_hook_data = hook(hook_data)
|
||||
if _hook_data is not None:
|
||||
hook_data = _hook_data
|
||||
|
||||
return hook_data
|
||||
|
||||
|
||||
async def async_dispatch_hook(key: str, hooks: AsyncHookType[_HV] | None, hook_data: _HV, **kwargs: typing.Any) -> _HV:
|
||||
"""Dispatches a hook dictionary on a given piece of data asynchronously."""
|
||||
if hooks is None:
|
||||
return hook_data
|
||||
|
||||
callables: list[HookCallableType[_HV] | AsyncHookCallableType[_HV]] | None = hooks.get(key)
|
||||
|
||||
if callables:
|
||||
if callable(callables):
|
||||
callables = [callables]
|
||||
for hook in callables:
|
||||
if iscoroutinefunction(hook):
|
||||
try:
|
||||
_hook_data = await hook(hook_data, **kwargs)
|
||||
except TypeError:
|
||||
_hook_data = await hook(hook_data)
|
||||
else:
|
||||
try:
|
||||
_hook_data = hook(hook_data, **kwargs)
|
||||
except TypeError:
|
||||
_hook_data = hook(hook_data)
|
||||
|
||||
if _hook_data is not None:
|
||||
hook_data = _hook_data
|
||||
|
||||
return hook_data
|
||||
|
||||
|
||||
class _BaseLifeCycleHook(
|
||||
typing.MutableMapping[str, typing.List[typing.Union[HookCallableType, AsyncHookCallableType]]], typing.Generic[_HV]
|
||||
):
|
||||
def __init__(self) -> None:
|
||||
self._store: MutableMapping[str, list[HookCallableType[_HV] | AsyncHookCallableType[_HV]]] = {
|
||||
"pre_request": [],
|
||||
"pre_send": [],
|
||||
"on_upload": [],
|
||||
"early_response": [],
|
||||
"response": [],
|
||||
}
|
||||
|
||||
def __setitem__(self, key: str | bytes, value: list[HookCallableType[_HV] | AsyncHookCallableType[_HV]]) -> None:
|
||||
raise NotImplementedError("LifeCycleHook is Read Only")
|
||||
|
||||
def __getitem__(self, key: str) -> list[HookCallableType[_HV] | AsyncHookCallableType[_HV]]:
|
||||
return self._store[key]
|
||||
|
||||
def get(self, key: str) -> list[HookCallableType[_HV] | AsyncHookCallableType[_HV]]: # type: ignore[override]
|
||||
return self[key]
|
||||
|
||||
def __add__(self, other) -> _BaseLifeCycleHook:
|
||||
if not isinstance(other, _BaseLifeCycleHook):
|
||||
raise TypeError
|
||||
|
||||
tmp_store = {}
|
||||
combined_hooks: _BaseLifeCycleHook[_HV] = _BaseLifeCycleHook()
|
||||
|
||||
for h, fns in self._store.items():
|
||||
tmp_store[h] = fns
|
||||
tmp_store[h] += other._store[h]
|
||||
|
||||
combined_hooks._store = tmp_store
|
||||
|
||||
return combined_hooks
|
||||
|
||||
def __iter__(self):
|
||||
yield from self._store
|
||||
|
||||
def items(self):
|
||||
for key in self:
|
||||
yield key, self[key]
|
||||
|
||||
def __delitem__(self, key):
|
||||
raise NotImplementedError("LifeCycleHook is Read Only")
|
||||
|
||||
def __len__(self):
|
||||
return len(self._store)
|
||||
|
||||
|
||||
class LifeCycleHook(_BaseLifeCycleHook[_HV]):
|
||||
"""
|
||||
A sync-only middleware to be used in your request/response lifecycles.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._store.update(
|
||||
{
|
||||
"pre_request": [self.pre_request], # type: ignore[list-item]
|
||||
"pre_send": [self.pre_send], # type: ignore[list-item]
|
||||
"on_upload": [self.on_upload], # type: ignore[list-item]
|
||||
"early_response": [self.early_response], # type: ignore[list-item]
|
||||
"response": [self.response], # type: ignore[list-item]
|
||||
}
|
||||
)
|
||||
|
||||
def pre_request(self, prepared_request: PreparedRequest, **kwargs: typing.Any) -> PreparedRequest | None:
|
||||
"""The prepared request just got built. You may alter it prior to be sent through HTTP."""
|
||||
return None
|
||||
|
||||
def pre_send(self, prepared_request: PreparedRequest, **kwargs: typing.Any) -> None:
|
||||
"""The prepared request got his ConnectionInfo injected. This event is triggered just
|
||||
after picking a live connection from the pool. You may not alter the prepared request."""
|
||||
return None
|
||||
|
||||
def on_upload(self, prepared_request: PreparedRequest, **kwargs: typing.Any) -> None:
|
||||
"""Permit to monitor the upload progress of passed body. This event is triggered each time
|
||||
a block of data is transmitted to the remote peer. Use this hook carefully as
|
||||
it may impact the overall performance. You may not alter the prepared request."""
|
||||
return None
|
||||
|
||||
def early_response(self, response: Response, **kwargs: typing.Any) -> None:
|
||||
"""An early response caught before receiving the final Response for a given Request.
|
||||
Like but not limited to 103 Early Hints."""
|
||||
return None
|
||||
|
||||
def response(self, response: Response, **kwargs: typing.Any) -> Response | None:
|
||||
"""The response generated from a Request. You may alter the response at will."""
|
||||
return None
|
||||
|
||||
|
||||
class AsyncLifeCycleHook(_BaseLifeCycleHook[_HV]):
|
||||
"""
|
||||
An async-only middleware to be used in your request/response lifecycles.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._store.update(
|
||||
{
|
||||
"pre_request": [self.pre_request], # type: ignore[list-item]
|
||||
"pre_send": [self.pre_send], # type: ignore[list-item]
|
||||
"on_upload": [self.on_upload], # type: ignore[list-item]
|
||||
"early_response": [self.early_response], # type: ignore[list-item]
|
||||
"response": [self.response], # type: ignore[list-item]
|
||||
}
|
||||
)
|
||||
|
||||
async def pre_request(self, prepared_request: PreparedRequest, **kwargs: typing.Any) -> PreparedRequest | None:
|
||||
"""The prepared request just got built. You may alter it prior to be sent through HTTP."""
|
||||
return None
|
||||
|
||||
async def pre_send(self, prepared_request: PreparedRequest, **kwargs: typing.Any) -> None:
|
||||
"""The prepared request got his ConnectionInfo injected. This event is triggered just
|
||||
after picking a live connection from the pool. You may not alter the prepared request."""
|
||||
return None
|
||||
|
||||
async def on_upload(self, prepared_request: PreparedRequest, **kwargs: typing.Any) -> None:
|
||||
"""Permit to monitor the upload progress of passed body. This event is triggered each time
|
||||
a block of data is transmitted to the remote peer. Use this hook carefully as
|
||||
it may impact the overall performance. You may not alter the prepared request."""
|
||||
return None
|
||||
|
||||
async def early_response(self, response: Response, **kwargs: typing.Any) -> None:
|
||||
"""An early response caught before receiving the final Response for a given Request.
|
||||
Like but not limited to 103 Early Hints."""
|
||||
return None
|
||||
|
||||
async def response(self, response: Response, **kwargs: typing.Any) -> Response | None:
|
||||
"""The response generated from a Request. You may alter the response at will."""
|
||||
return None
|
||||
|
||||
|
||||
class _LeakyBucketMixin:
|
||||
"""Shared leaky bucket algorithm logic."""
|
||||
|
||||
rate: float
|
||||
interval: float
|
||||
last_request: float | None
|
||||
|
||||
def _init_leaky_bucket(self, rate: float) -> None:
|
||||
self.rate = rate
|
||||
self.interval = 1.0 / rate
|
||||
self.last_request = None
|
||||
|
||||
def _compute_wait(self) -> float:
|
||||
"""Compute wait time and update state. Returns wait time (may be <= 0)."""
|
||||
now = time.monotonic()
|
||||
if self.last_request is not None:
|
||||
elapsed = now - self.last_request
|
||||
wait_time = self.interval - elapsed
|
||||
else:
|
||||
wait_time = 0.0
|
||||
return wait_time
|
||||
|
||||
def _record_request(self) -> None:
|
||||
"""Record that a request was made."""
|
||||
self.last_request = time.monotonic()
|
||||
|
||||
|
||||
class _TokenBucketMixin:
|
||||
"""Shared token bucket algorithm logic."""
|
||||
|
||||
rate: float
|
||||
capacity: float
|
||||
tokens: float
|
||||
last_update: float
|
||||
|
||||
def _init_token_bucket(self, rate: float, capacity: float | None) -> None:
|
||||
self.rate = rate
|
||||
self.capacity = capacity if capacity is not None else rate
|
||||
self.tokens = self.capacity
|
||||
self.last_update = time.monotonic()
|
||||
|
||||
def _acquire_token(self) -> float | None:
|
||||
"""Replenish tokens and try to acquire one. Returns wait time if needed, None otherwise."""
|
||||
now = time.monotonic()
|
||||
elapsed = now - self.last_update
|
||||
self.tokens = min(self.capacity, self.tokens + elapsed * self.rate)
|
||||
self.last_update = now
|
||||
|
||||
if self.tokens >= 1.0:
|
||||
self.tokens -= 1.0
|
||||
return None
|
||||
else:
|
||||
# Don't update last_update here; let _post_wait handle it
|
||||
wait_time = (1.0 - self.tokens) / self.rate
|
||||
return wait_time
|
||||
|
||||
def _post_wait(self) -> None:
|
||||
"""Called after waiting to consume the token."""
|
||||
now = time.monotonic()
|
||||
elapsed = now - self.last_update
|
||||
# Replenish tokens accumulated during the wait
|
||||
self.tokens = min(self.capacity, self.tokens + elapsed * self.rate)
|
||||
self.last_update = now
|
||||
# Now consume the token
|
||||
self.tokens -= 1.0
|
||||
|
||||
|
||||
class LeakyBucketLimiter(_LeakyBucketMixin, LifeCycleHook):
|
||||
"""Rate limiter using the leaky bucket algorithm.
|
||||
|
||||
Requests "leak" out at a constant rate. When a request arrives, it waits
|
||||
until enough time has passed since the last request to maintain the rate.
|
||||
|
||||
Usage::
|
||||
|
||||
limiter = LeakyBucketLimiter(rate=10.0) # 10 requests per second
|
||||
with niquests.Session(hooks=limiter) as session:
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(self, rate: float = 10.0) -> None:
|
||||
"""Initialize the leaky bucket limiter.
|
||||
|
||||
Args:
|
||||
rate: Maximum requests per second
|
||||
"""
|
||||
super().__init__()
|
||||
self._init_leaky_bucket(rate)
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def pre_request(self, prepared_request: PreparedRequest, **kwargs: typing.Any) -> PreparedRequest | None:
|
||||
"""Wait if needed to maintain the rate limit."""
|
||||
with self._lock:
|
||||
wait_time = self._compute_wait()
|
||||
if wait_time > 0:
|
||||
time.sleep(wait_time)
|
||||
self._record_request()
|
||||
return None
|
||||
|
||||
|
||||
class AsyncLeakyBucketLimiter(_LeakyBucketMixin, AsyncLifeCycleHook):
|
||||
"""Rate limiter using the leaky bucket algorithm.
|
||||
|
||||
Requests "leak" out at a constant rate. When a request arrives, it waits
|
||||
until enough time has passed since the last request to maintain the rate.
|
||||
|
||||
Usage::
|
||||
|
||||
limiter = AsyncLeakyBucketLimiter(rate=10.0) # 10 requests per second
|
||||
async with niquests.AsyncSession(hooks=limiter) as session:
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(self, rate: float = 10.0) -> None:
|
||||
"""Initialize the leaky bucket limiter.
|
||||
|
||||
Args:
|
||||
rate: Maximum requests per second
|
||||
"""
|
||||
super().__init__()
|
||||
self._init_leaky_bucket(rate)
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def pre_request(self, prepared_request: PreparedRequest, **kwargs: typing.Any) -> PreparedRequest | None:
|
||||
"""Wait if needed to maintain the rate limit."""
|
||||
async with self._lock:
|
||||
wait_time = self._compute_wait()
|
||||
if wait_time > 0:
|
||||
await asyncio.sleep(wait_time)
|
||||
self._record_request()
|
||||
return None
|
||||
|
||||
|
||||
class TokenBucketLimiter(_TokenBucketMixin, LifeCycleHook):
|
||||
"""Rate limiter using the token bucket algorithm.
|
||||
|
||||
Tokens are added to a bucket at a constant rate up to a maximum capacity.
|
||||
Each request consumes one token. Allows bursts up to the bucket capacity.
|
||||
|
||||
Usage::
|
||||
|
||||
limiter = TokenBucketLimiter(rate=10.0, capacity=50.0) # 10/s, burst of 50
|
||||
with niquests.Session(hooks=limiter) as session:
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(self, rate: float = 10.0, capacity: float | None = None) -> None:
|
||||
"""Initialize the token bucket limiter.
|
||||
|
||||
Args:
|
||||
rate: Token replenishment rate (tokens per second)
|
||||
capacity: Maximum bucket capacity (defaults to rate, allowing 1 second burst)
|
||||
"""
|
||||
super().__init__()
|
||||
self._init_token_bucket(rate, capacity)
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def pre_request(self, prepared_request: PreparedRequest, **kwargs: typing.Any) -> PreparedRequest | None:
|
||||
"""Wait until a token is available, then consume it."""
|
||||
with self._lock:
|
||||
wait_time = self._acquire_token()
|
||||
if wait_time is not None:
|
||||
time.sleep(wait_time)
|
||||
self._post_wait()
|
||||
return None
|
||||
|
||||
|
||||
class AsyncTokenBucketLimiter(_TokenBucketMixin, AsyncLifeCycleHook):
|
||||
"""Rate limiter using the token bucket algorithm.
|
||||
|
||||
Tokens are added to a bucket at a constant rate up to a maximum capacity.
|
||||
Each request consumes one token. Allows bursts up to the bucket capacity.
|
||||
|
||||
Usage::
|
||||
|
||||
limiter = AsyncTokenBucketLimiter(rate=10.0, capacity=50.0) # 10/s, burst of 50
|
||||
async with niquests.AsyncSession(hooks=limiter) as session:
|
||||
...
|
||||
"""
|
||||
|
||||
def __init__(self, rate: float = 10.0, capacity: float | None = None) -> None:
|
||||
"""Initialize the token bucket limiter.
|
||||
|
||||
Args:
|
||||
rate: Token replenishment rate (tokens per second)
|
||||
capacity: Maximum bucket capacity (defaults to rate, allowing 1 second burst)
|
||||
"""
|
||||
super().__init__()
|
||||
self._init_token_bucket(rate, capacity)
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def pre_request(self, prepared_request: PreparedRequest, **kwargs: typing.Any) -> PreparedRequest | None:
|
||||
"""Wait until a token is available, then consume it."""
|
||||
async with self._lock:
|
||||
wait_time = self._acquire_token()
|
||||
if wait_time is not None:
|
||||
await asyncio.sleep(wait_time)
|
||||
self._post_wait()
|
||||
return None
|
||||
2006
.venv/lib/python3.9/site-packages/niquests/models.py
Normal file
2006
.venv/lib/python3.9/site-packages/niquests/models.py
Normal file
File diff suppressed because it is too large
Load Diff
122
.venv/lib/python3.9/site-packages/niquests/packages.py
Normal file
122
.venv/lib/python3.9/site-packages/niquests/packages.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import importlib.abc
|
||||
import importlib.machinery
|
||||
import importlib.util
|
||||
import sys
|
||||
import typing
|
||||
|
||||
from ._compat import HAS_LEGACY_URLLIB3
|
||||
|
||||
# just to enable smooth type-completion!
|
||||
if typing.TYPE_CHECKING:
|
||||
import charset_normalizer as chardet
|
||||
import urllib3
|
||||
|
||||
charset_normalizer = chardet
|
||||
|
||||
import idna # type: ignore[import-not-found]
|
||||
|
||||
# Mapping of aliased package prefixes:
|
||||
# "niquests.packages.<alias>." to "<real-package>."
|
||||
# Populated by the loop below, consumed by the import hook.
|
||||
_ALIAS_TO_REAL: dict[str, str] = {}
|
||||
|
||||
# This code exists for backwards compatibility reasons.
|
||||
# I don't like it either. Just look the other way. :)
|
||||
for package in (
|
||||
"urllib3",
|
||||
"charset_normalizer",
|
||||
"idna",
|
||||
"chardet",
|
||||
):
|
||||
to_be_imported: str = package
|
||||
|
||||
if package == "chardet":
|
||||
to_be_imported = "charset_normalizer"
|
||||
elif package == "urllib3" and HAS_LEGACY_URLLIB3:
|
||||
to_be_imported = "urllib3_future"
|
||||
|
||||
try:
|
||||
locals()[package] = __import__(to_be_imported)
|
||||
except ImportError:
|
||||
continue # idna could be missing. not required!
|
||||
|
||||
# Determine the alias prefix (what niquests code imports)
|
||||
# and the real prefix (the actual installed package).
|
||||
if package == "chardet":
|
||||
alias_prefix = "niquests.packages.chardet."
|
||||
real_prefix = "charset_normalizer."
|
||||
alias_root = "niquests.packages.chardet"
|
||||
real_root = "charset_normalizer"
|
||||
elif package == "urllib3" and HAS_LEGACY_URLLIB3:
|
||||
alias_prefix = "niquests.packages.urllib3."
|
||||
real_prefix = "urllib3_future."
|
||||
alias_root = "niquests.packages.urllib3"
|
||||
real_root = "urllib3_future"
|
||||
else:
|
||||
alias_prefix = f"niquests.packages.{package}."
|
||||
real_prefix = f"{package}."
|
||||
alias_root = f"niquests.packages.{package}"
|
||||
real_root = package
|
||||
|
||||
_ALIAS_TO_REAL[alias_prefix] = real_prefix
|
||||
_ALIAS_TO_REAL[alias_root] = real_root
|
||||
|
||||
# This traversal is apparently necessary such that the identities are
|
||||
# preserved (requests.packages.urllib3.* is urllib3.*)
|
||||
for mod in list(sys.modules):
|
||||
if mod == to_be_imported or mod.startswith(f"{to_be_imported}."):
|
||||
inner_mod = mod
|
||||
|
||||
if HAS_LEGACY_URLLIB3 and inner_mod == "urllib3_future" or inner_mod.startswith("urllib3_future."):
|
||||
inner_mod = inner_mod.replace("urllib3_future", "urllib3")
|
||||
elif inner_mod == "charset_normalizer":
|
||||
inner_mod = "chardet"
|
||||
|
||||
try:
|
||||
sys.modules[f"niquests.packages.{inner_mod}"] = sys.modules[mod]
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
|
||||
class _NiquestsPackagesAliasImporter(importlib.abc.MetaPathFinder):
|
||||
"""Made to avoid duplicate due to lazy imports at urllib3-future side(...)"""
|
||||
|
||||
def find_spec(
|
||||
self,
|
||||
fullname: str,
|
||||
path: typing.Any = None,
|
||||
target: typing.Any = None,
|
||||
) -> importlib.machinery.ModuleSpec | None:
|
||||
if fullname in sys.modules:
|
||||
return None
|
||||
|
||||
real_name: str | None = None
|
||||
for alias, real in _ALIAS_TO_REAL.items():
|
||||
if fullname == alias or fullname.startswith(alias if alias.endswith(".") else alias + "."):
|
||||
real_name = real + fullname[len(alias) :]
|
||||
break
|
||||
|
||||
if real_name is None:
|
||||
return None
|
||||
|
||||
# Import the real module first, then point the alias at it.
|
||||
real_module = importlib.import_module(real_name)
|
||||
sys.modules[fullname] = real_module
|
||||
|
||||
# Return a spec that resolves to the cached module.
|
||||
return importlib.util.spec_from_loader(fullname, loader=None, origin=real_name) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# Insert at front so we intercept before the default PathFinder.
|
||||
sys.meta_path.insert(0, _NiquestsPackagesAliasImporter())
|
||||
|
||||
|
||||
__all__ = (
|
||||
"urllib3",
|
||||
"chardet",
|
||||
"charset_normalizer",
|
||||
"idna",
|
||||
)
|
||||
0
.venv/lib/python3.9/site-packages/niquests/py.typed
Normal file
0
.venv/lib/python3.9/site-packages/niquests/py.typed
Normal file
1956
.venv/lib/python3.9/site-packages/niquests/sessions.py
Normal file
1956
.venv/lib/python3.9/site-packages/niquests/sessions.py
Normal file
File diff suppressed because it is too large
Load Diff
126
.venv/lib/python3.9/site-packages/niquests/status_codes.py
Normal file
126
.venv/lib/python3.9/site-packages/niquests/status_codes.py
Normal file
@@ -0,0 +1,126 @@
|
||||
r"""
|
||||
The ``codes`` object defines a mapping from common names for HTTP statuses
|
||||
to their numerical codes, accessible either as attributes or as dictionary
|
||||
items.
|
||||
|
||||
Example::
|
||||
|
||||
>>> import niquests
|
||||
>>> niquests.codes['temporary_redirect']
|
||||
307
|
||||
>>> niquests.codes.teapot
|
||||
418
|
||||
>>> niquests.codes['\o/']
|
||||
200
|
||||
|
||||
Some codes have multiple names, and both upper- and lower-case versions of
|
||||
the names are allowed. For example, ``codes.ok``, ``codes.OK``, and
|
||||
``codes.okay`` all correspond to the HTTP status code 200.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .structures import LookupDict
|
||||
|
||||
_codes = {
|
||||
# Informational.
|
||||
100: ("continue",),
|
||||
101: ("switching_protocols",),
|
||||
102: ("processing",),
|
||||
103: ("checkpoint",),
|
||||
122: ("uri_too_long", "request_uri_too_long"),
|
||||
200: ("ok", "okay", "all_ok", "all_okay", "all_good", "\\o/", "✓"),
|
||||
201: ("created",),
|
||||
202: ("accepted",),
|
||||
203: ("non_authoritative_info", "non_authoritative_information"),
|
||||
204: ("no_content",),
|
||||
205: ("reset_content", "reset"),
|
||||
206: ("partial_content", "partial"),
|
||||
207: ("multi_status", "multiple_status", "multi_stati", "multiple_stati"),
|
||||
208: ("already_reported",),
|
||||
226: ("im_used",),
|
||||
# Redirection.
|
||||
300: ("multiple_choices",),
|
||||
301: ("moved_permanently", "moved", "\\o-"),
|
||||
302: ("found",),
|
||||
303: ("see_other", "other"),
|
||||
304: ("not_modified",),
|
||||
305: ("use_proxy",),
|
||||
306: ("switch_proxy",),
|
||||
307: ("temporary_redirect", "temporary_moved", "temporary"),
|
||||
308: (
|
||||
"permanent_redirect",
|
||||
"resume_incomplete",
|
||||
"resume",
|
||||
), # "resume" and "resume_incomplete" to be removed in 3.0
|
||||
# Client Error.
|
||||
400: ("bad_request", "bad"),
|
||||
401: ("unauthorized",),
|
||||
402: ("payment_required", "payment"),
|
||||
403: ("forbidden",),
|
||||
404: ("not_found", "-o-"),
|
||||
405: ("method_not_allowed", "not_allowed"),
|
||||
406: ("not_acceptable",),
|
||||
407: ("proxy_authentication_required", "proxy_auth", "proxy_authentication"),
|
||||
408: ("request_timeout", "timeout"),
|
||||
409: ("conflict",),
|
||||
410: ("gone",),
|
||||
411: ("length_required",),
|
||||
412: ("precondition_failed", "precondition"),
|
||||
413: ("request_entity_too_large",),
|
||||
414: ("request_uri_too_large",),
|
||||
415: ("unsupported_media_type", "unsupported_media", "media_type"),
|
||||
416: (
|
||||
"requested_range_not_satisfiable",
|
||||
"requested_range",
|
||||
"range_not_satisfiable",
|
||||
),
|
||||
417: ("expectation_failed",),
|
||||
418: ("im_a_teapot", "teapot", "i_am_a_teapot"),
|
||||
421: ("misdirected_request",),
|
||||
422: ("unprocessable_entity", "unprocessable"),
|
||||
423: ("locked",),
|
||||
424: ("failed_dependency", "dependency"),
|
||||
425: ("unordered_collection", "unordered", "too_early"),
|
||||
426: ("upgrade_required", "upgrade"),
|
||||
428: ("precondition_required", "precondition"),
|
||||
429: ("too_many_requests", "too_many"),
|
||||
431: ("header_fields_too_large", "fields_too_large"),
|
||||
444: ("no_response", "none"),
|
||||
449: ("retry_with", "retry"),
|
||||
450: ("blocked_by_windows_parental_controls", "parental_controls"),
|
||||
451: ("unavailable_for_legal_reasons", "legal_reasons"),
|
||||
499: ("client_closed_request",),
|
||||
# Server Error.
|
||||
500: ("internal_server_error", "server_error", "/o\\", "✗"),
|
||||
501: ("not_implemented",),
|
||||
502: ("bad_gateway",),
|
||||
503: ("service_unavailable", "unavailable"),
|
||||
504: ("gateway_timeout",),
|
||||
505: ("http_version_not_supported", "http_version"),
|
||||
506: ("variant_also_negotiates",),
|
||||
507: ("insufficient_storage",),
|
||||
509: ("bandwidth_limit_exceeded", "bandwidth"),
|
||||
510: ("not_extended",),
|
||||
511: ("network_authentication_required", "network_auth", "network_authentication"),
|
||||
}
|
||||
|
||||
codes = LookupDict(name="status_codes")
|
||||
|
||||
|
||||
def _init():
|
||||
for code, titles in _codes.items():
|
||||
for title in titles:
|
||||
setattr(codes, title, code)
|
||||
if not title.startswith(("\\", "/")):
|
||||
setattr(codes, title.upper(), code)
|
||||
|
||||
def doc(code):
|
||||
names = ", ".join(f"``{n}``" for n in _codes[code])
|
||||
return f"* {code}: {names}"
|
||||
|
||||
global __doc__
|
||||
__doc__ = __doc__ + "\n" + "\n".join(doc(code) for code in sorted(_codes)) if __doc__ is not None else None
|
||||
|
||||
|
||||
_init()
|
||||
283
.venv/lib/python3.9/site-packages/niquests/structures.py
Normal file
283
.venv/lib/python3.9/site-packages/niquests/structures.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
requests.structures
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Data structures that power Requests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import typing
|
||||
from collections.abc import Mapping, MutableMapping
|
||||
|
||||
try:
|
||||
from ._compat import HAS_LEGACY_URLLIB3
|
||||
|
||||
if not HAS_LEGACY_URLLIB3:
|
||||
from urllib3._collections import _lower_wrapper # type: ignore[attr-defined]
|
||||
else: # Defensive: tested in separate/isolated CI
|
||||
from urllib3_future._collections import (
|
||||
_lower_wrapper, # type: ignore[attr-defined]
|
||||
)
|
||||
except ImportError:
|
||||
from functools import lru_cache
|
||||
|
||||
@lru_cache(maxsize=64)
|
||||
def _lower_wrapper(string: str) -> str:
|
||||
"""backport"""
|
||||
return string.lower()
|
||||
|
||||
|
||||
from .exceptions import InvalidHeader
|
||||
|
||||
|
||||
def _ensure_str_or_bytes(key: typing.Any, value: typing.Any) -> tuple[bytes | str, bytes | str]:
|
||||
if isinstance(key, (bytes, str)) and isinstance(value, (bytes, str)):
|
||||
return key, value
|
||||
if isinstance(
|
||||
value,
|
||||
(
|
||||
float,
|
||||
int,
|
||||
),
|
||||
):
|
||||
value = str(value)
|
||||
if isinstance(key, (bytes, str)) is False or (value is not None and isinstance(value, (bytes, str)) is False):
|
||||
raise InvalidHeader(f"Illegal header name or value {key}")
|
||||
return key, value
|
||||
|
||||
|
||||
_T = typing.TypeVar("_T")
|
||||
|
||||
|
||||
class CaseInsensitiveDict(MutableMapping):
|
||||
"""A case-insensitive ``dict``-like object.
|
||||
|
||||
Implements all methods and operations of
|
||||
``MutableMapping`` as well as dict's ``copy``. Also
|
||||
provides ``lower_items``.
|
||||
|
||||
All keys are expected to be strings. The structure remembers the
|
||||
case of the last key to be set, and ``iter(instance)``,
|
||||
``keys()``, ``items()``, ``iterkeys()``, and ``iteritems()``
|
||||
will contain case-sensitive keys. However, querying and contains
|
||||
testing is case insensitive::
|
||||
|
||||
cid = CaseInsensitiveDict()
|
||||
cid['Accept'] = 'application/json'
|
||||
cid['aCCEPT'] == 'application/json' # True
|
||||
list(cid) == ['Accept'] # True
|
||||
|
||||
For example, ``headers['content-encoding']`` will return the
|
||||
value of a ``'Content-Encoding'`` response header, regardless
|
||||
of how the header name was originally stored.
|
||||
|
||||
If the constructor, ``.update``, or equality comparison
|
||||
operations are given keys that have equal ``.lower()``s, the
|
||||
behavior is undefined.
|
||||
"""
|
||||
|
||||
def __init__(self, data=None, **kwargs) -> None:
|
||||
self._store: MutableMapping[bytes | str, tuple[bytes | str, ...]] = {}
|
||||
if data is None:
|
||||
data = {}
|
||||
|
||||
# given object is most likely to be urllib3.HTTPHeaderDict or follow a similar implementation that we can trust
|
||||
if hasattr(data, "getlist"):
|
||||
self._store = data._container.copy()
|
||||
elif isinstance(data, CaseInsensitiveDict):
|
||||
self._store = data._store.copy() # type: ignore[attr-defined]
|
||||
else: # otherwise, we must ensure given iterable contains type we can rely on
|
||||
if data or kwargs:
|
||||
if hasattr(data, "items"):
|
||||
self.update(data, **kwargs)
|
||||
else:
|
||||
self.update(
|
||||
{k: v for k, v in data},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __setitem__(self, key: str | bytes, value: str | bytes) -> None:
|
||||
# Use the lowercased key for lookups, but store the actual
|
||||
# key alongside the value.
|
||||
self._store[_lower_wrapper(key)] = _ensure_str_or_bytes(key, value)
|
||||
|
||||
def __getitem__(self, key) -> bytes | str:
|
||||
e = self._store[_lower_wrapper(key)]
|
||||
if len(e) == 2:
|
||||
return e[1]
|
||||
# this path should always be list[str] (if coming from urllib3.HTTPHeaderDict!)
|
||||
try:
|
||||
return ", ".join(e[1:]) if isinstance(e[1], str) else b", ".join(e[1:]) # type: ignore[arg-type]
|
||||
except TypeError: # worst case scenario...
|
||||
return ", ".join(v.decode() if isinstance(v, bytes) else v for v in e[1:])
|
||||
|
||||
@typing.overload # type: ignore[override]
|
||||
def get(self, key: str | bytes) -> str | bytes | None: ...
|
||||
|
||||
@typing.overload
|
||||
def get(self, key: str | bytes, default: str | bytes) -> str | bytes: ...
|
||||
|
||||
@typing.overload
|
||||
def get(self, key: str | bytes, default: _T) -> str | bytes | _T: ...
|
||||
|
||||
def get(self, key: str | bytes, default: str | bytes | _T | None = None) -> str | bytes | _T | None:
|
||||
return super().get(key, default=default)
|
||||
|
||||
def __delitem__(self, key) -> None:
|
||||
del self._store[_lower_wrapper(key)]
|
||||
|
||||
def __iter__(self) -> typing.Iterator[str | bytes]:
|
||||
for key_ci in self._store:
|
||||
yield self._store[key_ci][0]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._store)
|
||||
|
||||
def lower_items(self) -> typing.Iterator[tuple[bytes | str, bytes | str]]:
|
||||
"""Like iteritems(), but with all lowercase keys."""
|
||||
return ((lowerkey, keyval[1]) for (lowerkey, keyval) in self._store.items())
|
||||
|
||||
def items(self):
|
||||
for k in self._store:
|
||||
t = self._store[k]
|
||||
if len(t) == 2:
|
||||
yield tuple(t)
|
||||
else: # this case happen due to copying "_container" from HTTPHeaderDict!
|
||||
try:
|
||||
yield t[0], ", ".join(t[1:]) # type: ignore[arg-type]
|
||||
except TypeError:
|
||||
yield (
|
||||
t[0],
|
||||
", ".join(v.decode() if isinstance(v, bytes) else v for v in t[1:]),
|
||||
)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if isinstance(other, Mapping):
|
||||
other = CaseInsensitiveDict(other)
|
||||
else:
|
||||
return NotImplemented
|
||||
# Compare insensitively
|
||||
return dict(self.lower_items()) == dict(other.lower_items())
|
||||
|
||||
# Copy is required
|
||||
def copy(self) -> CaseInsensitiveDict:
|
||||
return CaseInsensitiveDict(self)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(dict(self.items()))
|
||||
|
||||
def __contains__(self, item: str) -> bool: # type: ignore[override]
|
||||
return _lower_wrapper(item) in self._store
|
||||
|
||||
|
||||
class LookupDict(dict):
|
||||
"""Dictionary lookup object."""
|
||||
|
||||
def __init__(self, name=None) -> None:
|
||||
self.name: str | None = name
|
||||
super().__init__()
|
||||
|
||||
def __repr__(self):
|
||||
return f"<lookup '{self.name}'>"
|
||||
|
||||
def __getitem__(self, key):
|
||||
# We allow fall-through here, so values default to None
|
||||
return self.__dict__.get(key, None)
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self.__dict__.get(key, default)
|
||||
|
||||
|
||||
class SharableLimitedDict(typing.MutableMapping):
|
||||
def __init__(self, max_size: int | None) -> None:
|
||||
self._store: typing.MutableMapping[typing.Any, typing.Any] = {}
|
||||
self._max_size = max_size
|
||||
self._lock: threading.RLock | DummyLock = threading.RLock()
|
||||
|
||||
def __getstate__(self) -> dict[str, typing.Any]:
|
||||
return {"_store": self._store, "_max_size": self._max_size}
|
||||
|
||||
def __setstate__(self, state: dict[str, typing.Any]) -> None:
|
||||
self._lock = threading.RLock()
|
||||
self._store = state["_store"]
|
||||
self._max_size = state["_max_size"]
|
||||
|
||||
def __delitem__(self, __key) -> None:
|
||||
with self._lock:
|
||||
del self._store[__key]
|
||||
|
||||
def __len__(self) -> int:
|
||||
with self._lock:
|
||||
return len(self._store)
|
||||
|
||||
def __iter__(self) -> typing.Iterator:
|
||||
with self._lock:
|
||||
return iter(self._store)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
with self._lock:
|
||||
if self._max_size and len(self._store) >= self._max_size:
|
||||
self._store.popitem()
|
||||
|
||||
self._store[key] = value
|
||||
|
||||
def __getitem__(self, item):
|
||||
with self._lock:
|
||||
return self._store[item]
|
||||
|
||||
|
||||
class QuicSharedCache(SharableLimitedDict):
|
||||
def __init__(self, max_size: int | None) -> None:
|
||||
super().__init__(max_size)
|
||||
self._exclusion_store: typing.MutableMapping[typing.Any, typing.Any] = {}
|
||||
|
||||
def add_domain(self, host: str, port: int | None = None, alt_port: int | None = None) -> None:
|
||||
if port is None:
|
||||
port = 443
|
||||
if alt_port is None:
|
||||
alt_port = port
|
||||
self[(host, port)] = (host, alt_port)
|
||||
|
||||
def exclude_domain(self, host: str, port: int | None = None, alt_port: int | None = None):
|
||||
if port is None:
|
||||
port = 443
|
||||
if alt_port is None:
|
||||
alt_port = port
|
||||
self._exclusion_store[(host, port)] = (host, alt_port)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
with self._lock:
|
||||
if key in self._exclusion_store:
|
||||
return
|
||||
|
||||
if self._max_size and len(self._store) >= self._max_size:
|
||||
self._store.popitem()
|
||||
|
||||
self._store[key] = value
|
||||
|
||||
|
||||
class AsyncQuicSharedCache(QuicSharedCache):
|
||||
def __init__(self, max_size: int | None) -> None:
|
||||
super().__init__(max_size)
|
||||
self._lock = DummyLock()
|
||||
|
||||
def __setstate__(self, state: dict[str, typing.Any]) -> None:
|
||||
self._lock = DummyLock()
|
||||
self._store = state["_store"]
|
||||
self._max_size = state["_max_size"]
|
||||
|
||||
|
||||
class DummyLock:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
def acquire(self):
|
||||
pass
|
||||
|
||||
def release(self):
|
||||
pass
|
||||
206
.venv/lib/python3.9/site-packages/niquests/typing.py
Normal file
206
.venv/lib/python3.9/site-packages/niquests/typing.py
Normal file
@@ -0,0 +1,206 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import typing
|
||||
from http.cookiejar import CookieJar
|
||||
from os import PathLike
|
||||
|
||||
from ._vendor.kiss_headers import Headers
|
||||
from .auth import AsyncAuthBase, AuthBase
|
||||
from .packages.urllib3 import AsyncResolverDescription, ResolverDescription, Retry, Timeout
|
||||
from .packages.urllib3.contrib.resolver import BaseResolver
|
||||
from .packages.urllib3.contrib.resolver._async import AsyncBaseResolver
|
||||
from .packages.urllib3.fields import RequestField
|
||||
from .structures import CaseInsensitiveDict
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
# tool like pyright in strict mode can't infer what is ".packages.urllib3"
|
||||
# so we circumvent it there[...]
|
||||
from urllib3 import AsyncResolverDescription, ResolverDescription, Retry, Timeout # type: ignore[no-redef]
|
||||
from urllib3.contrib.resolver import BaseResolver # type: ignore[no-redef]
|
||||
from urllib3.contrib.resolver._async import AsyncBaseResolver # type: ignore[no-redef]
|
||||
from urllib3.fields import RequestField # type: ignore[no-redef]
|
||||
|
||||
from .hooks import AsyncLifeCycleHook, LifeCycleHook
|
||||
from .models import PreparedRequest
|
||||
|
||||
#: (Restricted) list of http verb that we natively support and understand.
|
||||
HttpMethodType: typing.TypeAlias = str
|
||||
#: List of formats accepted for URL queries parameters. (e.g. /?param1=a¶m2=b)
|
||||
QueryParameterType: typing.TypeAlias = typing.Union[
|
||||
typing.List[typing.Tuple[str, typing.Union[str, typing.List[str], None]]],
|
||||
typing.Mapping[str, typing.Union[str, typing.List[str], None]],
|
||||
bytes,
|
||||
str,
|
||||
]
|
||||
BodyFormType: typing.TypeAlias = typing.Union[
|
||||
typing.List[typing.Tuple[str, str]],
|
||||
typing.Dict[str, typing.Union[typing.List[str], str]],
|
||||
]
|
||||
#: Accepted types for the payload in POST, PUT, and PATCH requests.
|
||||
BodyType: typing.TypeAlias = typing.Union[
|
||||
str,
|
||||
bytes,
|
||||
bytearray,
|
||||
typing.IO[bytes],
|
||||
typing.IO[str],
|
||||
BodyFormType,
|
||||
typing.Iterable[bytes],
|
||||
typing.Iterable[str],
|
||||
]
|
||||
AsyncBodyType: typing.TypeAlias = typing.Union[
|
||||
typing.AsyncIterable[bytes],
|
||||
typing.AsyncIterable[str],
|
||||
]
|
||||
#: HTTP Headers can be represented through three ways. 1) typical dict, 2) internal insensitive dict, and 3) list of tuple.
|
||||
HeadersType: typing.TypeAlias = typing.Union[
|
||||
typing.MutableMapping[typing.Union[str, bytes], typing.Union[str, bytes]],
|
||||
typing.MutableMapping[str, str],
|
||||
typing.MutableMapping[bytes, bytes],
|
||||
CaseInsensitiveDict,
|
||||
typing.List[typing.Tuple[typing.Union[str, bytes], typing.Union[str, bytes]]],
|
||||
Headers,
|
||||
]
|
||||
#: We accept both typical mapping and stdlib CookieJar.
|
||||
CookiesType: typing.TypeAlias = typing.Union[
|
||||
typing.MutableMapping[str, str],
|
||||
CookieJar,
|
||||
]
|
||||
#: Either Yes/No, or CA bundle pem location. Or directly the raw bundle content itself.
|
||||
if sys.version_info >= (3, 9): # we can't subscribe PathLike until that version...
|
||||
# This one was found used directly within a Pydantic model
|
||||
# see https://github.com/jawah/niquests/issues/324
|
||||
TLSVerifyType: typing.TypeAlias = typing.Union[bool, str, bytes, PathLike[str]]
|
||||
else:
|
||||
TLSVerifyType: typing.TypeAlias = typing.Union[bool, str, bytes, "PathLike[str]"]
|
||||
#: Accept a pem certificate (concat cert, key) or an explicit tuple of cert, key pair with an optional password.
|
||||
TLSClientCertType: typing.TypeAlias = typing.Union[str, typing.Tuple[str, str], typing.Tuple[str, str, str]]
|
||||
#: All accepted ways to describe desired timeout.
|
||||
TimeoutType: typing.TypeAlias = typing.Union[
|
||||
int, # TotalTimeout
|
||||
float, # TotalTimeout
|
||||
typing.Tuple[typing.Union[int, float], typing.Union[int, float]], # note: TotalTimeout, ConnectTimeout
|
||||
typing.Tuple[
|
||||
typing.Union[int, float], typing.Union[int, float], typing.Union[int, float]
|
||||
], # note: TotalTimeout, ConnectTimeout, ReadTimeout
|
||||
Timeout,
|
||||
]
|
||||
#: Specify (BasicAuth) authentication by passing a tuple of user, and password.
|
||||
#: Can be a custom authentication mechanism that derive from AuthBase.
|
||||
HttpAuthenticationType: typing.TypeAlias = typing.Union[
|
||||
typing.Tuple[typing.Union[str, bytes], typing.Union[str, bytes]],
|
||||
str,
|
||||
AuthBase,
|
||||
typing.Callable[["PreparedRequest"], "PreparedRequest"],
|
||||
]
|
||||
AsyncHttpAuthenticationType: typing.TypeAlias = typing.Union[
|
||||
AsyncAuthBase,
|
||||
typing.Callable[["PreparedRequest"], typing.Awaitable["PreparedRequest"]],
|
||||
]
|
||||
#: Map for each protocol (http, https) associated proxy to be used.
|
||||
ProxyType: typing.TypeAlias = typing.Dict[str, str]
|
||||
|
||||
# cases:
|
||||
# 1) fn, fp
|
||||
# 2) fn, fp, ft
|
||||
# 3) fn, fp, ft, fh
|
||||
# OR
|
||||
# 4) fp
|
||||
BodyFileType: typing.TypeAlias = typing.Union[
|
||||
str,
|
||||
bytes,
|
||||
bytearray,
|
||||
typing.IO[str],
|
||||
typing.IO[bytes],
|
||||
]
|
||||
MultiPartFileType: typing.TypeAlias = typing.Tuple[
|
||||
str,
|
||||
typing.Union[
|
||||
BodyFileType,
|
||||
typing.Tuple[str, BodyFileType],
|
||||
typing.Tuple[str, BodyFileType, str],
|
||||
typing.Tuple[str, BodyFileType, str, HeadersType],
|
||||
],
|
||||
]
|
||||
MultiPartFilesType: typing.TypeAlias = typing.List[MultiPartFileType]
|
||||
#: files (multipart formdata) can be (also) passed as dict.
|
||||
MultiPartFilesAltType: typing.TypeAlias = typing.Dict[
|
||||
str,
|
||||
typing.Union[
|
||||
BodyFileType,
|
||||
typing.Tuple[str, BodyFileType],
|
||||
typing.Tuple[str, BodyFileType, str],
|
||||
typing.Tuple[str, BodyFileType, str, HeadersType],
|
||||
],
|
||||
]
|
||||
|
||||
FieldValueType: typing.TypeAlias = typing.Union[str, bytes]
|
||||
FieldTupleType: typing.TypeAlias = typing.Union[
|
||||
FieldValueType,
|
||||
typing.Tuple[str, FieldValueType],
|
||||
typing.Tuple[str, FieldValueType, str],
|
||||
]
|
||||
|
||||
FieldSequenceType: typing.TypeAlias = typing.Sequence[typing.Union[typing.Tuple[str, FieldTupleType], RequestField]]
|
||||
FieldsType: typing.TypeAlias = typing.Union[
|
||||
FieldSequenceType,
|
||||
typing.Mapping[str, FieldTupleType],
|
||||
]
|
||||
|
||||
_HV = typing.TypeVar("_HV")
|
||||
|
||||
HookCallableType: typing.TypeAlias = typing.Callable[
|
||||
[_HV],
|
||||
typing.Optional[_HV],
|
||||
]
|
||||
|
||||
HookType: typing.TypeAlias = typing.Union[
|
||||
typing.Dict[str, typing.List[HookCallableType[_HV]]],
|
||||
"LifeCycleHook[_HV]",
|
||||
]
|
||||
|
||||
AsyncHookCallableType: typing.TypeAlias = typing.Callable[
|
||||
[_HV],
|
||||
typing.Awaitable[typing.Optional[_HV]],
|
||||
]
|
||||
|
||||
AsyncHookType: typing.TypeAlias = typing.Union[
|
||||
typing.Dict[str, typing.List[typing.Union[HookCallableType[_HV], AsyncHookCallableType[_HV]]]],
|
||||
"AsyncLifeCycleHook[_HV]",
|
||||
]
|
||||
|
||||
CacheLayerAltSvcType: typing.TypeAlias = typing.MutableMapping[typing.Tuple[str, int], typing.Optional[typing.Tuple[str, int]]]
|
||||
|
||||
RetryType: typing.TypeAlias = typing.Union[bool, int, Retry]
|
||||
|
||||
ResolverType: typing.TypeAlias = typing.Union[
|
||||
str,
|
||||
ResolverDescription,
|
||||
BaseResolver,
|
||||
typing.List[str],
|
||||
typing.List[ResolverDescription],
|
||||
]
|
||||
|
||||
AsyncResolverType: typing.TypeAlias = typing.Union[
|
||||
str,
|
||||
AsyncResolverDescription,
|
||||
AsyncBaseResolver,
|
||||
typing.List[str],
|
||||
typing.List[AsyncResolverDescription],
|
||||
]
|
||||
|
||||
ASGIScope: typing.TypeAlias = typing.MutableMapping[str, typing.Any]
|
||||
ASGIMessage: typing.TypeAlias = typing.MutableMapping[str, typing.Any]
|
||||
ASGIReceive: typing.TypeAlias = typing.Callable[[], typing.Awaitable[ASGIMessage]]
|
||||
ASGISend: typing.TypeAlias = typing.Callable[[ASGIMessage], typing.Awaitable[None]]
|
||||
|
||||
ASGIApp: typing.TypeAlias = typing.Callable[[ASGIScope, ASGIReceive, ASGISend], typing.Awaitable[None]]
|
||||
|
||||
WSGIStartResponse: typing.TypeAlias = typing.Callable[
|
||||
[str, typing.List[typing.Tuple[str, str]], typing.Optional[typing.Any]],
|
||||
typing.Callable[[bytes], None],
|
||||
]
|
||||
WSGIApp: typing.TypeAlias = typing.Callable[
|
||||
[typing.Dict[str, typing.Any], WSGIStartResponse],
|
||||
typing.Iterable[bytes],
|
||||
]
|
||||
1475
.venv/lib/python3.9/site-packages/niquests/utils.py
Normal file
1475
.venv/lib/python3.9/site-packages/niquests/utils.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user