from __future__ import annotations import io import typing from base64 import b64encode from enum import Enum from ..exceptions import UnrewindableBodyError from .util import to_bytes, iscoroutinefunction from .response import BytesQueueBuffer if typing.TYPE_CHECKING: from typing_extensions import Final from .._typing import _TYPE_BODY_POSITION # Pass as a value within ``headers`` to skip # emitting some HTTP headers that are added automatically. # The only headers that are supported are ``Accept-Encoding``, # ``Host``, and ``User-Agent``. SKIP_HEADER = "@@@SKIP_HEADER@@@" SKIPPABLE_HEADERS = frozenset(["accept-encoding", "host", "user-agent"]) NOT_FORWARDABLE_HEADERS = frozenset( [ "content-encoding", "content-language", "content-location", "content-type", "content-length", "digest", "last-modified", ] ) ACCEPT_ENCODING = "gzip,deflate" try: try: import brotlicffi as _unused_module_brotli # type: ignore[import-not-found] # noqa: F401 except ImportError: import brotli as _unused_module_brotli # type: ignore[import-not-found] # noqa: F401 except ImportError: pass else: ACCEPT_ENCODING += ",br" try: import zstandard as _unused_module_zstd # noqa: F401 except ImportError: pass else: ACCEPT_ENCODING += ",zstd" class _TYPE_FAILEDTELL(Enum): token = 0 _FAILEDTELL: Final[_TYPE_FAILEDTELL] = _TYPE_FAILEDTELL.token # When sending a request with these methods we aren't expecting # a body so don't need to set an explicit 'Content-Length: 0' # The reason we do this in the negative instead of tracking methods # which 'should' have a body is because unknown methods should be # treated as if they were 'POST' which *does* expect a body. _METHODS_NOT_EXPECTING_BODY = {"GET", "HEAD", "DELETE", "TRACE", "OPTIONS", "CONNECT"} def make_headers( keep_alive: bool | None = None, accept_encoding: bool | list[str] | str | None = None, user_agent: str | None = None, basic_auth: str | None = None, proxy_basic_auth: str | None = None, disable_cache: bool | None = None, ) -> dict[str, str]: """ Shortcuts for generating request headers. :param keep_alive: If ``True``, adds 'connection: keep-alive' header. :param accept_encoding: Can be a boolean, list, or string. ``True`` translates to 'gzip,deflate'. If either the ``brotli`` or ``brotlicffi`` package is installed 'gzip,deflate,br' is used instead. List will get joined by comma. String will be used as provided. :param user_agent: String representing the user-agent you want, such as "python-urllib3/0.6" :param basic_auth: Colon-separated username:password string for 'authorization: basic ...' auth header. :param proxy_basic_auth: Colon-separated username:password string for 'proxy-authorization: basic ...' auth header. :param disable_cache: If ``True``, adds 'cache-control: no-cache' header. Example: .. code-block:: python import urllib3 print(urllib3.util.make_headers(keep_alive=True, user_agent="Batman/1.0")) # {'connection': 'keep-alive', 'user-agent': 'Batman/1.0'} print(urllib3.util.make_headers(accept_encoding=True)) # {'accept-encoding': 'gzip,deflate'} """ headers: dict[str, str] = {} if accept_encoding: if isinstance(accept_encoding, str): pass elif isinstance(accept_encoding, list): accept_encoding = ",".join(accept_encoding) else: accept_encoding = ACCEPT_ENCODING headers["accept-encoding"] = accept_encoding if user_agent: headers["user-agent"] = user_agent if keep_alive: headers["connection"] = "keep-alive" if basic_auth: headers["authorization"] = ( f"Basic {b64encode(basic_auth.encode('latin-1')).decode()}" ) if proxy_basic_auth: headers["proxy-authorization"] = ( f"Basic {b64encode(proxy_basic_auth.encode('latin-1')).decode()}" ) if disable_cache: headers["cache-control"] = "no-cache" return headers def is_async_file(body: typing.Any) -> bool: return iscoroutinefunction(getattr(body, "tell", lambda: None)) def set_file_position( body: typing.Any, pos: _TYPE_BODY_POSITION | None ) -> _TYPE_BODY_POSITION | None: """ If a position is provided, move file to that point. Otherwise, we'll attempt to record a position for future use. """ if pos is not None: rewind_body(body, pos) elif getattr(body, "tell", None) is not None: try: pos = body.tell() except OSError: # This differentiates from None, allowing us to catch # a failed `tell()` later when trying to rewind the body. pos = _FAILEDTELL return pos async def aset_file_position( body: typing.Any, pos: _TYPE_BODY_POSITION | None ) -> _TYPE_BODY_POSITION | None: """ Asynchronous variant for set_file_position when using anyio.AsyncFile for example. """ if pos is not None: await arewind_body(body, pos) elif getattr(body, "tell", None) is not None: try: pos = await body.tell() except OSError: # This differentiates from None, allowing us to catch # a failed `tell()` later when trying to rewind the body. pos = _FAILEDTELL return pos def rewind_body(body: typing.IO[typing.AnyStr], body_pos: _TYPE_BODY_POSITION) -> None: """ Attempt to rewind body to a certain position. Primarily used for request redirects and retries. :param body: File-like object that supports seek. :param int pos: Position to seek to in file. """ body_seek = getattr(body, "seek", None) if body_seek is not None and isinstance(body_pos, int): try: body_seek(body_pos) except OSError as e: raise UnrewindableBodyError( "An error occurred when rewinding request body for redirect/retry." ) from e elif body_pos is _FAILEDTELL: raise UnrewindableBodyError( "Unable to record file position for rewinding " "request body during a redirect/retry." ) else: raise ValueError( f"body_pos must be of type integer, instead it was {type(body_pos)}." ) async def arewind_body(body: typing.Any, body_pos: _TYPE_BODY_POSITION) -> None: """ Async counterpart of rewind_body. """ body_seek = getattr(body, "seek", None) if body_seek is not None and isinstance(body_pos, int): try: await body_seek(body_pos) except OSError as e: raise UnrewindableBodyError( "An error occurred when rewinding request body for redirect/retry." ) from e elif body_pos is _FAILEDTELL: raise UnrewindableBodyError( "Unable to record file position for rewinding " "request body during a redirect/retry." ) else: raise ValueError( f"body_pos must be of type integer, instead it was {type(body_pos)}." ) class ChunksAndContentLength(typing.NamedTuple): chunks: typing.Iterable[bytes] | typing.AsyncIterable[bytes] | None content_length: int | None is_string: bool def body_to_chunks( body: typing.Any | None, method: str, blocksize: int, force: bool = False, ) -> ChunksAndContentLength: """Takes the HTTP request method, body, and blocksize and transforms them into an iterable of chunks to pass to socket.sendall() and an optional 'Content-Length' header. A 'Content-Length' of 'None' indicates the length of the body can't be determined so should use 'Transfer-Encoding: chunked' for framing instead. """ chunks: typing.Iterable[bytes] | typing.AsyncIterable[bytes] | None content_length: int | None # No body, we need to make a recommendation on 'Content-Length' # based on whether that request method is expected to have # a body or not. if body is None: chunks = None if method.upper() not in _METHODS_NOT_EXPECTING_BODY: content_length = 0 else: content_length = None return ChunksAndContentLength( chunks=chunks, content_length=content_length, is_string=False, ) def chunk_readable() -> typing.Iterable[bytes]: nonlocal body, blocksize encode = isinstance(body, io.TextIOBase) while True: datablock = body.read(blocksize) # type: ignore[union-attr] if not datablock: break if encode: datablock = datablock.encode("utf-8") yield datablock async def chunk_areadable() -> typing.AsyncIterable[bytes]: nonlocal body, blocksize assert body is not None and hasattr(body, "__aiter__") encode: bool | None = None if hasattr(body, "read"): while True: block = await body.read(blocksize) if not block: break if encode is None: encode = isinstance(block, str) yield block.encode("utf-8") if encode else block else: # branch for: # async iterable protocol don't support read by blocksize # so we must buffer it. the best we have # at hand is the efficient BytesQueueBuffer. queue = BytesQueueBuffer() async for block in body: if encode is None: encode = isinstance(block, str) queue.put(block.encode("utf-8") if encode else block) yield queue.get(blocksize) while queue: yield queue.get(blocksize) # Bytes or strings become bytes if isinstance(body, (str, bytes)): converted = to_bytes(body) content_length = len(converted) if force and len(converted) >= blocksize: body = io.BytesIO(converted) chunks = chunk_readable() else: chunks = (converted,) # File-like object, TODO: use seek() and tell() for length? elif hasattr(body, "read") and not hasattr(body, "__aiter__"): chunks = chunk_readable() content_length = None elif hasattr(body, "__aiter__"): chunks = chunk_areadable() content_length = None # Otherwise we need to start checking via duck-typing. else: try: # Check if the body implements the buffer API. mv = memoryview(body) except TypeError: try: # Check if the body is an iterable chunks = iter(body) content_length = None except TypeError: raise TypeError( f"'body' must be a bytes-like object, file-like " f"object, or iterable. Instead was {body!r}" ) from None else: # Since it implements the buffer API can be passed directly to socket.sendall() chunks = (body,) content_length = mv.nbytes return ChunksAndContentLength( chunks=chunks, content_length=content_length, is_string=isinstance( body, ( str, io.StringIO, io.TextIOWrapper, ), ), )