from __future__ import annotations import logging import typing import warnings from types import TracebackType from urllib.parse import urljoin from .._collections import HTTPHeaderDict from .._constant import DEFAULT_BLOCKSIZE from .._request_methods import AsyncRequestMethods from .._typing import _TYPE_BODY, _TYPE_BODY_POSITION, _TYPE_TIMEOUT, ProxyConfig from ..backend import HttpVersion, QuicPreemptiveCacheType, ResponsePromise from ..connectionpool import port_by_scheme from ..contrib.resolver import ProtocolResolver from ..contrib.resolver._async import ( AsyncBaseResolver, AsyncManyResolver, AsyncResolverDescription, ) from ..exceptions import ( LocationValueError, MaxRetryError, ProxySchemeUnknown, URLSchemeUnknown, ) from ..poolmanager import SSL_KEYWORDS, PoolKey, key_fn_by_scheme from ..util._async.traffic_police import AsyncTrafficPolice from ..util.proxy import connection_requires_http_tunnel from ..util.request import NOT_FORWARDABLE_HEADERS from ..util.retry import Retry from ..util.traffic_police import UnavailableTraffic from ..util.url import Url, parse_extension, parse_url from .connectionpool import AsyncHTTPConnectionPool, AsyncHTTPSConnectionPool from .response import AsyncHTTPResponse if typing.TYPE_CHECKING: import ssl from typing_extensions import Literal __all__ = ["AsyncPoolManager", "AsyncProxyManager", "proxy_from_url"] log = logging.getLogger(__name__) _SelfT = typing.TypeVar("_SelfT") pool_classes_by_scheme = { "http": AsyncHTTPConnectionPool, "https": AsyncHTTPSConnectionPool, } class AsyncPoolManager(AsyncRequestMethods): """ Allows for arbitrary async requests while transparently keeping track of necessary connection pools for you. :param num_pools: Number of connection pools to cache before discarding the least recently used pool. :param headers: Headers to include with all requests, unless other headers are given explicitly. :param \\**connection_pool_kw: Additional parameters are used to create fresh :class:`urllib3._async.connectionpool.AsyncConnectionPool` instances. Example: .. code-block:: python import urllib3 http = urllib3.AsyncPoolManager(num_pools=2) resp1 = await http.request("GET", "https://google.com/") resp2 = await http.request("GET", "https://google.com/mail") resp3 = await http.request("GET", "https://yahoo.com/") print(len(http.pools)) # 2 """ proxy: Url | None = None proxy_config: ProxyConfig | None = None def __init__( self, num_pools: int = 10, headers: typing.Mapping[str, str] | None = None, preemptive_quic_cache: QuicPreemptiveCacheType | None = None, resolver: AsyncResolverDescription | list[AsyncResolverDescription] | str | list[str] | AsyncBaseResolver | None = None, **connection_pool_kw: typing.Any, ) -> None: super().__init__(headers) # PoolManager handles redirects itself in PoolManager.urlopen(). # It always passes redirect=False to the underlying connection pool to # suppress per-pool redirect handling. If the user supplied a non-Retry # value (int/bool/etc) for retries and we let the pool normalize it # while redirect=False, the resulting Retry object would have redirect # handling disabled, which can interfere with PoolManager's own # redirect logic. Normalize here so redirects remain governed solely by # PoolManager logic. if "retries" in connection_pool_kw: retries = connection_pool_kw["retries"] if not isinstance(retries, Retry): retries = Retry.from_int(retries) connection_pool_kw = connection_pool_kw.copy() connection_pool_kw["retries"] = retries self.connection_pool_kw = connection_pool_kw self._num_pools = num_pools self.block = ( False if "block" not in connection_pool_kw else connection_pool_kw["block"] ) self.pools: AsyncTrafficPolice[AsyncHTTPConnectionPool] = AsyncTrafficPolice( num_pools, concurrency=True, strict_maxsize=not self.block, ) # Locally set the pool classes and keys so other PoolManagers can # override them. self.pool_classes_by_scheme = pool_classes_by_scheme self.key_fn_by_scheme = key_fn_by_scheme.copy() self._preemptive_quic_cache = preemptive_quic_cache self._own_resolver = not isinstance(resolver, AsyncBaseResolver) if resolver is None: resolver = [AsyncResolverDescription(ProtocolResolver.SYSTEM)] elif isinstance(resolver, str): resolver = [AsyncResolverDescription.from_url(resolver)] elif isinstance(resolver, AsyncResolverDescription): resolver = [resolver] self._resolvers: list[AsyncResolverDescription] = [] if not isinstance(resolver, AsyncBaseResolver): can_resolve_localhost: bool = False for resolver_description in resolver: if isinstance(resolver_description, str): self._resolvers.append( AsyncResolverDescription.from_url(resolver_description) ) if self._resolvers[-1].protocol == ProtocolResolver.SYSTEM: can_resolve_localhost = True continue self._resolvers.append(resolver_description) if self._resolvers[-1].protocol == ProtocolResolver.SYSTEM: can_resolve_localhost = True if not can_resolve_localhost: self._resolvers.append( AsyncResolverDescription.from_url( "system://default?hosts=localhost" ) ) #: We want to automatically forward ca_cert_data, ca_cert_dir, and ca_certs. for rd in self._resolvers: if "ca_cert_data" in connection_pool_kw: if "ca_cert_data" not in rd: rd["ca_cert_data"] = connection_pool_kw["ca_cert_data"] if "ca_cert_dir" in connection_pool_kw: if "ca_cert_dir" not in rd: rd["ca_cert_dir"] = connection_pool_kw["ca_cert_dir"] if "ca_certs" in connection_pool_kw: if "ca_certs" not in rd: rd["ca_certs"] = connection_pool_kw["ca_certs"] self._resolver: AsyncBaseResolver = ( AsyncManyResolver(*[r.new() for r in self._resolvers]) if not isinstance(resolver, AsyncBaseResolver) else resolver ) async def __aenter__(self: _SelfT) -> _SelfT: return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> Literal[False]: await self.clear() # Return False to re-raise any potential exceptions return False def _new_pool( self, scheme: str, host: str, port: int, request_context: dict[str, typing.Any] | None = None, ) -> AsyncHTTPConnectionPool: """ Create a new :class:`urllib3._async.connectionpool.AsyncConnectionPool` based on host, port, scheme, and any additional pool keyword arguments. If ``request_context`` is provided, it is provided as keyword arguments to the pool class used. This method is used to actually create the connection pools handed out by :meth:`connection_from_url` and companion methods. It is intended to be overridden for customization. """ pool_cls: type[AsyncHTTPConnectionPool] = self.pool_classes_by_scheme[scheme] if request_context is None: request_context = self.connection_pool_kw.copy() # Default blocksize to DEFAULT_BLOCKSIZE if missing or explicitly # set to 'None' in the request_context. if request_context.get("blocksize") is None: request_context["blocksize"] = DEFAULT_BLOCKSIZE # Although the context has everything necessary to create the pool, # this function has historically only used the scheme, host, and port # in the positional args. When an API change is acceptable these can # be removed. for key in ("scheme", "host", "port"): request_context.pop(key, None) if scheme == "http": for kw in SSL_KEYWORDS: request_context.pop(kw, None) request_context["preemptive_quic_cache"] = self._preemptive_quic_cache if not self._resolver.is_available(): self._resolver = self._resolver.recycle() request_context["resolver"] = self._resolver # By default, each HttpPool can have up to num_pools connections if "maxsize" not in request_context: request_context["maxsize"] = self._num_pools return pool_cls(host, port, **request_context) async def clear(self) -> None: """ Empty our store of pools and direct them all to close. This will not affect in-flight connections, but they will not be re-used after completion. """ await self.pools.clear() if self._own_resolver and self._resolver.is_available(): await self._resolver.close() async def connection_from_host( self, host: str | None, port: int | None = None, scheme: str | None = "http", pool_kwargs: dict[str, typing.Any] | None = None, ) -> AsyncHTTPConnectionPool: """ Get a :class:`urllib3._async.connectionpool.AsyncConnectionPool` based on the host, port, and scheme. If ``port`` isn't given, it will be derived from the ``scheme`` using ``urllib3.connectionpool.port_by_scheme``. If ``pool_kwargs`` is provided, it is merged with the instance's ``connection_pool_kw`` variable and used to create the new connection pool, if one is needed. """ if not host: raise LocationValueError("No host specified.") request_context = self._merge_pool_kwargs(pool_kwargs) request_context["scheme"] = scheme or "http" if not port: port = port_by_scheme.get(request_context["scheme"].lower()) request_context["port"] = port request_context["host"] = host return await self.connection_from_context(request_context) async def connection_from_context( self, request_context: dict[str, typing.Any] ) -> AsyncHTTPConnectionPool: """ Get a :class:`urllib3._async.connectionpool.AsyncConnectionPool` based on the request context. ``request_context`` must at least contain the ``scheme`` key and its value must be a key in ``key_fn_by_scheme`` instance variable. """ if "strict" in request_context: request_context.pop("strict") scheme = request_context["scheme"].lower() pool_key_constructor = self.key_fn_by_scheme.get(scheme) if pool_key_constructor is None: target_scheme, target_implementation = parse_extension(scheme) from ..contrib.webextensions._async import load_extension try: extension = load_extension(target_scheme, target_implementation) except ImportError: pass else: scheme = extension.scheme_to_http_scheme(target_scheme) pool_key_constructor = self.key_fn_by_scheme.get(scheme) request_context["scheme"] = scheme supported_svn = extension.supported_svn() disabled_svn = ( request_context["disabled_svn"] if "disabled_svn" in request_context else set() ) if len(extension.supported_svn()) != 3: if HttpVersion.h11 not in supported_svn: disabled_svn.add(HttpVersion.h11) if HttpVersion.h2 not in supported_svn: disabled_svn.add(HttpVersion.h2) if HttpVersion.h3 not in supported_svn: disabled_svn.add(HttpVersion.h3) request_context["disabled_svn"] = disabled_svn if not pool_key_constructor: raise URLSchemeUnknown(scheme) pool_key = pool_key_constructor(request_context) if self._preemptive_quic_cache is not None: request_context["preemptive_quic_cache"] = self._preemptive_quic_cache return await self.connection_from_pool_key( pool_key, request_context=request_context ) async def connection_from_pool_key( self, pool_key: PoolKey, request_context: dict[str, typing.Any] ) -> AsyncHTTPConnectionPool: """ Get a :class:`urllib3._async.connectionpool.AsyncConnectionPool` based on the provided pool key. ``pool_key`` should be a namedtuple that only contains immutable objects. At a minimum it must have the ``scheme``, ``host``, and ``port`` fields. """ # If the scheme, host, or port doesn't match existing open # connections, open a new ConnectionPool. if self.pools.busy: self.pools.release() async with self.pools.locate_or_hold( pool_key, block=True, ) as swapper_or_pool: if not hasattr(swapper_or_pool, "is_idle"): # Make a fresh ConnectionPool of the desired type scheme = request_context["scheme"] host = request_context["host"] port = request_context["port"] pool = self._new_pool( scheme, host, port, request_context=request_context ) await swapper_or_pool(pool) else: pool = swapper_or_pool # type: ignore[assignment] assert pool.pool is not None pool.pool.parent = self.pools return pool async def connection_from_url( self, url: str, pool_kwargs: dict[str, typing.Any] | None = None ) -> AsyncHTTPConnectionPool: """ Similar to :func:`urllib3.async_connection_from_url`. If ``pool_kwargs`` is not provided and a new pool needs to be constructed, ``self.connection_pool_kw`` is used to initialize the :class:`urllib3.connectionpool.ConnectionPool`. If ``pool_kwargs`` is provided, it is used instead. Note that if a new pool does not need to be created for the request, the provided ``pool_kwargs`` are not used. """ u = parse_url(url) return await self.connection_from_host( u.host, port=u.port, scheme=u.scheme, pool_kwargs=pool_kwargs ) def _merge_pool_kwargs( self, override: dict[str, typing.Any] | None ) -> dict[str, typing.Any]: """ Merge a dictionary of override values for self.connection_pool_kw. This does not modify self.connection_pool_kw and returns a new dict. Any keys in the override dictionary with a value of ``None`` are removed from the merged dictionary. """ base_pool_kwargs = self.connection_pool_kw.copy() if override: base_pool_kwargs.update( {k: v for k, v in override.items() if v is not None} ) return { k: v for k, v in base_pool_kwargs.items() if k not in override or override[k] is not None } return base_pool_kwargs def _proxy_requires_url_absolute_form(self, parsed_url: Url) -> bool: """ Indicates if the proxy requires the complete destination URL in the request. Normally this is only needed when not using an HTTP CONNECT tunnel. """ if self.proxy is None: return False return not connection_requires_http_tunnel( self.proxy, self.proxy_config, parsed_url.scheme ) async def get_response( self, *, promise: ResponsePromise | None = None ) -> AsyncHTTPResponse | None: """ Retrieve the first response available in the pools. This method should be called after issuing at least one request with ``multiplexed=True``. If none available, return None. """ if promise is not None and not isinstance(promise, ResponsePromise): raise TypeError( f"get_response only support ResponsePromise but received {type(promise)} instead. " f"This may occur if you expected the remote peer to support multiplexing but did not." ) try: async with self.pools.borrow( promise or ResponsePromise, block=False, not_idle_only=True, ) as pool: response = await pool.get_response(promise=promise) except UnavailableTraffic: return None if promise is not None and response is None: raise ValueError( "Invoked get_response with promise=... that no connections across pools recognize" ) if response is None: return None from_promise = None if promise: from_promise = promise else: if ( response._fp and hasattr(response._fp, "from_promise") and response._fp.from_promise ): from_promise = response._fp.from_promise if from_promise is None: raise ValueError( "Internal: Unable to identify originating ResponsePromise from a LowLevelResponse" ) self.pools.forget(from_promise) # Retrieve request ctx method = typing.cast(str, from_promise.get_parameter("method")) redirect = typing.cast(bool, from_promise.get_parameter("pm_redirect")) # Handle redirect? if redirect and response.get_redirect_location(): url = typing.cast(str, from_promise.get_parameter("pm_url")) body = typing.cast( typing.Union[_TYPE_BODY, None], from_promise.get_parameter("body") ) headers = typing.cast( typing.Union[HTTPHeaderDict, None], from_promise.get_parameter("headers"), ) preload_content = typing.cast( bool, from_promise.get_parameter("preload_content") ) decode_content = typing.cast( bool, from_promise.get_parameter("decode_content") ) timeout = typing.cast( typing.Union[_TYPE_TIMEOUT, None], from_promise.get_parameter("timeout") ) assert_same_host = typing.cast( bool, from_promise.get_parameter("assert_same_host") ) pool_timeout = from_promise.get_parameter("pool_timeout") response_kw = typing.cast( typing.MutableMapping[str, typing.Any], from_promise.get_parameter("response_kw"), ) chunked = typing.cast(bool, from_promise.get_parameter("chunked")) body_pos = typing.cast( _TYPE_BODY_POSITION, from_promise.get_parameter("body_pos") ) retries = typing.cast(Retry, from_promise.get_parameter("retries")) redirect_location = response.get_redirect_location() assert isinstance(redirect_location, str) if response.status == 303: method = "GET" body = None headers = HTTPHeaderDict(headers) for should_be_removed_header in NOT_FORWARDABLE_HEADERS: headers.discard(should_be_removed_header) try: retries = retries.increment( method, url, response=response, _pool=response._pool ) except MaxRetryError: if retries.raise_on_redirect: await response.drain_conn() raise return response await response.drain_conn() await retries.async_sleep_for_retry(response) log.debug("Redirecting %s -> %s", url, redirect_location) new_promise = await self.urlopen( method, urljoin(url, redirect_location), True, body=body, headers=headers, retries=retries, assert_same_host=assert_same_host, timeout=timeout, pool_timeout=pool_timeout, release_conn=True, chunked=chunked, body_pos=body_pos, preload_content=preload_content, decode_content=decode_content, multiplexed=True, **response_kw, ) return await self.get_response(promise=new_promise if promise else None) # Check if we should retry the HTTP response. has_retry_after = bool(response.headers.get("Retry-After")) retries = typing.cast(Retry, from_promise.get_parameter("retries")) if retries.is_retry(method, response.status, has_retry_after): url = typing.cast(str, from_promise.get_parameter("pm_url")) body = typing.cast( typing.Union[_TYPE_BODY, None], from_promise.get_parameter("body") ) headers = typing.cast( typing.Union[HTTPHeaderDict, None], from_promise.get_parameter("headers"), ) preload_content = typing.cast( bool, from_promise.get_parameter("preload_content") ) decode_content = typing.cast( bool, from_promise.get_parameter("decode_content") ) timeout = typing.cast( typing.Union[_TYPE_TIMEOUT, None], from_promise.get_parameter("timeout") ) assert_same_host = typing.cast( bool, from_promise.get_parameter("assert_same_host") ) pool_timeout = from_promise.get_parameter("pool_timeout") response_kw = typing.cast( typing.MutableMapping[str, typing.Any], from_promise.get_parameter("response_kw"), ) chunked = typing.cast(bool, from_promise.get_parameter("chunked")) body_pos = typing.cast( _TYPE_BODY_POSITION, from_promise.get_parameter("body_pos") ) redirect_location = response.get_redirect_location() assert isinstance(redirect_location, str) try: retries = retries.increment( method, url, response=response, _pool=response._pool ) except MaxRetryError: if retries.raise_on_status: await response.drain_conn() raise return response await response.drain_conn() await retries.async_sleep(response) log.debug("Retry: %s", url) new_promise = await self.urlopen( method, urljoin(url, redirect_location), True, body=body, headers=headers, retries=retries, assert_same_host=assert_same_host, timeout=timeout, pool_timeout=pool_timeout, release_conn=False, chunked=chunked, body_pos=body_pos, preload_content=preload_content, decode_content=decode_content, multiplexed=True, **response_kw, ) return await self.get_response(promise=new_promise if promise else None) extension = from_promise.get_parameter("extension") if extension is not None and response.extension is None: if response.status == 101 or ( 200 <= response.status < 300 and (method == "CONNECT" or extension is not None) ): if extension is None: from ..contrib.webextensions._async import load_extension extension = load_extension(None)() await response.start_extension(extension) return response @typing.overload # type: ignore[override] async def urlopen( self, method: str, url: str, redirect: bool = True, *, multiplexed: Literal[False] = ..., **kw: typing.Any, ) -> AsyncHTTPResponse: ... @typing.overload async def urlopen( self, method: str, url: str, redirect: bool = True, *, multiplexed: Literal[True], **kw: typing.Any, ) -> ResponsePromise: ... async def urlopen( self, method: str, url: str, redirect: bool = True, **kw: typing.Any ) -> AsyncHTTPResponse | ResponsePromise: """ Same as :meth:`urllib3.AsyncHTTPConnectionPool.urlopen` with custom cross-host redirect logic and only sends the request-uri portion of the ``url``. The given ``url`` parameter must be absolute, such that an appropriate :class:`urllib3._async.connectionpool.AsyncConnectionPool` can be chosen for it. """ u = parse_url(url) if u.scheme is None: warnings.warn( "URLs without a scheme (ie 'https://') are deprecated and will raise an error " "in a future version of urllib3. To avoid this DeprecationWarning ensure all URLs " "start with 'https://' or 'http://'. Read more in this issue: " "https://github.com/urllib3/urllib3/issues/2920", category=DeprecationWarning, stacklevel=2, ) # if we passed manually an extension to urlopen, we want to manually # disable svn if they are incompatible with said extension. pool_kwargs = None if "extension" in kw and kw["extension"] is not None: extension = kw["extension"] disabled_svn = set() pool_kwargs = {} if len(extension.supported_svn()) != 3: if HttpVersion.h11 not in extension.supported_svn(): disabled_svn.add(HttpVersion.h11) if HttpVersion.h2 not in extension.supported_svn(): disabled_svn.add(HttpVersion.h2) if HttpVersion.h3 not in extension.supported_svn(): disabled_svn.add(HttpVersion.h3) pool_kwargs["disabled_svn"] = disabled_svn conn = await self.connection_from_host( u.host, port=u.port, scheme=u.scheme, pool_kwargs=pool_kwargs ) if u.scheme is not None and u.scheme.lower() not in ("http", "https"): from ..contrib.webextensions._async import load_extension extension = load_extension(*parse_extension(u.scheme)) kw["extension"] = extension() kw.update(kw["extension"].urlopen_kwargs) kw["assert_same_host"] = False kw["redirect"] = False if "headers" not in kw: kw["headers"] = self.headers if self._proxy_requires_url_absolute_form(u): response = await conn.urlopen(method, url, **kw) else: response = await conn.urlopen(method, u.request_uri, **kw) self.pools.release() if "multiplexed" in kw and kw["multiplexed"]: if isinstance(response, ResponsePromise): response.set_parameter("pm_redirect", redirect) response.set_parameter("pm_url", url) assert isinstance(response, ResponsePromise) return response # the established connection is not capable of doing multiplexed request kw["multiplexed"] = False assert isinstance(response, AsyncHTTPResponse) redirect_location = redirect and response.get_redirect_location() if not redirect_location: return response # Support relative URLs for redirecting. redirect_location = urljoin(url, redirect_location) # RFC 7231, Section 6.4.4 if response.status == 303: method = "GET" kw["body"] = None kw["headers"] = HTTPHeaderDict(kw["headers"]) for should_be_removed_header in NOT_FORWARDABLE_HEADERS: kw["headers"].discard(should_be_removed_header) retries = kw.get("retries", response.retries) if not isinstance(retries, Retry): retries = Retry.from_int(retries, redirect=redirect) # Strip headers marked as unsafe to forward to the redirected location. # Check remove_headers_on_redirect to avoid a potential network call within # conn.is_same_host() which may use socket.gethostbyname() in the future. if retries.remove_headers_on_redirect and not conn.is_same_host( redirect_location ): new_headers = kw["headers"].copy() for header in kw["headers"]: if header.lower() in retries.remove_headers_on_redirect: new_headers.pop(header, None) kw["headers"] = new_headers try: retries = retries.increment(method, url, response=response, _pool=conn) except MaxRetryError: if retries.raise_on_redirect: await response.drain_conn() raise return response kw["retries"] = retries kw["redirect"] = redirect log.info("Redirecting %s -> %s", url, redirect_location) await response.drain_conn() return await self.urlopen(method, redirect_location, **kw) # type: ignore[no-any-return] def __repr__(self) -> str: inner_repr = "; ".join(repr(p) for p in self.pools._registry.values()) if inner_repr: inner_repr += " " return f"" class AsyncProxyManager(AsyncPoolManager): """ Behaves just like :class:`PoolManager`, but sends all requests through the defined proxy, using the CONNECT method for HTTPS URLs. :param proxy_url: The URL of the proxy to be used. :param proxy_headers: A dictionary containing headers that will be sent to the proxy. In case of HTTP they are being sent with each request, while in the HTTPS/CONNECT case they are sent only once. Could be used for proxy authentication. :param proxy_ssl_context: The proxy SSL context is used to establish the TLS connection to the proxy when using HTTPS proxies. :param use_forwarding_for_https: (Defaults to False) If set to True will forward requests to the HTTPS proxy to be made on behalf of the client instead of creating a TLS tunnel via the CONNECT method. **Enabling this flag means that request and response headers and content will be visible from the HTTPS proxy** whereas tunneling keeps request and response headers and content private. IP address, target hostname, SNI, and port are always visible to an HTTPS proxy even when this flag is disabled. :param proxy_assert_hostname: The hostname of the certificate to verify against. :param proxy_assert_fingerprint: The fingerprint of the certificate to verify against. Example: .. code-block:: python import urllib3 proxy = urllib3.AsyncProxyManager("https://localhost:3128/") resp1 = await proxy.request("GET", "https://google.com/") resp2 = await proxy.request("GET", "https://httpbin.org/") print(len(proxy.pools)) # 1 resp3 = await proxy.request("GET", "https://httpbin.org/") resp4 = await proxy.request("GET", "https://twitter.com/") print(len(proxy.pools)) # 3 """ def __init__( self, proxy_url: str, num_pools: int = 10, headers: typing.Mapping[str, str] | None = None, proxy_headers: typing.Mapping[str, str] | None = None, proxy_ssl_context: ssl.SSLContext | None = None, use_forwarding_for_https: bool = False, proxy_assert_hostname: None | str | Literal[False] = None, proxy_assert_fingerprint: str | None = None, **connection_pool_kw: typing.Any, ) -> None: if isinstance(proxy_url, AsyncHTTPConnectionPool): str_proxy_url = f"{proxy_url.scheme}://{proxy_url.host}:{proxy_url.port}" else: str_proxy_url = proxy_url proxy = parse_url(str_proxy_url) if proxy.scheme not in ("http", "https"): raise ProxySchemeUnknown(proxy.scheme) if not proxy.port: port = port_by_scheme.get(proxy.scheme, 80) proxy = proxy._replace(port=port) self.proxy = proxy self.proxy_headers = proxy_headers or {} self.proxy_ssl_context = proxy_ssl_context self.proxy_config = ProxyConfig( proxy_ssl_context, use_forwarding_for_https, proxy_assert_hostname, proxy_assert_fingerprint, ) connection_pool_kw["_proxy"] = self.proxy connection_pool_kw["_proxy_headers"] = self.proxy_headers connection_pool_kw["_proxy_config"] = self.proxy_config super().__init__(num_pools, headers, **connection_pool_kw) async def connection_from_host( self, host: str | None, port: int | None = None, scheme: str | None = "http", pool_kwargs: dict[str, typing.Any] | None = None, ) -> AsyncHTTPConnectionPool: if scheme == "https": return await super().connection_from_host( host, port, scheme, pool_kwargs=pool_kwargs ) assert self.proxy is not None return await super().connection_from_host( self.proxy.host, self.proxy.port, self.proxy.scheme, pool_kwargs=pool_kwargs, ) def _set_proxy_headers( self, url: str, headers: typing.Mapping[str, str] | None = None ) -> typing.Mapping[str, str]: """ Sets headers needed by proxies: specifically, the Accept and Host headers. Only sets headers not provided by the user. """ headers_ = {"Accept": "*/*"} netloc = parse_url(url).netloc if netloc: headers_["Host"] = netloc if headers: headers_.update(headers) return headers_ @typing.overload # type: ignore[override] async def urlopen( self, method: str, url: str, redirect: bool = True, *, multiplexed: Literal[False] = ..., **kw: typing.Any, ) -> AsyncHTTPResponse: ... @typing.overload async def urlopen( self, method: str, url: str, redirect: bool = True, *, multiplexed: Literal[True], **kw: typing.Any, ) -> ResponsePromise: ... async def urlopen( self, method: str, url: str, redirect: bool = True, **kw: typing.Any, ) -> AsyncHTTPResponse | ResponsePromise: "Same as HTTP(S)ConnectionPool.urlopen, ``url`` must be absolute." u = parse_url(url) if not connection_requires_http_tunnel(self.proxy, self.proxy_config, u.scheme): # For connections using HTTP CONNECT, httplib sets the necessary # headers on the CONNECT to the proxy. If we're not using CONNECT, # we'll definitely need to set 'Host' at the very least. headers = kw.get("headers", self.headers) kw["headers"] = self._set_proxy_headers(url, headers) return await super().urlopen(method, url, redirect=redirect, **kw) # type: ignore[no-any-return] def __repr__(self) -> str: inner_repr = "; ".join(repr(p) for p in self.pools._registry.values()) if inner_repr: inner_repr += " " return f"" def proxy_from_url(url: str, **kw: typing.Any) -> AsyncProxyManager: return AsyncProxyManager(proxy_url=url, **kw)