fix: 포트 충돌 회피 — note_bridge 8098, intent_service 8099
Jellyfin(8096), OrbStack(8097) 포트 충돌으로 변경. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
42
.venv/lib/python3.9/site-packages/qh3/__init__.py
Normal file
42
.venv/lib/python3.9/site-packages/qh3/__init__.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from .asyncio import QuicConnectionProtocol, connect, serve
|
||||
from .h3 import events as h3_events
|
||||
from .h3.connection import H3Connection, ProtocolError
|
||||
from .h3.exceptions import H3Error, NoAvailablePushIDError
|
||||
from .quic import events as quic_events
|
||||
from .quic.configuration import QuicConfiguration
|
||||
from .quic.connection import QuicConnection, QuicConnectionError
|
||||
from .quic.logger import QuicFileLogger, QuicLogger
|
||||
from .quic.packet import QuicProtocolVersion
|
||||
from .tls import CipherSuite, SessionTicket
|
||||
|
||||
__version__ = "1.6.0"
|
||||
|
||||
__all__ = (
|
||||
"connect",
|
||||
"QuicConnectionProtocol",
|
||||
"serve",
|
||||
"h3_events",
|
||||
"H3Error",
|
||||
"H3Connection",
|
||||
"NoAvailablePushIDError",
|
||||
"quic_events",
|
||||
"QuicConfiguration",
|
||||
"QuicConnection",
|
||||
"QuicConnectionError",
|
||||
"QuicProtocolVersion",
|
||||
"QuicFileLogger",
|
||||
"QuicLogger",
|
||||
"ProtocolError",
|
||||
"CipherSuite",
|
||||
"SessionTicket",
|
||||
"__version__",
|
||||
)
|
||||
|
||||
# Attach a NullHandler to the top level logger by default
|
||||
# https://docs.python.org/3.3/howto/logging.html#configuring-logging-for-a-library
|
||||
logging.getLogger("quic").addHandler(logging.NullHandler())
|
||||
logging.getLogger("http3").addHandler(logging.NullHandler())
|
||||
7
.venv/lib/python3.9/site-packages/qh3/_compat.py
Normal file
7
.venv/lib/python3.9/site-packages/qh3/_compat.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
DATACLASS_KWARGS = {"slots": True} if sys.version_info >= (3, 10) else {}
|
||||
UINT_VAR_MAX = 0x3FFFFFFFFFFFFFFF
|
||||
UINT_VAR_MAX_SIZE = 8
|
||||
BIN
.venv/lib/python3.9/site-packages/qh3/_hazmat.abi3.so
Executable file
BIN
.venv/lib/python3.9/site-packages/qh3/_hazmat.abi3.so
Executable file
Binary file not shown.
361
.venv/lib/python3.9/site-packages/qh3/_hazmat.pyi
Normal file
361
.venv/lib/python3.9/site-packages/qh3/_hazmat.pyi
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
Everything within that module is off the semver guarantees.
|
||||
You use it, you deal with unexpected breakage. Anytime, anywhere.
|
||||
You'd be better off using cryptography directly.
|
||||
|
||||
This module serve exclusively qh3 interests. You have been warned.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Sequence
|
||||
|
||||
class DecompressionFailed(Exception): ...
|
||||
class DecoderStreamError(Exception): ...
|
||||
class EncoderStreamError(Exception): ...
|
||||
class StreamBlocked(Exception): ...
|
||||
|
||||
class QpackDecoder:
|
||||
def __init__(self, max_table_capacity: int, blocked_streams: int) -> None: ...
|
||||
def feed_encoder(self, data: bytes) -> None: ...
|
||||
def feed_header(
|
||||
self, stream_id: int, data: bytes
|
||||
) -> tuple[bytes, list[tuple[bytes, bytes]]]: ...
|
||||
def resume_header(
|
||||
self, stream_id: int
|
||||
) -> tuple[bytes, list[tuple[bytes, bytes]]]: ...
|
||||
|
||||
class QpackEncoder:
|
||||
def apply_settings(
|
||||
self, max_table_capacity: int, dyn_table_capacity: int, blocked_streams: int
|
||||
) -> bytes: ...
|
||||
def encode(
|
||||
self, stream_id: int, headers: list[tuple[bytes, bytes]]
|
||||
) -> tuple[bytes, bytes]: ...
|
||||
def feed_decoder(self, data: bytes) -> None: ...
|
||||
|
||||
class AeadChaCha20Poly1305:
|
||||
def __init__(self, key: bytes, iv: bytes) -> None: ...
|
||||
def encrypt(
|
||||
self, packet_number: int, data: bytes, associated_data: bytes
|
||||
) -> bytes: ...
|
||||
def decrypt(
|
||||
self, packet_number: int, data: bytes, associated_data: bytes
|
||||
) -> bytes: ...
|
||||
|
||||
class AeadAes256Gcm:
|
||||
def __init__(self, key: bytes, iv: bytes) -> None: ...
|
||||
def encrypt(
|
||||
self, packet_number: int, data: bytes, associated_data: bytes
|
||||
) -> bytes: ...
|
||||
def decrypt(
|
||||
self, packet_number: int, data: bytes, associated_data: bytes
|
||||
) -> bytes: ...
|
||||
|
||||
class AeadAes128Gcm:
|
||||
def __init__(self, key: bytes, iv: bytes) -> None: ...
|
||||
def encrypt(
|
||||
self, packet_number: int, data: bytes, associated_data: bytes
|
||||
) -> bytes: ...
|
||||
def encrypt_with_nonce(
|
||||
self, nonce: bytes, data: bytes, associated_data: bytes
|
||||
) -> bytes: ...
|
||||
def decrypt(
|
||||
self, packet_number: int, data: bytes, associated_data: bytes
|
||||
) -> bytes: ...
|
||||
|
||||
class ServerVerifier:
|
||||
def __init__(self, authorities: list[bytes]) -> None: ...
|
||||
def verify(
|
||||
self,
|
||||
peer: bytes,
|
||||
intermediaries: list[bytes],
|
||||
server_name: str,
|
||||
ocsp_response: bytes,
|
||||
) -> None: ...
|
||||
|
||||
class TlsCertUsage(Enum):
|
||||
ServerAuth = 0
|
||||
ClientAuth = 1
|
||||
Both = 2
|
||||
Other = 3
|
||||
|
||||
class Certificate:
|
||||
"""
|
||||
A (very) straightforward class to expose a parsed X509 certificate.
|
||||
This is hazardous material, nothing in there is guaranteed to
|
||||
remain backward compatible.
|
||||
|
||||
Use with care...
|
||||
"""
|
||||
|
||||
def __init__(self, certificate_der: bytes) -> None: ...
|
||||
@property
|
||||
def subject(self):
|
||||
list[tuple[str, str, bytes]]
|
||||
@property
|
||||
def issuer(self):
|
||||
list[tuple[str, str, bytes]]
|
||||
@property
|
||||
def not_valid_after(self) -> int: ...
|
||||
@property
|
||||
def not_valid_before(self) -> int: ...
|
||||
@property
|
||||
def serial_number(self) -> str: ...
|
||||
def get_extension_for_oid(self, oid: str) -> list[tuple[str, bool, bytes]]: ...
|
||||
@property
|
||||
def version(self) -> int: ...
|
||||
def get_ocsp_endpoints(self) -> list[bytes]: ...
|
||||
def get_crl_endpoints(self) -> list[bytes]: ...
|
||||
def get_issuer_endpoints(self) -> list[bytes]: ...
|
||||
def get_subject_alt_names(self) -> list[bytes]: ...
|
||||
def public_bytes(self) -> bytes: ...
|
||||
def public_key(self) -> bytes: ...
|
||||
@property
|
||||
def self_signed(self) -> bool: ...
|
||||
@property
|
||||
def is_ca(self) -> bool: ...
|
||||
@property
|
||||
def usage(self) -> TlsCertUsage: ...
|
||||
def serialize(self) -> bytes: ...
|
||||
@staticmethod
|
||||
def deserialize(src: bytes) -> Certificate: ...
|
||||
|
||||
class Rsa:
|
||||
"""
|
||||
This binding host a RSA Private/Public Keys.
|
||||
Use Oaep (padding) + SHA256 under. Not customizable.
|
||||
"""
|
||||
|
||||
def __init__(self, key_size: int) -> None: ...
|
||||
def encrypt(self, data: bytes) -> bytes: ...
|
||||
def decrypt(self, data: bytes) -> bytes: ...
|
||||
|
||||
class EcPrivateKey:
|
||||
def __init__(self, der_key: bytes, curve_type: int, is_pkcs8: bool) -> None: ...
|
||||
def public_key(self) -> bytes: ...
|
||||
def sign(self, data: bytes) -> bytes: ...
|
||||
@property
|
||||
def curve_type(self) -> int: ...
|
||||
|
||||
class Ed25519PrivateKey:
|
||||
def __init__(self, pkcs8: bytes) -> None: ...
|
||||
def public_key(self) -> bytes: ...
|
||||
def sign(self, data: bytes) -> bytes: ...
|
||||
|
||||
class DsaPrivateKey:
|
||||
def __init__(self, pkcs8: bytes) -> None: ...
|
||||
def public_key(self) -> bytes: ...
|
||||
def sign(self, data: bytes) -> bytes: ...
|
||||
|
||||
class RsaPrivateKey:
|
||||
def __init__(self, pkcs8: bytes) -> None: ...
|
||||
def public_key(self) -> bytes: ...
|
||||
def sign(self, data: bytes, padding, hash_size: int) -> bytes: ...
|
||||
|
||||
def verify_with_public_key(
|
||||
public_key_raw: bytes, algorithm: int, message: bytes, signature: bytes
|
||||
) -> None: ...
|
||||
|
||||
class X25519ML768KeyExchange:
|
||||
def __init__(self) -> None: ...
|
||||
def public_key(self) -> bytes: ...
|
||||
def exchange(self, peer_public_key: bytes) -> bytes: ...
|
||||
def shared_ciphertext(self) -> bytes: ...
|
||||
|
||||
class X25519KeyExchange:
|
||||
def __init__(self) -> None: ...
|
||||
def public_key(self) -> bytes: ...
|
||||
def exchange(self, peer_public_key: bytes) -> bytes: ...
|
||||
|
||||
class ECDHP256KeyExchange:
|
||||
def __init__(self) -> None: ...
|
||||
def public_key(self) -> bytes: ...
|
||||
def exchange(self, peer_public_key: bytes) -> bytes: ...
|
||||
|
||||
class ECDHP384KeyExchange:
|
||||
def __init__(self) -> None: ...
|
||||
def public_key(self) -> bytes: ...
|
||||
def exchange(self, peer_public_key: bytes) -> bytes: ...
|
||||
|
||||
class ECDHP521KeyExchange:
|
||||
def __init__(self) -> None: ...
|
||||
def public_key(self) -> bytes: ...
|
||||
def exchange(self, peer_public_key: bytes) -> bytes: ...
|
||||
|
||||
class CryptoError(Exception): ...
|
||||
|
||||
class KeyType(Enum):
|
||||
ECDSA_P256 = 0
|
||||
ECDSA_P384 = 1
|
||||
ECDSA_P521 = 2
|
||||
ED25519 = 3
|
||||
DSA = 4
|
||||
RSA = 5
|
||||
|
||||
class PrivateKeyInfo:
|
||||
"""
|
||||
Load a PEM private key and extract valuable info from it.
|
||||
Does two things, provide a DER encoded key and hint
|
||||
toward its nature (eg. EC, RSA, DSA, etc...)
|
||||
"""
|
||||
|
||||
def __init__(self, raw_pem_content: bytes, password: bytes | None) -> None: ...
|
||||
def public_bytes(self) -> bytes: ...
|
||||
def get_type(self) -> KeyType: ...
|
||||
|
||||
class SelfSignedCertificateError(Exception): ...
|
||||
class InvalidNameCertificateError(Exception): ...
|
||||
class ExpiredCertificateError(Exception): ...
|
||||
class UnacceptableCertificateError(Exception): ...
|
||||
class SignatureError(Exception): ...
|
||||
|
||||
class QUICHeaderProtection:
|
||||
def __init__(self, algorithm: str, key: bytes) -> None: ...
|
||||
def apply(self, plain_header: bytes, protected_payload: bytes) -> bytes: ...
|
||||
def remove(self, packet: bytes, pn_offset: int) -> tuple[bytes, int]: ...
|
||||
def mask(self, sample: bytes) -> bytes: ...
|
||||
|
||||
class ReasonFlags(Enum):
|
||||
unspecified = 0
|
||||
key_compromise = 1
|
||||
ca_compromise = 2
|
||||
affiliation_changed = 3
|
||||
superseded = 4
|
||||
cessation_of_operation = 5
|
||||
certificate_hold = 6
|
||||
privilege_withdrawn = 9
|
||||
aa_compromise = 10
|
||||
remove_from_crl = 8
|
||||
|
||||
class OCSPResponseStatus(Enum):
|
||||
SUCCESSFUL = 0
|
||||
MALFORMED_REQUEST = 1
|
||||
INTERNAL_ERROR = 2
|
||||
TRY_LATER = 3
|
||||
SIG_REQUIRED = 5
|
||||
UNAUTHORIZED = 6
|
||||
|
||||
class OCSPCertStatus(Enum):
|
||||
GOOD = 0
|
||||
REVOKED = 1
|
||||
UNKNOWN = 2
|
||||
|
||||
class OCSPResponse:
|
||||
def __init__(self, raw_response: bytes) -> None: ...
|
||||
@property
|
||||
def next_update(self) -> int: ...
|
||||
@property
|
||||
def response_status(self) -> OCSPResponseStatus: ...
|
||||
@property
|
||||
def certificate_status(self) -> OCSPCertStatus: ...
|
||||
@property
|
||||
def revocation_reason(self) -> ReasonFlags | None: ...
|
||||
def serialize(self) -> bytes: ...
|
||||
@staticmethod
|
||||
def deserialize(src: bytes) -> OCSPResponse: ...
|
||||
def authenticate_for(self, issuer_der: bytes) -> bool: ...
|
||||
|
||||
class OCSPRequest:
|
||||
def __init__(self, peer_certificate: bytes, issuer_certificate: bytes) -> None: ...
|
||||
def public_bytes(self) -> bytes: ...
|
||||
|
||||
class BufferReadError(ValueError): ...
|
||||
class BufferWriteError(ValueError): ...
|
||||
|
||||
class Buffer:
|
||||
def __init__(self, capacity: int = 0, data: bytes | None = None) -> None: ...
|
||||
@property
|
||||
def capacity(self) -> int: ...
|
||||
@property
|
||||
def data(self) -> bytes: ...
|
||||
def data_slice(self, start: int, end: int) -> bytes: ...
|
||||
def eof(self) -> bool: ...
|
||||
def seek(self, pos: int) -> None: ...
|
||||
def tell(self) -> int: ...
|
||||
def pull_bytes(self, length: int) -> bytes: ...
|
||||
def pull_uint8(self) -> int: ...
|
||||
def pull_uint16(self) -> int: ...
|
||||
def pull_uint24(self) -> int: ... # only for OCSP resp parsing!
|
||||
def pull_uint32(self) -> int: ...
|
||||
def pull_uint64(self) -> int: ...
|
||||
def pull_uint_var(self) -> int: ...
|
||||
def push_bytes(self, value: bytes) -> None: ...
|
||||
def push_uint8(self, value: int) -> None: ...
|
||||
def push_uint16(self, value: int) -> None: ...
|
||||
def push_uint32(self, value: int) -> None: ...
|
||||
def push_uint64(self, value: int) -> None: ...
|
||||
def push_uint_var(self, value: int) -> None: ...
|
||||
|
||||
def encode_uint_var(value: int) -> bytes:
|
||||
"""
|
||||
Encode a variable-length unsigned integer.
|
||||
"""
|
||||
|
||||
def size_uint_var(value: int) -> int:
|
||||
"""
|
||||
Return the number of bytes required to encode the given value
|
||||
as a QUIC variable-length unsigned integer.
|
||||
"""
|
||||
|
||||
def idna_encode(text: str) -> bytes:
|
||||
"""using UTS46"""
|
||||
|
||||
def idna_decode(src: bytes) -> str:
|
||||
"""using UTS46"""
|
||||
|
||||
class RangeSet(Sequence):
|
||||
def __init__(self) -> None: ...
|
||||
def add(self, start: int, stop: int | None = None) -> None: ...
|
||||
def bounds(self) -> tuple[int, int]: ...
|
||||
def shift(self) -> tuple[int, int]: ...
|
||||
def subtract(self, start: int, stop: int) -> None: ...
|
||||
def __contains__(self, val: object) -> bool: ...
|
||||
def __eq__(self, other: Any) -> bool: ...
|
||||
def __getitem__(self, key: Any) -> tuple[int, int]: ...
|
||||
def __len__(self) -> int: ...
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
def decode_packet_number(truncated: int, num_bits: int, expected: int) -> int:
|
||||
"""
|
||||
Recover a packet number from a truncated packet number.
|
||||
|
||||
See: Appendix A - Sample Packet Number Decoding Algorithm
|
||||
"""
|
||||
|
||||
class QuicPacketPacer:
|
||||
def __init__(self, max_datagram_size: int) -> None: ...
|
||||
def next_send_time(self, now: float) -> float | None: ...
|
||||
def update_after_send(self, now: float) -> None: ...
|
||||
def update_bucket(self, now: float) -> None: ...
|
||||
def update_rate(self, congestion_window: int, smoothed_rtt: float) -> None: ...
|
||||
|
||||
class QuicRttMonitor:
|
||||
"""
|
||||
Roundtrip time monitor for HyStart.
|
||||
"""
|
||||
def add_rtt(self, rtt: float) -> None: ...
|
||||
def is_rtt_increasing(self, rtt: float, now: float) -> bool: ...
|
||||
|
||||
class RevokedCertificate:
|
||||
serial_number: str
|
||||
reason: ReasonFlags
|
||||
expired_at: int
|
||||
|
||||
class CertificateRevocationList:
|
||||
def __init__(self, crl: bytes) -> None: ...
|
||||
def is_revoked(self, serial_number: str) -> RevokedCertificate | None: ...
|
||||
def serialize(self) -> bytes: ...
|
||||
@staticmethod
|
||||
def deserialize(src: bytes) -> CertificateRevocationList: ...
|
||||
def __len__(self) -> int: ...
|
||||
@property
|
||||
def issuer(self) -> str: ...
|
||||
@property
|
||||
def last_updated_at(self) -> int: ...
|
||||
@property
|
||||
def next_update_at(self) -> int: ...
|
||||
def authenticate_for(self, issuer_der: bytes) -> bool: ...
|
||||
|
||||
def rebuild_chain(leaf: bytes, intermediates: list[bytes]) -> list[bytes]: ...
|
||||
@@ -0,0 +1,3 @@
|
||||
from .client import connect # noqa
|
||||
from .protocol import QuicConnectionProtocol # noqa
|
||||
from .server import serve # noqa
|
||||
104
.venv/lib/python3.9/site-packages/qh3/asyncio/client.py
Normal file
104
.venv/lib/python3.9/site-packages/qh3/asyncio/client.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import socket
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator, Callable, cast
|
||||
|
||||
from ..quic.configuration import QuicConfiguration
|
||||
from ..quic.connection import QuicConnection
|
||||
from ..tls import SessionTicketHandler
|
||||
from .protocol import QuicConnectionProtocol, QuicStreamHandler
|
||||
|
||||
__all__ = ["connect"]
|
||||
|
||||
# keep compatibility for Python 3.7 on Windows
|
||||
if not hasattr(socket, "IPPROTO_IPV6"):
|
||||
socket.IPPROTO_IPV6 = 41
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect(
|
||||
host: str,
|
||||
port: int,
|
||||
*,
|
||||
configuration: QuicConfiguration | None = None,
|
||||
create_protocol: Callable | None = QuicConnectionProtocol,
|
||||
session_ticket_handler: SessionTicketHandler | None = None,
|
||||
stream_handler: QuicStreamHandler | None = None,
|
||||
wait_connected: bool = True,
|
||||
local_port: int = 0,
|
||||
) -> AsyncGenerator[QuicConnectionProtocol]:
|
||||
"""
|
||||
Connect to a QUIC server at the given `host` and `port`.
|
||||
|
||||
:meth:`connect()` returns an awaitable. Awaiting it yields a
|
||||
:class:`~qh3.asyncio.QuicConnectionProtocol` which can be used to
|
||||
create streams.
|
||||
|
||||
:func:`connect` also accepts the following optional arguments:
|
||||
|
||||
* ``configuration`` is a :class:`~qh3.quic.configuration.QuicConfiguration`
|
||||
configuration object.
|
||||
* ``create_protocol`` allows customizing the :class:`~asyncio.Protocol` that
|
||||
manages the connection. It should be a callable or class accepting the same
|
||||
arguments as :class:`~qh3.asyncio.QuicConnectionProtocol` and returning
|
||||
an instance of :class:`~qh3.asyncio.QuicConnectionProtocol` or a subclass.
|
||||
* ``session_ticket_handler`` is a callback which is invoked by the TLS
|
||||
engine when a new session ticket is received.
|
||||
* ``stream_handler`` is a callback which is invoked whenever a stream is
|
||||
created. It must accept two arguments: a :class:`asyncio.StreamReader`
|
||||
and a :class:`asyncio.StreamWriter`.
|
||||
* ``local_port`` is the UDP port number that this client wants to bind.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
local_host = "::"
|
||||
|
||||
# if host is not an IP address, pass it to enable SNI
|
||||
try:
|
||||
ipaddress.ip_address(host)
|
||||
server_name = None
|
||||
except ValueError:
|
||||
server_name = host
|
||||
|
||||
# lookup remote address
|
||||
infos = await loop.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
|
||||
addr = infos[0][4]
|
||||
if len(addr) == 2:
|
||||
addr = ("::ffff:" + addr[0], addr[1], 0, 0)
|
||||
|
||||
# prepare QUIC connection
|
||||
if configuration is None:
|
||||
configuration = QuicConfiguration(is_client=True)
|
||||
if configuration.server_name is None:
|
||||
configuration.server_name = server_name
|
||||
connection = QuicConnection(
|
||||
configuration=configuration, session_ticket_handler=session_ticket_handler
|
||||
)
|
||||
|
||||
# explicitly enable IPv4/IPv6 dual stack
|
||||
sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
|
||||
completed = False
|
||||
try:
|
||||
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
|
||||
sock.bind((local_host, local_port, 0, 0))
|
||||
completed = True
|
||||
finally:
|
||||
if not completed:
|
||||
sock.close()
|
||||
# connect
|
||||
transport, protocol = await loop.create_datagram_endpoint(
|
||||
lambda: create_protocol(connection, stream_handler=stream_handler),
|
||||
sock=sock,
|
||||
)
|
||||
protocol = cast(QuicConnectionProtocol, protocol)
|
||||
try:
|
||||
protocol.connect(addr)
|
||||
if wait_connected:
|
||||
await protocol.wait_connected()
|
||||
yield protocol
|
||||
finally:
|
||||
protocol.close()
|
||||
await protocol.wait_closed()
|
||||
transport.close()
|
||||
254
.venv/lib/python3.9/site-packages/qh3/asyncio/protocol.py
Normal file
254
.venv/lib/python3.9/site-packages/qh3/asyncio/protocol.py
Normal file
@@ -0,0 +1,254 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
from ..quic import events
|
||||
from ..quic.connection import NetworkAddress, QuicConnection
|
||||
|
||||
QuicConnectionIdHandler = Callable[[bytes], None]
|
||||
QuicStreamHandler = Callable[[asyncio.StreamReader, asyncio.StreamWriter], None]
|
||||
|
||||
|
||||
class QuicConnectionProtocol(asyncio.DatagramProtocol):
|
||||
def __init__(
|
||||
self, quic: QuicConnection, stream_handler: QuicStreamHandler | None = None
|
||||
):
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
self._closed = asyncio.Event()
|
||||
self._connected = False
|
||||
self._connected_waiter: asyncio.Future[None] | None = None
|
||||
self._loop = loop
|
||||
self._ping_waiters: dict[int, asyncio.Future[None]] = {}
|
||||
self._quic = quic
|
||||
self._stream_readers: dict[int, asyncio.StreamReader] = {}
|
||||
self._timer: asyncio.TimerHandle | None = None
|
||||
self._timer_at: float | None = None
|
||||
self._transmit_task: asyncio.Handle | None = None
|
||||
self._transport: asyncio.DatagramTransport | None = None
|
||||
|
||||
# callbacks
|
||||
self._connection_id_issued_handler: QuicConnectionIdHandler = lambda c: None
|
||||
self._connection_id_retired_handler: QuicConnectionIdHandler = lambda c: None
|
||||
self._connection_terminated_handler: Callable[[], None] = lambda: None
|
||||
if stream_handler is not None:
|
||||
self._stream_handler = stream_handler
|
||||
else:
|
||||
self._stream_handler = lambda r, w: None
|
||||
|
||||
def change_connection_id(self) -> None:
|
||||
"""
|
||||
Change the connection ID used to communicate with the peer.
|
||||
|
||||
The previous connection ID will be retired.
|
||||
"""
|
||||
self._quic.change_connection_id()
|
||||
self.transmit()
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close the connection.
|
||||
"""
|
||||
self._quic.close()
|
||||
self.transmit()
|
||||
|
||||
def connect(self, addr: NetworkAddress) -> None:
|
||||
"""
|
||||
Initiate the TLS handshake.
|
||||
|
||||
This method can only be called for clients and a single time.
|
||||
"""
|
||||
self._quic.connect(addr, now=self._loop.time())
|
||||
self.transmit()
|
||||
|
||||
async def create_stream(
|
||||
self, is_unidirectional: bool = False
|
||||
) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
|
||||
"""
|
||||
Create a QUIC stream and return a pair of (reader, writer) objects.
|
||||
|
||||
The returned reader and writer objects are instances of
|
||||
:class:`asyncio.StreamReader` and :class:`asyncio.StreamWriter` classes.
|
||||
"""
|
||||
stream_id = self._quic.get_next_available_stream_id(
|
||||
is_unidirectional=is_unidirectional
|
||||
)
|
||||
return self._create_stream(stream_id)
|
||||
|
||||
def request_key_update(self) -> None:
|
||||
"""
|
||||
Request an update of the encryption keys.
|
||||
"""
|
||||
self._quic.request_key_update()
|
||||
self.transmit()
|
||||
|
||||
async def ping(self) -> None:
|
||||
"""
|
||||
Ping the peer and wait for the response.
|
||||
"""
|
||||
waiter = self._loop.create_future()
|
||||
uid = id(waiter)
|
||||
self._ping_waiters[uid] = waiter
|
||||
self._quic.send_ping(uid)
|
||||
self.transmit()
|
||||
await asyncio.shield(waiter)
|
||||
|
||||
def transmit(self) -> None:
|
||||
"""
|
||||
Send pending datagrams to the peer and arm the timer if needed.
|
||||
"""
|
||||
self._transmit_task = None
|
||||
|
||||
# send datagrams
|
||||
for data, addr in self._quic.datagrams_to_send(now=self._loop.time()):
|
||||
self._transport.sendto(data, addr)
|
||||
|
||||
# re-arm timer
|
||||
timer_at = self._quic.get_timer()
|
||||
if self._timer is not None and self._timer_at != timer_at:
|
||||
self._timer.cancel()
|
||||
self._timer = None
|
||||
if self._timer is None and timer_at is not None:
|
||||
self._timer = self._loop.call_at(timer_at, self._handle_timer)
|
||||
self._timer_at = timer_at
|
||||
|
||||
async def wait_closed(self) -> None:
|
||||
"""
|
||||
Wait for the connection to be closed.
|
||||
"""
|
||||
await self._closed.wait()
|
||||
|
||||
async def wait_connected(self) -> None:
|
||||
"""
|
||||
Wait for the TLS handshake to complete.
|
||||
"""
|
||||
assert self._connected_waiter is None, "already awaiting connected"
|
||||
if not self._connected:
|
||||
self._connected_waiter = self._loop.create_future()
|
||||
await asyncio.shield(self._connected_waiter)
|
||||
|
||||
# asyncio.Transport
|
||||
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
self._transport = cast(asyncio.DatagramTransport, transport)
|
||||
|
||||
def datagram_received(self, data: bytes | str, addr: NetworkAddress) -> None:
|
||||
self._quic.receive_datagram(cast(bytes, data), addr, now=self._loop.time())
|
||||
self._process_events()
|
||||
self.transmit()
|
||||
|
||||
# overridable
|
||||
|
||||
def quic_event_received(self, event: events.QuicEvent) -> None:
|
||||
"""
|
||||
Called when a QUIC event is received.
|
||||
|
||||
Reimplement this in your subclass to handle the events.
|
||||
"""
|
||||
# FIXME: move this to a subclass
|
||||
if isinstance(event, events.ConnectionTerminated):
|
||||
for reader in self._stream_readers.values():
|
||||
reader.feed_eof()
|
||||
elif isinstance(event, events.StreamDataReceived):
|
||||
reader = self._stream_readers.get(event.stream_id, None)
|
||||
if reader is None:
|
||||
reader, writer = self._create_stream(event.stream_id)
|
||||
self._stream_handler(reader, writer)
|
||||
reader.feed_data(event.data)
|
||||
if event.end_stream:
|
||||
reader.feed_eof()
|
||||
|
||||
# private
|
||||
|
||||
def _create_stream(
|
||||
self, stream_id: int
|
||||
) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
|
||||
adapter = QuicStreamAdapter(self, stream_id)
|
||||
reader = asyncio.StreamReader()
|
||||
writer = asyncio.StreamWriter(adapter, None, reader, self._loop)
|
||||
self._stream_readers[stream_id] = reader
|
||||
return reader, writer
|
||||
|
||||
def _handle_timer(self) -> None:
|
||||
now = max(self._timer_at, self._loop.time())
|
||||
self._timer = None
|
||||
self._timer_at = None
|
||||
self._quic.handle_timer(now=now)
|
||||
self._process_events()
|
||||
self.transmit()
|
||||
|
||||
def _process_events(self) -> None:
|
||||
event = self._quic.next_event()
|
||||
while event is not None:
|
||||
if isinstance(event, events.ConnectionIdIssued):
|
||||
self._connection_id_issued_handler(event.connection_id)
|
||||
elif isinstance(event, events.ConnectionIdRetired):
|
||||
self._connection_id_retired_handler(event.connection_id)
|
||||
elif isinstance(event, events.ConnectionTerminated):
|
||||
self._connection_terminated_handler()
|
||||
|
||||
# abort connection waiter
|
||||
if self._connected_waiter is not None:
|
||||
waiter = self._connected_waiter
|
||||
self._connected_waiter = None
|
||||
waiter.set_exception(ConnectionError)
|
||||
|
||||
# abort ping waiters
|
||||
for waiter in self._ping_waiters.values():
|
||||
waiter.set_exception(ConnectionError)
|
||||
self._ping_waiters.clear()
|
||||
|
||||
self._closed.set()
|
||||
elif isinstance(event, events.HandshakeCompleted):
|
||||
if self._connected_waiter is not None:
|
||||
waiter = self._connected_waiter
|
||||
self._connected = True
|
||||
self._connected_waiter = None
|
||||
waiter.set_result(None)
|
||||
elif isinstance(event, events.PingAcknowledged):
|
||||
waiter = self._ping_waiters.pop(event.uid, None)
|
||||
if waiter is not None:
|
||||
waiter.set_result(None)
|
||||
self.quic_event_received(event)
|
||||
event = self._quic.next_event()
|
||||
|
||||
def _transmit_soon(self) -> None:
|
||||
if self._transmit_task is None:
|
||||
self._transmit_task = self._loop.call_soon(self.transmit)
|
||||
|
||||
|
||||
class QuicStreamAdapter(asyncio.Transport):
|
||||
def __init__(self, protocol: QuicConnectionProtocol, stream_id: int):
|
||||
super().__init__()
|
||||
|
||||
self.protocol = protocol
|
||||
self.stream_id = stream_id
|
||||
self._closing = False
|
||||
|
||||
def can_write_eof(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_extra_info(self, name: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Get information about the underlying QUIC stream.
|
||||
"""
|
||||
if name == "stream_id":
|
||||
return self.stream_id
|
||||
|
||||
def write(self, data) -> None:
|
||||
self.protocol._quic.send_stream_data(self.stream_id, data)
|
||||
self.protocol._transmit_soon()
|
||||
|
||||
def write_eof(self) -> None:
|
||||
if self._closing:
|
||||
return
|
||||
self._closing = True
|
||||
self.protocol._quic.send_stream_data(self.stream_id, b"", end_stream=True)
|
||||
self.protocol._transmit_soon()
|
||||
|
||||
def close(self) -> None:
|
||||
self.write_eof()
|
||||
|
||||
def is_closing(self) -> bool:
|
||||
return self._closing
|
||||
217
.venv/lib/python3.9/site-packages/qh3/asyncio/server.py
Normal file
217
.venv/lib/python3.9/site-packages/qh3/asyncio/server.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Callable, cast
|
||||
|
||||
from .._hazmat import Buffer
|
||||
from ..quic.configuration import QuicConfiguration
|
||||
from ..quic.connection import NetworkAddress, QuicConnection
|
||||
from ..quic.packet import (
|
||||
QuicPacketType,
|
||||
encode_quic_retry,
|
||||
encode_quic_version_negotiation,
|
||||
pull_quic_header,
|
||||
)
|
||||
from ..quic.retry import QuicRetryTokenHandler
|
||||
from ..tls import SessionTicketFetcher, SessionTicketHandler
|
||||
from .protocol import QuicConnectionProtocol, QuicStreamHandler
|
||||
|
||||
__all__ = ["serve"]
|
||||
|
||||
|
||||
class QuicServer(asyncio.DatagramProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
configuration: QuicConfiguration,
|
||||
create_protocol: Callable = QuicConnectionProtocol,
|
||||
session_ticket_fetcher: SessionTicketFetcher | None = None,
|
||||
session_ticket_handler: SessionTicketHandler | None = None,
|
||||
retry: bool = False,
|
||||
stream_handler: QuicStreamHandler | None = None,
|
||||
) -> None:
|
||||
self._configuration = configuration
|
||||
self._create_protocol = create_protocol
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._protocols: dict[bytes, QuicConnectionProtocol] = {}
|
||||
self._session_ticket_fetcher = session_ticket_fetcher
|
||||
self._session_ticket_handler = session_ticket_handler
|
||||
self._transport: asyncio.DatagramTransport | None = None
|
||||
|
||||
self._stream_handler = stream_handler
|
||||
|
||||
if retry:
|
||||
self._retry = QuicRetryTokenHandler()
|
||||
else:
|
||||
self._retry = None
|
||||
|
||||
def close(self):
|
||||
for protocol in set(self._protocols.values()):
|
||||
protocol.close()
|
||||
self._protocols.clear()
|
||||
self._transport.close()
|
||||
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
self._transport = cast(asyncio.DatagramTransport, transport)
|
||||
|
||||
def datagram_received(self, data: bytes | str, addr: NetworkAddress) -> None:
|
||||
data = cast(bytes, data)
|
||||
buf = Buffer(data=data)
|
||||
|
||||
try:
|
||||
header = pull_quic_header(
|
||||
buf, host_cid_length=self._configuration.connection_id_length
|
||||
)
|
||||
except ValueError:
|
||||
return
|
||||
|
||||
# version negotiation
|
||||
if (
|
||||
header.version is not None
|
||||
and header.version not in self._configuration.supported_versions
|
||||
):
|
||||
self._transport.sendto(
|
||||
encode_quic_version_negotiation(
|
||||
source_cid=header.destination_cid,
|
||||
destination_cid=header.source_cid,
|
||||
supported_versions=self._configuration.supported_versions,
|
||||
),
|
||||
addr,
|
||||
)
|
||||
return
|
||||
|
||||
protocol = self._protocols.get(header.destination_cid, None)
|
||||
original_destination_connection_id: bytes | None = None
|
||||
retry_source_connection_id: bytes | None = None
|
||||
if (
|
||||
protocol is None
|
||||
and len(data) >= 1200
|
||||
and header.packet_type == QuicPacketType.INITIAL
|
||||
):
|
||||
# retry
|
||||
if self._retry is not None:
|
||||
if not header.token:
|
||||
# create a retry token
|
||||
source_cid = os.urandom(8)
|
||||
self._transport.sendto(
|
||||
encode_quic_retry(
|
||||
version=header.version,
|
||||
source_cid=source_cid,
|
||||
destination_cid=header.source_cid,
|
||||
original_destination_cid=header.destination_cid,
|
||||
retry_token=self._retry.create_token(
|
||||
addr, header.destination_cid, source_cid
|
||||
),
|
||||
),
|
||||
addr,
|
||||
)
|
||||
return
|
||||
else:
|
||||
# validate retry token
|
||||
try:
|
||||
(
|
||||
original_destination_connection_id,
|
||||
retry_source_connection_id,
|
||||
) = self._retry.validate_token(addr, header.token)
|
||||
except ValueError:
|
||||
return
|
||||
else:
|
||||
original_destination_connection_id = header.destination_cid
|
||||
|
||||
# create new connection
|
||||
connection = QuicConnection(
|
||||
configuration=self._configuration,
|
||||
original_destination_connection_id=original_destination_connection_id,
|
||||
retry_source_connection_id=retry_source_connection_id,
|
||||
session_ticket_fetcher=self._session_ticket_fetcher,
|
||||
session_ticket_handler=self._session_ticket_handler,
|
||||
)
|
||||
protocol = self._create_protocol(
|
||||
connection, stream_handler=self._stream_handler
|
||||
)
|
||||
protocol.connection_made(self._transport)
|
||||
|
||||
# register callbacks
|
||||
protocol._connection_id_issued_handler = partial(
|
||||
self._connection_id_issued, protocol=protocol
|
||||
)
|
||||
protocol._connection_id_retired_handler = partial(
|
||||
self._connection_id_retired, protocol=protocol
|
||||
)
|
||||
protocol._connection_terminated_handler = partial(
|
||||
self._connection_terminated, protocol=protocol
|
||||
)
|
||||
|
||||
self._protocols[header.destination_cid] = protocol
|
||||
self._protocols[connection.host_cid] = protocol
|
||||
|
||||
if protocol is not None:
|
||||
protocol.datagram_received(data, addr)
|
||||
|
||||
def _connection_id_issued(self, cid: bytes, protocol: QuicConnectionProtocol):
|
||||
self._protocols[cid] = protocol
|
||||
|
||||
def _connection_id_retired(
|
||||
self, cid: bytes, protocol: QuicConnectionProtocol
|
||||
) -> None:
|
||||
assert self._protocols[cid] == protocol
|
||||
del self._protocols[cid]
|
||||
|
||||
def _connection_terminated(self, protocol: QuicConnectionProtocol):
|
||||
for cid, proto in list(self._protocols.items()):
|
||||
if proto == protocol:
|
||||
del self._protocols[cid]
|
||||
|
||||
|
||||
async def serve(
|
||||
host: str,
|
||||
port: int,
|
||||
*,
|
||||
configuration: QuicConfiguration,
|
||||
create_protocol: Callable = QuicConnectionProtocol,
|
||||
session_ticket_fetcher: SessionTicketFetcher | None = None,
|
||||
session_ticket_handler: SessionTicketHandler | None = None,
|
||||
retry: bool = False,
|
||||
stream_handler: QuicStreamHandler = None,
|
||||
) -> QuicServer:
|
||||
"""
|
||||
Start a QUIC server at the given `host` and `port`.
|
||||
|
||||
:func:`serve` requires a :class:`~qh3.quic.configuration.QuicConfiguration`
|
||||
containing TLS certificate and private key as the ``configuration`` argument.
|
||||
|
||||
:func:`serve` also accepts the following optional arguments:
|
||||
|
||||
* ``create_protocol`` allows customizing the :class:`~asyncio.Protocol` that
|
||||
manages the connection. It should be a callable or class accepting the same
|
||||
arguments as :class:`~qh3.asyncio.QuicConnectionProtocol` and returning
|
||||
an instance of :class:`~qh3.asyncio.QuicConnectionProtocol` or a subclass.
|
||||
* ``session_ticket_fetcher`` is a callback which is invoked by the TLS
|
||||
engine when a session ticket is presented by the peer. It should return
|
||||
the session ticket with the specified ID or `None` if it is not found.
|
||||
* ``session_ticket_handler`` is a callback which is invoked by the TLS
|
||||
engine when a new session ticket is issued. It should store the session
|
||||
ticket for future lookup.
|
||||
* ``retry`` specifies whether client addresses should be validated prior to
|
||||
the cryptographic handshake using a retry packet.
|
||||
* ``stream_handler`` is a callback which is invoked whenever a stream is
|
||||
created. It must accept two arguments: a :class:`asyncio.StreamReader`
|
||||
and a :class:`asyncio.StreamWriter`.
|
||||
"""
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
_, protocol = await loop.create_datagram_endpoint(
|
||||
lambda: QuicServer(
|
||||
configuration=configuration,
|
||||
create_protocol=create_protocol,
|
||||
session_ticket_fetcher=session_ticket_fetcher,
|
||||
session_ticket_handler=session_ticket_handler,
|
||||
retry=retry,
|
||||
stream_handler=stream_handler,
|
||||
),
|
||||
local_addr=(host, port),
|
||||
)
|
||||
return protocol
|
||||
1215
.venv/lib/python3.9/site-packages/qh3/h3/connection.py
Normal file
1215
.venv/lib/python3.9/site-packages/qh3/h3/connection.py
Normal file
File diff suppressed because it is too large
Load Diff
116
.venv/lib/python3.9/site-packages/qh3/h3/events.py
Normal file
116
.venv/lib/python3.9/site-packages/qh3/h3/events.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
|
||||
Headers = List[Tuple[bytes, bytes]]
|
||||
|
||||
|
||||
class H3Event:
|
||||
"""
|
||||
Base class for HTTP/3 events.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataReceived(H3Event):
|
||||
"""
|
||||
The DataReceived event is fired whenever data is received on a stream from
|
||||
the remote peer.
|
||||
"""
|
||||
|
||||
data: bytes
|
||||
"The data which was received."
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream the data was received for."
|
||||
|
||||
stream_ended: bool
|
||||
"Whether the STREAM frame had the FIN bit set."
|
||||
|
||||
push_id: int | None = None
|
||||
"The Push ID or `None` if this is not a push."
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatagramReceived(H3Event):
|
||||
"""
|
||||
The DatagramReceived is fired whenever a datagram is received from the
|
||||
the remote peer.
|
||||
"""
|
||||
|
||||
data: bytes
|
||||
"The data which was received."
|
||||
|
||||
flow_id: int
|
||||
"The ID of the flow the data was received for."
|
||||
|
||||
|
||||
@dataclass
|
||||
class InformationalHeadersReceived(H3Event):
|
||||
"""
|
||||
This event is fired whenever an informational response has been caught inflight!
|
||||
The stream cannot be ended there.
|
||||
"""
|
||||
|
||||
headers: Headers
|
||||
"The headers."
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream the headers were received for."
|
||||
|
||||
|
||||
@dataclass
|
||||
class HeadersReceived(H3Event):
|
||||
"""
|
||||
The HeadersReceived event is fired whenever headers are received.
|
||||
"""
|
||||
|
||||
headers: Headers
|
||||
"The headers."
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream the headers were received for."
|
||||
|
||||
stream_ended: bool
|
||||
"Whether the STREAM frame had the FIN bit set."
|
||||
|
||||
push_id: int | None = None
|
||||
"The Push ID or `None` if this is not a push."
|
||||
|
||||
|
||||
@dataclass
|
||||
class PushPromiseReceived(H3Event):
|
||||
"""
|
||||
The PushedStreamReceived event is fired whenever a pushed stream has been
|
||||
received from the remote peer.
|
||||
"""
|
||||
|
||||
headers: Headers
|
||||
"The request headers."
|
||||
|
||||
push_id: int
|
||||
"The Push ID of the push promise."
|
||||
|
||||
stream_id: int
|
||||
"The Stream ID of the stream that the push is related to."
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebTransportStreamDataReceived(H3Event):
|
||||
"""
|
||||
The WebTransportStreamDataReceived is fired whenever data is received
|
||||
for a WebTransport stream.
|
||||
"""
|
||||
|
||||
data: bytes
|
||||
"The data which was received."
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream the data was received for."
|
||||
|
||||
stream_ended: bool
|
||||
"Whether the STREAM frame had the FIN bit set."
|
||||
|
||||
session_id: int
|
||||
"The ID of the session the data was received for."
|
||||
13
.venv/lib/python3.9/site-packages/qh3/h3/exceptions.py
Normal file
13
.venv/lib/python3.9/site-packages/qh3/h3/exceptions.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class H3Error(Exception):
|
||||
"""
|
||||
Base class for HTTP/3 exceptions.
|
||||
"""
|
||||
|
||||
|
||||
class NoAvailablePushIDError(H3Error):
|
||||
"""
|
||||
There are no available push IDs left.
|
||||
"""
|
||||
1
.venv/lib/python3.9/site-packages/qh3/py.typed
Normal file
1
.venv/lib/python3.9/site-packages/qh3/py.typed
Normal file
@@ -0,0 +1 @@
|
||||
Marker
|
||||
191
.venv/lib/python3.9/site-packages/qh3/quic/configuration.py
Normal file
191
.venv/lib/python3.9/site-packages/qh3/quic/configuration.py
Normal file
@@ -0,0 +1,191 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from os import PathLike
|
||||
from re import split
|
||||
from typing import TYPE_CHECKING, TextIO
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._hazmat import Certificate as X509Certificate
|
||||
from .._hazmat import DsaPrivateKey, EcPrivateKey, Ed25519PrivateKey, RsaPrivateKey
|
||||
|
||||
from ..tls import (
|
||||
CipherSuite,
|
||||
SessionTicket,
|
||||
load_pem_private_key,
|
||||
load_pem_x509_certificates,
|
||||
)
|
||||
from .logger import QuicLogger
|
||||
from .packet import QuicProtocolVersion
|
||||
from .packet_builder import PACKET_MAX_SIZE
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuicConfiguration:
|
||||
"""
|
||||
A QUIC configuration.
|
||||
"""
|
||||
|
||||
alpn_protocols: list[str] | None = None
|
||||
"""
|
||||
A list of supported ALPN protocols.
|
||||
"""
|
||||
|
||||
connection_id_length: int = 8
|
||||
"""
|
||||
The length in bytes of local connection IDs.
|
||||
"""
|
||||
|
||||
idle_timeout: float = 60.0
|
||||
"""
|
||||
The idle timeout in seconds.
|
||||
|
||||
The connection is terminated if nothing is received for the given duration.
|
||||
"""
|
||||
|
||||
is_client: bool = True
|
||||
"""
|
||||
Whether this is the client side of the QUIC connection.
|
||||
"""
|
||||
|
||||
max_data: int = 1048576
|
||||
"""
|
||||
Connection-wide flow control limit.
|
||||
"""
|
||||
|
||||
max_datagram_size: int = PACKET_MAX_SIZE
|
||||
"""
|
||||
The maximum QUIC payload size in bytes to send, excluding UDP or IP overhead.
|
||||
"""
|
||||
|
||||
probe_datagram_size: bool = True
|
||||
"""
|
||||
Enable path MTU discovery. Client-only.
|
||||
"""
|
||||
|
||||
max_stream_data: int = 1048576
|
||||
"""
|
||||
Per-stream flow control limit.
|
||||
"""
|
||||
|
||||
quic_logger: QuicLogger | None = None
|
||||
"""
|
||||
The :class:`~qh3.quic.logger.QuicLogger` instance to log events to.
|
||||
"""
|
||||
|
||||
secrets_log_file: TextIO = None
|
||||
"""
|
||||
A file-like object in which to log traffic secrets.
|
||||
|
||||
This is useful to analyze traffic captures with Wireshark.
|
||||
"""
|
||||
|
||||
server_name: str | None = None
|
||||
"""
|
||||
The server name to send during the TLS handshake the Server Name Indication.
|
||||
|
||||
.. note:: This is only used by clients.
|
||||
"""
|
||||
|
||||
session_ticket: SessionTicket | None = None
|
||||
"""
|
||||
The TLS session ticket which should be used for session resumption.
|
||||
"""
|
||||
|
||||
hostname_checks_common_name: bool = False
|
||||
assert_fingerprint: str | None = None
|
||||
verify_hostname: bool = True
|
||||
|
||||
cadata: bytes | None = None
|
||||
cafile: str | None = None
|
||||
capath: str | None = None
|
||||
|
||||
certificate: X509Certificate | None = None
|
||||
certificate_chain: list[X509Certificate] = field(default_factory=list)
|
||||
|
||||
cipher_suites: list[CipherSuite] | None = None
|
||||
initial_rtt: float = 0.1
|
||||
|
||||
max_datagram_frame_size: int | None = None
|
||||
original_version: int | None = None
|
||||
|
||||
private_key: (
|
||||
EcPrivateKey | Ed25519PrivateKey | DsaPrivateKey | RsaPrivateKey | None
|
||||
) = None
|
||||
|
||||
quantum_readiness_test: bool = False
|
||||
supported_versions: list[int] = field(
|
||||
default_factory=lambda: [
|
||||
QuicProtocolVersion.VERSION_1,
|
||||
QuicProtocolVersion.VERSION_2,
|
||||
]
|
||||
)
|
||||
verify_mode: int | None = None
|
||||
|
||||
def load_cert_chain(
|
||||
self,
|
||||
certfile: str | bytes | PathLike,
|
||||
keyfile: str | bytes | PathLike | None = None,
|
||||
password: bytes | str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Load a private key and the corresponding certificate.
|
||||
"""
|
||||
|
||||
if isinstance(certfile, str):
|
||||
certfile = certfile.encode()
|
||||
elif isinstance(certfile, PathLike):
|
||||
certfile = str(certfile).encode()
|
||||
|
||||
if keyfile is not None:
|
||||
if isinstance(keyfile, str):
|
||||
keyfile = keyfile.encode()
|
||||
elif isinstance(keyfile, PathLike):
|
||||
keyfile = str(keyfile).encode()
|
||||
|
||||
# we either have the certificate or a file path in certfile/keyfile.
|
||||
if b"-----BEGIN" not in certfile:
|
||||
with open(certfile, "rb") as fp:
|
||||
certfile = fp.read()
|
||||
if keyfile is not None:
|
||||
with open(keyfile, "rb") as fp:
|
||||
keyfile = fp.read()
|
||||
|
||||
is_crlf = b"-----\r\n" in certfile
|
||||
boundary = (
|
||||
b"-----BEGIN PRIVATE KEY-----\n"
|
||||
if not is_crlf
|
||||
else b"-----BEGIN PRIVATE KEY-----\r\n"
|
||||
)
|
||||
chunks = split(b"\n" + boundary, certfile)
|
||||
|
||||
certificates = load_pem_x509_certificates(chunks[0])
|
||||
|
||||
if len(chunks) == 2:
|
||||
private_key = boundary + chunks[1]
|
||||
self.private_key = load_pem_private_key(private_key)
|
||||
|
||||
self.certificate = certificates[0]
|
||||
self.certificate_chain = certificates[1:]
|
||||
|
||||
if keyfile is not None:
|
||||
self.private_key = load_pem_private_key(
|
||||
keyfile,
|
||||
password=(
|
||||
password.encode("utf8") if isinstance(password, str) else password
|
||||
),
|
||||
)
|
||||
|
||||
def load_verify_locations(
|
||||
self,
|
||||
cafile: str | None = None,
|
||||
capath: str | None = None,
|
||||
cadata: bytes | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Load a set of "certification authority" (CA) certificates used to
|
||||
validate other peers' certificates.
|
||||
"""
|
||||
self.cafile = cafile
|
||||
self.capath = capath
|
||||
self.cadata = cadata
|
||||
3804
.venv/lib/python3.9/site-packages/qh3/quic/connection.py
Normal file
3804
.venv/lib/python3.9/site-packages/qh3/quic/connection.py
Normal file
File diff suppressed because it is too large
Load Diff
284
.venv/lib/python3.9/site-packages/qh3/quic/crypto.py
Normal file
284
.venv/lib/python3.9/site-packages/qh3/quic/crypto.py
Normal file
@@ -0,0 +1,284 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import binascii
|
||||
from typing import Callable
|
||||
|
||||
from .._hazmat import (
|
||||
AeadAes128Gcm,
|
||||
AeadAes256Gcm,
|
||||
AeadChaCha20Poly1305,
|
||||
CryptoError,
|
||||
decode_packet_number,
|
||||
)
|
||||
from .._hazmat import QUICHeaderProtection as HeaderProtection
|
||||
from ..tls import CipherSuite, cipher_suite_hash, hkdf_expand_label, hkdf_extract
|
||||
from .packet import (
|
||||
QuicProtocolVersion,
|
||||
is_long_header,
|
||||
)
|
||||
|
||||
CIPHER_SUITES = {
|
||||
CipherSuite.AES_128_GCM_SHA256: (b"aes-128-ecb", b"aes-128-gcm"),
|
||||
CipherSuite.CHACHA20_POLY1305_SHA256: (b"chacha20", b"chacha20-poly1305"),
|
||||
CipherSuite.AES_256_GCM_SHA384: (b"aes-256-ecb", b"aes-256-gcm"),
|
||||
}
|
||||
INITIAL_CIPHER_SUITE = CipherSuite.AES_128_GCM_SHA256
|
||||
INITIAL_SALT_VERSION_1 = binascii.unhexlify("38762cf7f55934b34d179ae6a4c80cadccbb7f0a")
|
||||
INITIAL_SALT_VERSION_2 = binascii.unhexlify("0dede3def700a6db819381be6e269dcbf9bd2ed9")
|
||||
SAMPLE_SIZE = 16
|
||||
|
||||
|
||||
Callback = Callable[[str], None]
|
||||
|
||||
|
||||
def NoCallback(trigger: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class KeyUnavailableError(CryptoError):
|
||||
pass
|
||||
|
||||
|
||||
def derive_key_iv_hp(
|
||||
*, cipher_suite: CipherSuite, secret: bytes, version: int
|
||||
) -> tuple[bytes, bytes, bytes]:
|
||||
algorithm = cipher_suite_hash(cipher_suite)
|
||||
|
||||
if cipher_suite in [
|
||||
CipherSuite.AES_256_GCM_SHA384,
|
||||
CipherSuite.CHACHA20_POLY1305_SHA256,
|
||||
]:
|
||||
key_size = 32
|
||||
else:
|
||||
key_size = 16
|
||||
|
||||
if version == QuicProtocolVersion.VERSION_2:
|
||||
return (
|
||||
hkdf_expand_label(algorithm, secret, b"quicv2 key", b"", key_size),
|
||||
hkdf_expand_label(algorithm, secret, b"quicv2 iv", b"", 12),
|
||||
hkdf_expand_label(algorithm, secret, b"quicv2 hp", b"", key_size),
|
||||
)
|
||||
else:
|
||||
return (
|
||||
hkdf_expand_label(algorithm, secret, b"quic key", b"", key_size),
|
||||
hkdf_expand_label(algorithm, secret, b"quic iv", b"", 12),
|
||||
hkdf_expand_label(algorithm, secret, b"quic hp", b"", key_size),
|
||||
)
|
||||
|
||||
|
||||
class CryptoContext:
|
||||
__slots__ = (
|
||||
"aead",
|
||||
"cipher_suite",
|
||||
"hp",
|
||||
"key_phase",
|
||||
"secret",
|
||||
"version",
|
||||
"_setup_cb",
|
||||
"_teardown_cb",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key_phase: int = 0,
|
||||
setup_cb: Callback = NoCallback,
|
||||
teardown_cb: Callback = NoCallback,
|
||||
) -> None:
|
||||
self.aead: AeadChaCha20Poly1305 | AeadAes128Gcm | AeadAes256Gcm | None = None
|
||||
self.cipher_suite: CipherSuite | None = None
|
||||
self.hp: HeaderProtection | None = None
|
||||
self.key_phase = key_phase
|
||||
self.secret: bytes | None = None
|
||||
self.version: int | None = None
|
||||
self._setup_cb = setup_cb
|
||||
self._teardown_cb = teardown_cb
|
||||
|
||||
def decrypt_packet(
|
||||
self, packet: bytes, encrypted_offset: int, expected_packet_number: int
|
||||
) -> tuple[bytes, bytes, int, bool]:
|
||||
if self.aead is None:
|
||||
raise KeyUnavailableError("Decryption key is not available")
|
||||
|
||||
# header protection
|
||||
plain_header, packet_number = self.hp.remove(packet, encrypted_offset)
|
||||
first_byte = plain_header[0]
|
||||
|
||||
# packet number
|
||||
pn_length = (first_byte & 0x03) + 1
|
||||
packet_number = decode_packet_number(
|
||||
packet_number, pn_length * 8, expected_packet_number
|
||||
)
|
||||
|
||||
# detect key phase change
|
||||
crypto = self
|
||||
if not is_long_header(first_byte):
|
||||
key_phase = (first_byte & 4) >> 2
|
||||
if key_phase != self.key_phase:
|
||||
crypto = next_key_phase(self)
|
||||
|
||||
# payload protection
|
||||
payload = crypto.aead.decrypt(
|
||||
packet_number, packet[len(plain_header) :], plain_header
|
||||
)
|
||||
|
||||
return plain_header, payload, packet_number, crypto != self
|
||||
|
||||
def encrypt_packet(
|
||||
self, plain_header: bytes, plain_payload: bytes, packet_number: int
|
||||
) -> bytes:
|
||||
assert self.is_valid(), "Encryption key is not available"
|
||||
|
||||
# payload protection
|
||||
protected_payload = self.aead.encrypt(
|
||||
packet_number, plain_payload, plain_header
|
||||
)
|
||||
|
||||
# header protection
|
||||
return self.hp.apply(plain_header, protected_payload)
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
return self.aead is not None
|
||||
|
||||
def setup(self, *, cipher_suite: CipherSuite, secret: bytes, version: int) -> None:
|
||||
hp_cipher_name, aead_cipher_name = CIPHER_SUITES[cipher_suite]
|
||||
|
||||
key, iv, hp = derive_key_iv_hp(
|
||||
cipher_suite=cipher_suite,
|
||||
secret=secret,
|
||||
version=version,
|
||||
)
|
||||
|
||||
if aead_cipher_name == b"chacha20-poly1305":
|
||||
self.aead = AeadChaCha20Poly1305(key, iv)
|
||||
elif aead_cipher_name == b"aes-256-gcm":
|
||||
self.aead = AeadAes256Gcm(key, iv)
|
||||
elif aead_cipher_name == b"aes-128-gcm":
|
||||
self.aead = AeadAes128Gcm(key, iv)
|
||||
else:
|
||||
raise CryptoError(f"Invalid cipher name: {aead_cipher_name.decode()}")
|
||||
|
||||
self.cipher_suite = cipher_suite
|
||||
self.hp = HeaderProtection(hp_cipher_name.decode(), hp)
|
||||
self.secret = secret
|
||||
self.version = version
|
||||
|
||||
# trigger callback
|
||||
self._setup_cb("tls")
|
||||
|
||||
def teardown(self) -> None:
|
||||
self.aead = None
|
||||
self.cipher_suite = None
|
||||
self.hp = None
|
||||
self.secret = None
|
||||
|
||||
# trigger callback
|
||||
self._teardown_cb("tls")
|
||||
|
||||
|
||||
def apply_key_phase(self: CryptoContext, crypto: CryptoContext, trigger: str) -> None:
|
||||
self.aead = crypto.aead
|
||||
self.key_phase = crypto.key_phase
|
||||
self.secret = crypto.secret
|
||||
|
||||
# trigger callback
|
||||
self._setup_cb(trigger)
|
||||
|
||||
|
||||
def next_key_phase(self: CryptoContext) -> CryptoContext:
|
||||
algorithm = cipher_suite_hash(self.cipher_suite)
|
||||
|
||||
crypto = CryptoContext(key_phase=int(not self.key_phase))
|
||||
crypto.setup(
|
||||
cipher_suite=self.cipher_suite,
|
||||
secret=hkdf_expand_label(
|
||||
algorithm, self.secret, b"quic ku", b"", int(algorithm / 8)
|
||||
),
|
||||
version=self.version,
|
||||
)
|
||||
return crypto
|
||||
|
||||
|
||||
class CryptoPair:
|
||||
__slots__ = (
|
||||
"aead_tag_size",
|
||||
"recv",
|
||||
"send",
|
||||
"_update_key_requested",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
recv_setup_cb: Callback = NoCallback,
|
||||
recv_teardown_cb: Callback = NoCallback,
|
||||
send_setup_cb: Callback = NoCallback,
|
||||
send_teardown_cb: Callback = NoCallback,
|
||||
) -> None:
|
||||
self.aead_tag_size = 16
|
||||
self.recv = CryptoContext(setup_cb=recv_setup_cb, teardown_cb=recv_teardown_cb)
|
||||
self.send = CryptoContext(setup_cb=send_setup_cb, teardown_cb=send_teardown_cb)
|
||||
self._update_key_requested = False
|
||||
|
||||
def decrypt_packet(
|
||||
self, packet: bytes, encrypted_offset: int, expected_packet_number: int
|
||||
) -> tuple[bytes, bytes, int]:
|
||||
plain_header, payload, packet_number, update_key = self.recv.decrypt_packet(
|
||||
packet, encrypted_offset, expected_packet_number
|
||||
)
|
||||
if update_key:
|
||||
self._update_key("remote_update")
|
||||
return plain_header, payload, packet_number
|
||||
|
||||
def encrypt_packet(
|
||||
self, plain_header: bytes, plain_payload: bytes, packet_number: int
|
||||
) -> bytes:
|
||||
if self._update_key_requested:
|
||||
self._update_key("local_update")
|
||||
return self.send.encrypt_packet(plain_header, plain_payload, packet_number)
|
||||
|
||||
def setup_initial(self, cid: bytes, is_client: bool, version: int) -> None:
|
||||
if is_client:
|
||||
recv_label, send_label = b"server in", b"client in"
|
||||
else:
|
||||
recv_label, send_label = b"client in", b"server in"
|
||||
|
||||
if version == QuicProtocolVersion.VERSION_2:
|
||||
initial_salt = INITIAL_SALT_VERSION_2
|
||||
else:
|
||||
initial_salt = INITIAL_SALT_VERSION_1
|
||||
|
||||
algorithm = cipher_suite_hash(INITIAL_CIPHER_SUITE)
|
||||
digest_size = int(algorithm / 8)
|
||||
initial_secret = hkdf_extract(algorithm, initial_salt, cid)
|
||||
self.recv.setup(
|
||||
cipher_suite=INITIAL_CIPHER_SUITE,
|
||||
secret=hkdf_expand_label(
|
||||
algorithm, initial_secret, recv_label, b"", digest_size
|
||||
),
|
||||
version=version,
|
||||
)
|
||||
self.send.setup(
|
||||
cipher_suite=INITIAL_CIPHER_SUITE,
|
||||
secret=hkdf_expand_label(
|
||||
algorithm, initial_secret, send_label, b"", digest_size
|
||||
),
|
||||
version=version,
|
||||
)
|
||||
|
||||
def teardown(self) -> None:
|
||||
self.recv.teardown()
|
||||
self.send.teardown()
|
||||
|
||||
def update_key(self) -> None:
|
||||
self._update_key_requested = True
|
||||
|
||||
@property
|
||||
def key_phase(self) -> int:
|
||||
if self._update_key_requested:
|
||||
return int(not self.recv.key_phase)
|
||||
else:
|
||||
return self.recv.key_phase
|
||||
|
||||
def _update_key(self, trigger: str) -> None:
|
||||
apply_key_phase(self.recv, next_key_phase(self.recv), trigger=trigger)
|
||||
apply_key_phase(self.send, next_key_phase(self.send), trigger=trigger)
|
||||
self._update_key_requested = False
|
||||
127
.venv/lib/python3.9/site-packages/qh3/quic/events.py
Normal file
127
.venv/lib/python3.9/site-packages/qh3/quic/events.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class QuicEvent:
|
||||
"""
|
||||
Base class for QUIC events.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionIdIssued(QuicEvent):
|
||||
connection_id: bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionIdRetired(QuicEvent):
|
||||
connection_id: bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionTerminated(QuicEvent):
|
||||
"""
|
||||
The ConnectionTerminated event is fired when the QUIC connection is terminated.
|
||||
"""
|
||||
|
||||
error_code: int
|
||||
"The error code which was specified when closing the connection."
|
||||
|
||||
frame_type: int | None
|
||||
"The frame type which caused the connection to be closed, or `None`."
|
||||
|
||||
reason_phrase: str
|
||||
"The human-readable reason for which the connection was closed."
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatagramFrameReceived(QuicEvent):
|
||||
"""
|
||||
The DatagramFrameReceived event is fired when a DATAGRAM frame is received.
|
||||
"""
|
||||
|
||||
data: bytes
|
||||
"The data which was received."
|
||||
|
||||
|
||||
@dataclass
|
||||
class HandshakeCompleted(QuicEvent):
|
||||
"""
|
||||
The HandshakeCompleted event is fired when the TLS handshake completes.
|
||||
"""
|
||||
|
||||
alpn_protocol: str | None
|
||||
"The protocol which was negotiated using ALPN, or `None`."
|
||||
|
||||
early_data_accepted: bool
|
||||
"Whether early (0-RTT) data was accepted by the remote peer."
|
||||
|
||||
session_resumed: bool
|
||||
"Whether a TLS session was resumed."
|
||||
|
||||
|
||||
@dataclass
|
||||
class PingAcknowledged(QuicEvent):
|
||||
"""
|
||||
The PingAcknowledged event is fired when a PING frame is acknowledged.
|
||||
"""
|
||||
|
||||
uid: int
|
||||
"The unique ID of the PING."
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProtocolNegotiated(QuicEvent):
|
||||
"""
|
||||
The ProtocolNegotiated event is fired when ALPN negotiation completes.
|
||||
"""
|
||||
|
||||
alpn_protocol: str | None
|
||||
"The protocol which was negotiated using ALPN, or `None`."
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamDataReceived(QuicEvent):
|
||||
"""
|
||||
The StreamDataReceived event is fired whenever data is received on a
|
||||
stream.
|
||||
"""
|
||||
|
||||
data: bytes
|
||||
"The data which was received."
|
||||
|
||||
end_stream: bool
|
||||
"Whether the STREAM frame had the FIN bit set."
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream the data was received for."
|
||||
|
||||
|
||||
@dataclass
|
||||
class StopSendingReceived(QuicEvent):
|
||||
"""
|
||||
The StopSendingReceived event is fired when the remote peer requests
|
||||
stopping data transmission on a stream.
|
||||
"""
|
||||
|
||||
error_code: int
|
||||
"The error code that was sent from the peer."
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream that the peer requested stopping data transmission."
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamReset(QuicEvent):
|
||||
"""
|
||||
The StreamReset event is fired when the remote peer resets a stream.
|
||||
"""
|
||||
|
||||
error_code: int
|
||||
"The error code that triggered the reset."
|
||||
|
||||
stream_id: int
|
||||
"The ID of the stream that was reset."
|
||||
333
.venv/lib/python3.9/site-packages/qh3/quic/logger.py
Normal file
333
.venv/lib/python3.9/site-packages/qh3/quic/logger.py
Normal file
@@ -0,0 +1,333 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import binascii
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Any
|
||||
|
||||
from .._hazmat import RangeSet
|
||||
from ..h3.events import Headers
|
||||
from .packet import (
|
||||
QuicFrameType,
|
||||
QuicPacketType,
|
||||
QuicStreamFrame,
|
||||
QuicTransportParameters,
|
||||
)
|
||||
|
||||
PACKET_TYPE_NAMES = {
|
||||
QuicPacketType.INITIAL: "initial",
|
||||
QuicPacketType.HANDSHAKE: "handshake",
|
||||
QuicPacketType.ZERO_RTT: "0RTT",
|
||||
QuicPacketType.ONE_RTT: "1RTT",
|
||||
QuicPacketType.RETRY: "retry",
|
||||
QuicPacketType.VERSION_NEGOTIATION: "version_negotiation",
|
||||
}
|
||||
QLOG_VERSION = "0.3"
|
||||
|
||||
|
||||
def hexdump(data: bytes) -> str:
|
||||
return binascii.hexlify(data).decode("ascii")
|
||||
|
||||
|
||||
class QuicLoggerTrace:
|
||||
"""
|
||||
A QUIC event trace.
|
||||
|
||||
Events are logged in the format defined by qlog.
|
||||
|
||||
See:
|
||||
- https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-02
|
||||
- https://datatracker.ietf.org/doc/html/draft-marx-quic-qlog-quic-events
|
||||
- https://datatracker.ietf.org/doc/html/draft-marx-quic-qlog-h3-events
|
||||
"""
|
||||
|
||||
def __init__(self, *, is_client: bool, odcid: bytes) -> None:
|
||||
self._odcid = odcid
|
||||
self._events: deque[dict[str, Any]] = deque()
|
||||
self._vantage_point = {
|
||||
"name": "qh3",
|
||||
"type": "client" if is_client else "server",
|
||||
}
|
||||
|
||||
# QUIC
|
||||
|
||||
def encode_ack_frame(self, ranges: RangeSet, delay: float) -> dict:
|
||||
return {
|
||||
"ack_delay": self.encode_time(delay),
|
||||
"acked_ranges": [[x[0], x[1] - 1] for x in ranges],
|
||||
"frame_type": "ack",
|
||||
}
|
||||
|
||||
def encode_connection_close_frame(
|
||||
self, error_code: int, frame_type: int | None, reason_phrase: str
|
||||
) -> dict:
|
||||
attrs = {
|
||||
"error_code": error_code,
|
||||
"error_space": "application" if frame_type is None else "transport",
|
||||
"frame_type": "connection_close",
|
||||
"raw_error_code": error_code,
|
||||
"reason": reason_phrase,
|
||||
}
|
||||
if frame_type is not None:
|
||||
attrs["trigger_frame_type"] = frame_type
|
||||
|
||||
return attrs
|
||||
|
||||
def encode_connection_limit_frame(self, frame_type: int, maximum: int) -> dict:
|
||||
if frame_type == QuicFrameType.MAX_DATA:
|
||||
return {"frame_type": "max_data", "maximum": maximum}
|
||||
else:
|
||||
return {
|
||||
"frame_type": "max_streams",
|
||||
"maximum": maximum,
|
||||
"stream_type": (
|
||||
"unidirectional"
|
||||
if frame_type == QuicFrameType.MAX_STREAMS_UNI
|
||||
else "bidirectional"
|
||||
),
|
||||
}
|
||||
|
||||
def encode_crypto_frame(self, frame: QuicStreamFrame) -> dict:
|
||||
return {
|
||||
"frame_type": "crypto",
|
||||
"length": len(frame.data),
|
||||
"offset": frame.offset,
|
||||
}
|
||||
|
||||
def encode_data_blocked_frame(self, limit: int) -> dict:
|
||||
return {"frame_type": "data_blocked", "limit": limit}
|
||||
|
||||
def encode_datagram_frame(self, length: int) -> dict:
|
||||
return {"frame_type": "datagram", "length": length}
|
||||
|
||||
def encode_handshake_done_frame(self) -> dict:
|
||||
return {"frame_type": "handshake_done"}
|
||||
|
||||
def encode_max_stream_data_frame(self, maximum: int, stream_id: int) -> dict:
|
||||
return {
|
||||
"frame_type": "max_stream_data",
|
||||
"maximum": maximum,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_new_connection_id_frame(
|
||||
self,
|
||||
connection_id: bytes,
|
||||
retire_prior_to: int,
|
||||
sequence_number: int,
|
||||
stateless_reset_token: bytes,
|
||||
) -> dict:
|
||||
return {
|
||||
"connection_id": hexdump(connection_id),
|
||||
"frame_type": "new_connection_id",
|
||||
"length": len(connection_id),
|
||||
"reset_token": hexdump(stateless_reset_token),
|
||||
"retire_prior_to": retire_prior_to,
|
||||
"sequence_number": sequence_number,
|
||||
}
|
||||
|
||||
def encode_new_token_frame(self, token: bytes) -> dict:
|
||||
return {
|
||||
"frame_type": "new_token",
|
||||
"length": len(token),
|
||||
"token": hexdump(token),
|
||||
}
|
||||
|
||||
def encode_padding_frame(self) -> dict:
|
||||
return {"frame_type": "padding"}
|
||||
|
||||
def encode_path_challenge_frame(self, data: bytes) -> dict:
|
||||
return {"data": hexdump(data), "frame_type": "path_challenge"}
|
||||
|
||||
def encode_path_response_frame(self, data: bytes) -> dict:
|
||||
return {"data": hexdump(data), "frame_type": "path_response"}
|
||||
|
||||
def encode_ping_frame(self) -> dict:
|
||||
return {"frame_type": "ping"}
|
||||
|
||||
def encode_reset_stream_frame(
|
||||
self, error_code: int, final_size: int, stream_id: int
|
||||
) -> dict:
|
||||
return {
|
||||
"error_code": error_code,
|
||||
"final_size": final_size,
|
||||
"frame_type": "reset_stream",
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_retire_connection_id_frame(self, sequence_number: int) -> dict:
|
||||
return {
|
||||
"frame_type": "retire_connection_id",
|
||||
"sequence_number": sequence_number,
|
||||
}
|
||||
|
||||
def encode_stream_data_blocked_frame(self, limit: int, stream_id: int) -> dict:
|
||||
return {
|
||||
"frame_type": "stream_data_blocked",
|
||||
"limit": limit,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_stop_sending_frame(self, error_code: int, stream_id: int) -> dict:
|
||||
return {
|
||||
"frame_type": "stop_sending",
|
||||
"error_code": error_code,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_stream_frame(self, frame: QuicStreamFrame, stream_id: int) -> dict:
|
||||
return {
|
||||
"fin": frame.fin,
|
||||
"frame_type": "stream",
|
||||
"length": len(frame.data),
|
||||
"offset": frame.offset,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_streams_blocked_frame(self, is_unidirectional: bool, limit: int) -> dict:
|
||||
return {
|
||||
"frame_type": "streams_blocked",
|
||||
"limit": limit,
|
||||
"stream_type": "unidirectional" if is_unidirectional else "bidirectional",
|
||||
}
|
||||
|
||||
def encode_time(self, seconds: float) -> float:
|
||||
"""
|
||||
Convert a time to milliseconds.
|
||||
"""
|
||||
return seconds * 1000
|
||||
|
||||
def encode_transport_parameters(
|
||||
self, owner: str, parameters: QuicTransportParameters
|
||||
) -> dict[str, Any]:
|
||||
data: dict[str, Any] = {"owner": owner}
|
||||
for param_name, param_value in parameters.__dict__.items():
|
||||
if isinstance(param_value, bool):
|
||||
data[param_name] = param_value
|
||||
elif isinstance(param_value, bytes):
|
||||
data[param_name] = hexdump(param_value)
|
||||
elif isinstance(param_value, int):
|
||||
data[param_name] = param_value
|
||||
return data
|
||||
|
||||
def packet_type(self, packet_type: QuicPacketType) -> str:
|
||||
return PACKET_TYPE_NAMES[packet_type]
|
||||
|
||||
# HTTP/3
|
||||
|
||||
def encode_http3_data_frame(self, length: int, stream_id: int) -> dict:
|
||||
return {
|
||||
"frame": {"frame_type": "data"},
|
||||
"length": length,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_http3_headers_frame(
|
||||
self, length: int, headers: Headers, stream_id: int
|
||||
) -> dict:
|
||||
return {
|
||||
"frame": {
|
||||
"frame_type": "headers",
|
||||
"headers": self._encode_http3_headers(headers),
|
||||
},
|
||||
"length": length,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def encode_http3_push_promise_frame(
|
||||
self, length: int, headers: Headers, push_id: int, stream_id: int
|
||||
) -> dict:
|
||||
return {
|
||||
"frame": {
|
||||
"frame_type": "push_promise",
|
||||
"headers": self._encode_http3_headers(headers),
|
||||
"push_id": push_id,
|
||||
},
|
||||
"length": length,
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
def _encode_http3_headers(self, headers: Headers) -> list[dict]:
|
||||
return [
|
||||
{"name": h[0].decode("utf8"), "value": h[1].decode("utf8")} for h in headers
|
||||
]
|
||||
|
||||
# CORE
|
||||
|
||||
def log_event(self, *, category: str, event: str, data: dict) -> None:
|
||||
self._events.append(
|
||||
{
|
||||
"data": data,
|
||||
"name": category + ":" + event,
|
||||
"time": self.encode_time(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""
|
||||
Return the trace as a dictionary which can be written as JSON.
|
||||
"""
|
||||
return {
|
||||
"common_fields": {
|
||||
"ODCID": hexdump(self._odcid),
|
||||
},
|
||||
"events": list(self._events),
|
||||
"vantage_point": self._vantage_point,
|
||||
}
|
||||
|
||||
|
||||
class QuicLogger:
|
||||
"""
|
||||
A QUIC event logger which stores traces in memory.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._traces: list[QuicLoggerTrace] = []
|
||||
|
||||
def start_trace(self, is_client: bool, odcid: bytes) -> QuicLoggerTrace:
|
||||
trace = QuicLoggerTrace(is_client=is_client, odcid=odcid)
|
||||
self._traces.append(trace)
|
||||
return trace
|
||||
|
||||
def end_trace(self, trace: QuicLoggerTrace) -> None:
|
||||
assert trace in self._traces, "QuicLoggerTrace does not belong to QuicLogger"
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""
|
||||
Return the traces as a dictionary which can be written as JSON.
|
||||
"""
|
||||
return {
|
||||
"qlog_format": "JSON",
|
||||
"qlog_version": QLOG_VERSION,
|
||||
"traces": [trace.to_dict() for trace in self._traces],
|
||||
}
|
||||
|
||||
|
||||
class QuicFileLogger(QuicLogger):
|
||||
"""
|
||||
A QUIC event logger which writes one trace per file.
|
||||
"""
|
||||
|
||||
def __init__(self, path: str) -> None:
|
||||
if not os.path.isdir(path):
|
||||
raise ValueError(f"QUIC log output directory '{path}' does not exist")
|
||||
self.path = path
|
||||
super().__init__()
|
||||
|
||||
def end_trace(self, trace: QuicLoggerTrace) -> None:
|
||||
trace_dict = trace.to_dict()
|
||||
trace_path = os.path.join(
|
||||
self.path, trace_dict["common_fields"]["ODCID"] + ".qlog"
|
||||
)
|
||||
with open(trace_path, "w") as logger_fp:
|
||||
json.dump(
|
||||
{
|
||||
"qlog_format": "JSON",
|
||||
"qlog_version": QLOG_VERSION,
|
||||
"traces": [trace_dict],
|
||||
},
|
||||
logger_fp,
|
||||
)
|
||||
self._traces.remove(trace)
|
||||
628
.venv/lib/python3.9/site-packages/qh3/quic/packet.py
Normal file
628
.venv/lib/python3.9/site-packages/qh3/quic/packet.py
Normal file
@@ -0,0 +1,628 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import binascii
|
||||
import ipaddress
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
|
||||
from .._compat import DATACLASS_KWARGS
|
||||
from .._hazmat import AeadAes128Gcm, Buffer, RangeSet
|
||||
|
||||
PACKET_LONG_HEADER = 0x80
|
||||
PACKET_FIXED_BIT = 0x40
|
||||
PACKET_SPIN_BIT = 0x20
|
||||
|
||||
CONNECTION_ID_MAX_SIZE = 20
|
||||
PACKET_NUMBER_MAX_SIZE = 4
|
||||
RETRY_AEAD_KEY_VERSION_1 = binascii.unhexlify("be0c690b9f66575a1d766b54e368c84e")
|
||||
RETRY_AEAD_KEY_VERSION_2 = binascii.unhexlify("8fb4b01b56ac48e260fbcbcead7ccc92")
|
||||
RETRY_AEAD_NONCE_VERSION_1 = binascii.unhexlify("461599d35d632bf2239825bb")
|
||||
RETRY_AEAD_NONCE_VERSION_2 = binascii.unhexlify("d86969bc2d7c6d9990efb04a")
|
||||
RETRY_INTEGRITY_TAG_SIZE = 16
|
||||
STATELESS_RESET_TOKEN_SIZE = 16
|
||||
|
||||
|
||||
class QuicErrorCode(IntEnum):
|
||||
NO_ERROR = 0x0
|
||||
INTERNAL_ERROR = 0x1
|
||||
CONNECTION_REFUSED = 0x2
|
||||
FLOW_CONTROL_ERROR = 0x3
|
||||
STREAM_LIMIT_ERROR = 0x4
|
||||
STREAM_STATE_ERROR = 0x5
|
||||
FINAL_SIZE_ERROR = 0x6
|
||||
FRAME_ENCODING_ERROR = 0x7
|
||||
TRANSPORT_PARAMETER_ERROR = 0x8
|
||||
CONNECTION_ID_LIMIT_ERROR = 0x9
|
||||
PROTOCOL_VIOLATION = 0xA
|
||||
INVALID_TOKEN = 0xB
|
||||
APPLICATION_ERROR = 0xC
|
||||
CRYPTO_BUFFER_EXCEEDED = 0xD
|
||||
KEY_UPDATE_ERROR = 0xE
|
||||
AEAD_LIMIT_REACHED = 0xF
|
||||
VERSION_NEGOTIATION_ERROR = 0x11
|
||||
CRYPTO_ERROR = 0x100
|
||||
|
||||
|
||||
class QuicPacketType(IntEnum):
|
||||
INITIAL = 0
|
||||
ZERO_RTT = 1
|
||||
HANDSHAKE = 2
|
||||
RETRY = 3
|
||||
VERSION_NEGOTIATION = 4
|
||||
ONE_RTT = 5
|
||||
|
||||
|
||||
# For backwards compatibility only, use `QuicPacketType` in new code.
|
||||
PACKET_TYPE_INITIAL = QuicPacketType.INITIAL
|
||||
|
||||
# QUIC version 1
|
||||
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2
|
||||
PACKET_LONG_TYPE_ENCODE_VERSION_1 = {
|
||||
QuicPacketType.INITIAL: 0,
|
||||
QuicPacketType.ZERO_RTT: 1,
|
||||
QuicPacketType.HANDSHAKE: 2,
|
||||
QuicPacketType.RETRY: 3,
|
||||
}
|
||||
PACKET_LONG_TYPE_DECODE_VERSION_1 = {
|
||||
v: i for (i, v) in PACKET_LONG_TYPE_ENCODE_VERSION_1.items()
|
||||
}
|
||||
|
||||
# QUIC version 2
|
||||
# https://datatracker.ietf.org/doc/html/rfc9369#section-3.2
|
||||
PACKET_LONG_TYPE_ENCODE_VERSION_2 = {
|
||||
QuicPacketType.INITIAL: 1,
|
||||
QuicPacketType.ZERO_RTT: 2,
|
||||
QuicPacketType.HANDSHAKE: 3,
|
||||
QuicPacketType.RETRY: 0,
|
||||
}
|
||||
PACKET_LONG_TYPE_DECODE_VERSION_2 = {
|
||||
v: i for (i, v) in PACKET_LONG_TYPE_ENCODE_VERSION_2.items()
|
||||
}
|
||||
|
||||
|
||||
class QuicProtocolVersion(IntEnum):
|
||||
NEGOTIATION = 0
|
||||
VERSION_1 = 0x00000001
|
||||
VERSION_2 = 0x6B3343CF
|
||||
|
||||
|
||||
@dataclass(**DATACLASS_KWARGS)
|
||||
class QuicHeader:
|
||||
version: int | None
|
||||
"The protocol version. Only present in long header packets."
|
||||
|
||||
packet_type: QuicPacketType
|
||||
"The type of the packet."
|
||||
|
||||
packet_length: int
|
||||
"The total length of the packet, in bytes."
|
||||
|
||||
destination_cid: bytes
|
||||
"The destination connection ID."
|
||||
|
||||
source_cid: bytes
|
||||
"The destination connection ID."
|
||||
|
||||
token: bytes
|
||||
"The address verification token. Only present in `INITIAL` and `RETRY` packets."
|
||||
|
||||
integrity_tag: bytes
|
||||
"The retry integrity tag. Only present in `RETRY` packets."
|
||||
|
||||
supported_versions: list[int]
|
||||
"Supported protocol versions. Only present in `VERSION_NEGOTIATION` packets."
|
||||
|
||||
|
||||
def get_retry_integrity_tag(
|
||||
packet_without_tag: bytes, original_destination_cid: bytes, version: int
|
||||
) -> bytes:
|
||||
"""
|
||||
Calculate the integrity tag for a RETRY packet.
|
||||
"""
|
||||
# build Retry pseudo packet
|
||||
buf = Buffer(capacity=1 + len(original_destination_cid) + len(packet_without_tag))
|
||||
buf.push_uint8(len(original_destination_cid))
|
||||
buf.push_bytes(original_destination_cid)
|
||||
buf.push_bytes(packet_without_tag)
|
||||
assert buf.eof()
|
||||
|
||||
if version == QuicProtocolVersion.VERSION_2:
|
||||
aead_key = RETRY_AEAD_KEY_VERSION_2
|
||||
aead_nonce = RETRY_AEAD_NONCE_VERSION_2
|
||||
else:
|
||||
aead_key = RETRY_AEAD_KEY_VERSION_1
|
||||
aead_nonce = RETRY_AEAD_NONCE_VERSION_1
|
||||
|
||||
# run AES-128-GCM
|
||||
aead = AeadAes128Gcm(aead_key, b"null!12bytes")
|
||||
integrity_tag = aead.encrypt_with_nonce(aead_nonce, b"", buf.data)
|
||||
assert len(integrity_tag) == RETRY_INTEGRITY_TAG_SIZE
|
||||
return integrity_tag
|
||||
|
||||
|
||||
def get_spin_bit(first_byte: int) -> bool:
|
||||
if first_byte & PACKET_SPIN_BIT:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_long_header(first_byte: int) -> bool:
|
||||
if first_byte & PACKET_LONG_HEADER:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def pretty_protocol_version(version: int) -> str:
|
||||
"""
|
||||
Return a user-friendly representation of a protocol version.
|
||||
"""
|
||||
try:
|
||||
version_name = QuicProtocolVersion(version).name
|
||||
except ValueError:
|
||||
version_name = "UNKNOWN"
|
||||
return f"0x{version:08x} ({version_name})"
|
||||
|
||||
|
||||
def pull_quic_header(buf: Buffer, host_cid_length: int | None = None) -> QuicHeader:
|
||||
packet_start = buf.tell()
|
||||
|
||||
version: int | None
|
||||
|
||||
integrity_tag = b""
|
||||
supported_versions = []
|
||||
token = b""
|
||||
|
||||
first_byte = buf.pull_uint8()
|
||||
|
||||
if is_long_header(first_byte):
|
||||
# Long Header Packets.
|
||||
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2
|
||||
version = buf.pull_uint32()
|
||||
|
||||
destination_cid_length = buf.pull_uint8()
|
||||
if destination_cid_length > CONNECTION_ID_MAX_SIZE:
|
||||
raise ValueError(
|
||||
f"Destination CID is too long ({destination_cid_length} bytes)"
|
||||
)
|
||||
destination_cid = buf.pull_bytes(destination_cid_length)
|
||||
|
||||
source_cid_length = buf.pull_uint8()
|
||||
if source_cid_length > CONNECTION_ID_MAX_SIZE:
|
||||
raise ValueError(f"Source CID is too long ({source_cid_length} bytes)")
|
||||
source_cid = buf.pull_bytes(source_cid_length)
|
||||
|
||||
if version == QuicProtocolVersion.NEGOTIATION:
|
||||
# Version Negotiation Packet.
|
||||
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2.1
|
||||
packet_type = QuicPacketType.VERSION_NEGOTIATION
|
||||
while not buf.eof():
|
||||
supported_versions.append(buf.pull_uint32())
|
||||
packet_end = buf.tell()
|
||||
else:
|
||||
if not (first_byte & PACKET_FIXED_BIT):
|
||||
raise ValueError("Packet fixed bit is zero")
|
||||
|
||||
if version == QuicProtocolVersion.VERSION_2:
|
||||
packet_type = PACKET_LONG_TYPE_DECODE_VERSION_2[
|
||||
(first_byte & 0x30) >> 4
|
||||
]
|
||||
else:
|
||||
packet_type = PACKET_LONG_TYPE_DECODE_VERSION_1[
|
||||
(first_byte & 0x30) >> 4
|
||||
]
|
||||
|
||||
if packet_type == QuicPacketType.INITIAL:
|
||||
token_length = buf.pull_uint_var()
|
||||
token = buf.pull_bytes(token_length)
|
||||
rest_length = buf.pull_uint_var()
|
||||
elif packet_type == QuicPacketType.ZERO_RTT:
|
||||
rest_length = buf.pull_uint_var()
|
||||
elif packet_type == QuicPacketType.HANDSHAKE:
|
||||
rest_length = buf.pull_uint_var()
|
||||
else:
|
||||
token_length = buf.capacity - buf.tell() - RETRY_INTEGRITY_TAG_SIZE
|
||||
token = buf.pull_bytes(token_length)
|
||||
integrity_tag = buf.pull_bytes(RETRY_INTEGRITY_TAG_SIZE)
|
||||
rest_length = 0
|
||||
|
||||
# Check remainder length.
|
||||
packet_end = buf.tell() + rest_length
|
||||
|
||||
if packet_end > buf.capacity:
|
||||
raise ValueError("Packet payload is truncated")
|
||||
|
||||
else:
|
||||
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.3
|
||||
if not (first_byte & PACKET_FIXED_BIT):
|
||||
raise ValueError("Packet fixed bit is zero")
|
||||
|
||||
version = None
|
||||
packet_type = QuicPacketType.ONE_RTT
|
||||
destination_cid = buf.pull_bytes(host_cid_length)
|
||||
source_cid = b""
|
||||
packet_end = buf.capacity
|
||||
|
||||
return QuicHeader(
|
||||
version=version,
|
||||
packet_type=packet_type,
|
||||
packet_length=packet_end - packet_start,
|
||||
destination_cid=destination_cid,
|
||||
source_cid=source_cid,
|
||||
token=token,
|
||||
integrity_tag=integrity_tag,
|
||||
supported_versions=supported_versions,
|
||||
)
|
||||
|
||||
|
||||
def encode_long_header_first_byte(
|
||||
version: int, packet_type: QuicPacketType, bits: int
|
||||
) -> int:
|
||||
"""
|
||||
Encode the first byte of a long header packet.
|
||||
"""
|
||||
if version == QuicProtocolVersion.VERSION_2:
|
||||
long_type_encode = PACKET_LONG_TYPE_ENCODE_VERSION_2
|
||||
else:
|
||||
long_type_encode = PACKET_LONG_TYPE_ENCODE_VERSION_1
|
||||
return (
|
||||
PACKET_LONG_HEADER
|
||||
| PACKET_FIXED_BIT
|
||||
| long_type_encode[packet_type] << 4
|
||||
| bits
|
||||
)
|
||||
|
||||
|
||||
def encode_quic_retry(
|
||||
version: int,
|
||||
source_cid: bytes,
|
||||
destination_cid: bytes,
|
||||
original_destination_cid: bytes,
|
||||
retry_token: bytes,
|
||||
unused: int = 0,
|
||||
) -> bytes:
|
||||
buf = Buffer(
|
||||
capacity=7
|
||||
+ len(destination_cid)
|
||||
+ len(source_cid)
|
||||
+ len(retry_token)
|
||||
+ RETRY_INTEGRITY_TAG_SIZE
|
||||
)
|
||||
buf.push_uint8(encode_long_header_first_byte(version, QuicPacketType.RETRY, unused))
|
||||
buf.push_uint32(version)
|
||||
buf.push_uint8(len(destination_cid))
|
||||
buf.push_bytes(destination_cid)
|
||||
buf.push_uint8(len(source_cid))
|
||||
buf.push_bytes(source_cid)
|
||||
buf.push_bytes(retry_token)
|
||||
buf.push_bytes(
|
||||
get_retry_integrity_tag(buf.data, original_destination_cid, version=version)
|
||||
)
|
||||
assert buf.eof()
|
||||
return buf.data
|
||||
|
||||
|
||||
def encode_quic_version_negotiation(
|
||||
source_cid: bytes, destination_cid: bytes, supported_versions: list[int]
|
||||
) -> bytes:
|
||||
buf = Buffer(
|
||||
capacity=7
|
||||
+ len(destination_cid)
|
||||
+ len(source_cid)
|
||||
+ 4 * len(supported_versions)
|
||||
)
|
||||
buf.push_uint8(os.urandom(1)[0] | PACKET_LONG_HEADER)
|
||||
buf.push_uint32(QuicProtocolVersion.NEGOTIATION)
|
||||
buf.push_uint8(len(destination_cid))
|
||||
buf.push_bytes(destination_cid)
|
||||
buf.push_uint8(len(source_cid))
|
||||
buf.push_bytes(source_cid)
|
||||
for version in supported_versions:
|
||||
buf.push_uint32(version)
|
||||
return buf.data
|
||||
|
||||
|
||||
# TLS EXTENSION
|
||||
|
||||
|
||||
@dataclass(**DATACLASS_KWARGS)
|
||||
class QuicPreferredAddress:
|
||||
ipv4_address: tuple[str, int] | None
|
||||
ipv6_address: tuple[str, int] | None
|
||||
connection_id: bytes
|
||||
stateless_reset_token: bytes
|
||||
|
||||
|
||||
@dataclass(**DATACLASS_KWARGS)
|
||||
class QuicVersionInformation:
|
||||
chosen_version: int
|
||||
available_versions: list[int]
|
||||
|
||||
|
||||
@dataclass()
|
||||
class QuicTransportParameters:
|
||||
original_destination_connection_id: bytes | None = None
|
||||
max_idle_timeout: int | None = None
|
||||
stateless_reset_token: bytes | None = None
|
||||
max_udp_payload_size: int | None = None
|
||||
initial_max_data: int | None = None
|
||||
initial_max_stream_data_bidi_local: int | None = None
|
||||
initial_max_stream_data_bidi_remote: int | None = None
|
||||
initial_max_stream_data_uni: int | None = None
|
||||
initial_max_streams_bidi: int | None = None
|
||||
initial_max_streams_uni: int | None = None
|
||||
ack_delay_exponent: int | None = None
|
||||
max_ack_delay: int | None = None
|
||||
disable_active_migration: bool | None = False
|
||||
preferred_address: QuicPreferredAddress | None = None
|
||||
active_connection_id_limit: int | None = None
|
||||
initial_source_connection_id: bytes | None = None
|
||||
retry_source_connection_id: bytes | None = None
|
||||
version_information: QuicVersionInformation | None = None
|
||||
max_datagram_frame_size: int | None = None
|
||||
quantum_readiness: bytes | None = None
|
||||
|
||||
|
||||
PARAMS = {
|
||||
0x00: ("original_destination_connection_id", bytes),
|
||||
0x01: ("max_idle_timeout", int),
|
||||
0x02: ("stateless_reset_token", bytes),
|
||||
0x03: ("max_udp_payload_size", int),
|
||||
0x04: ("initial_max_data", int),
|
||||
0x05: ("initial_max_stream_data_bidi_local", int),
|
||||
0x06: ("initial_max_stream_data_bidi_remote", int),
|
||||
0x07: ("initial_max_stream_data_uni", int),
|
||||
0x08: ("initial_max_streams_bidi", int),
|
||||
0x09: ("initial_max_streams_uni", int),
|
||||
0x0A: ("ack_delay_exponent", int),
|
||||
0x0B: ("max_ack_delay", int),
|
||||
0x0C: ("disable_active_migration", bool),
|
||||
0x0D: ("preferred_address", QuicPreferredAddress),
|
||||
0x0E: ("active_connection_id_limit", int),
|
||||
0x0F: ("initial_source_connection_id", bytes),
|
||||
0x10: ("retry_source_connection_id", bytes),
|
||||
# https://datatracker.ietf.org/doc/html/rfc9368#section-3
|
||||
0x11: ("version_information", QuicVersionInformation),
|
||||
# extensions
|
||||
0x0020: ("max_datagram_frame_size", int),
|
||||
0x0C37: ("quantum_readiness", bytes),
|
||||
}
|
||||
|
||||
|
||||
def pull_quic_preferred_address(buf: Buffer) -> QuicPreferredAddress:
|
||||
ipv4_address = None
|
||||
ipv4_host = buf.pull_bytes(4)
|
||||
ipv4_port = buf.pull_uint16()
|
||||
if ipv4_host != bytes(4):
|
||||
ipv4_address = (str(ipaddress.IPv4Address(ipv4_host)), ipv4_port)
|
||||
|
||||
ipv6_address = None
|
||||
ipv6_host = buf.pull_bytes(16)
|
||||
ipv6_port = buf.pull_uint16()
|
||||
if ipv6_host != bytes(16):
|
||||
ipv6_address = (str(ipaddress.IPv6Address(ipv6_host)), ipv6_port)
|
||||
|
||||
connection_id_length = buf.pull_uint8()
|
||||
connection_id = buf.pull_bytes(connection_id_length)
|
||||
stateless_reset_token = buf.pull_bytes(16)
|
||||
|
||||
return QuicPreferredAddress(
|
||||
ipv4_address=ipv4_address,
|
||||
ipv6_address=ipv6_address,
|
||||
connection_id=connection_id,
|
||||
stateless_reset_token=stateless_reset_token,
|
||||
)
|
||||
|
||||
|
||||
def push_quic_preferred_address(
|
||||
buf: Buffer, preferred_address: QuicPreferredAddress
|
||||
) -> None:
|
||||
if preferred_address.ipv4_address is not None:
|
||||
buf.push_bytes(ipaddress.IPv4Address(preferred_address.ipv4_address[0]).packed)
|
||||
buf.push_uint16(preferred_address.ipv4_address[1])
|
||||
else:
|
||||
buf.push_bytes(bytes(6))
|
||||
|
||||
if preferred_address.ipv6_address is not None:
|
||||
buf.push_bytes(ipaddress.IPv6Address(preferred_address.ipv6_address[0]).packed)
|
||||
buf.push_uint16(preferred_address.ipv6_address[1])
|
||||
else:
|
||||
buf.push_bytes(bytes(18))
|
||||
|
||||
buf.push_uint8(len(preferred_address.connection_id))
|
||||
buf.push_bytes(preferred_address.connection_id)
|
||||
buf.push_bytes(preferred_address.stateless_reset_token)
|
||||
|
||||
|
||||
def pull_quic_version_information(buf: Buffer, length: int) -> QuicVersionInformation:
|
||||
chosen_version = buf.pull_uint32()
|
||||
available_versions = []
|
||||
for i in range(length // 4 - 1):
|
||||
available_versions.append(buf.pull_uint32())
|
||||
|
||||
# If an endpoint receives a Chosen Version equal to zero, or any Available Version
|
||||
# equal to zero, it MUST treat it as a parsing failure.
|
||||
#
|
||||
# https://datatracker.ietf.org/doc/html/rfc9368#section-4
|
||||
if chosen_version == 0 or 0 in available_versions:
|
||||
raise ValueError("Version Information must not contain version 0")
|
||||
|
||||
return QuicVersionInformation(
|
||||
chosen_version=chosen_version,
|
||||
available_versions=available_versions,
|
||||
)
|
||||
|
||||
|
||||
def push_quic_version_information(
|
||||
buf: Buffer, version_information: QuicVersionInformation
|
||||
) -> None:
|
||||
buf.push_uint32(version_information.chosen_version)
|
||||
for version in version_information.available_versions:
|
||||
buf.push_uint32(version)
|
||||
|
||||
|
||||
def pull_quic_transport_parameters(buf: Buffer) -> QuicTransportParameters:
|
||||
params = QuicTransportParameters()
|
||||
while not buf.eof():
|
||||
param_id = buf.pull_uint_var()
|
||||
param_len = buf.pull_uint_var()
|
||||
param_start = buf.tell()
|
||||
if param_id in PARAMS:
|
||||
# parse known parameter
|
||||
param_name, param_type = PARAMS[param_id]
|
||||
if param_type is int:
|
||||
setattr(params, param_name, buf.pull_uint_var())
|
||||
elif param_type is bytes:
|
||||
setattr(params, param_name, buf.pull_bytes(param_len))
|
||||
elif param_type is QuicPreferredAddress:
|
||||
setattr(params, param_name, pull_quic_preferred_address(buf))
|
||||
elif param_type is QuicVersionInformation:
|
||||
setattr(
|
||||
params,
|
||||
param_name,
|
||||
pull_quic_version_information(buf, param_len),
|
||||
)
|
||||
else:
|
||||
setattr(params, param_name, True)
|
||||
else:
|
||||
# skip unknown parameter
|
||||
buf.pull_bytes(param_len)
|
||||
|
||||
if buf.tell() != param_start + param_len:
|
||||
raise ValueError("Transport parameter length does not match")
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def push_quic_transport_parameters(
|
||||
buf: Buffer, params: QuicTransportParameters
|
||||
) -> None:
|
||||
for param_id, (param_name, param_type) in PARAMS.items():
|
||||
param_value = getattr(params, param_name)
|
||||
if param_value is not None and param_value is not False:
|
||||
param_buf = Buffer(capacity=65536)
|
||||
if param_type is int:
|
||||
param_buf.push_uint_var(param_value)
|
||||
elif param_type is bytes:
|
||||
param_buf.push_bytes(param_value)
|
||||
elif param_type is QuicPreferredAddress:
|
||||
push_quic_preferred_address(param_buf, param_value)
|
||||
elif param_type is QuicVersionInformation:
|
||||
push_quic_version_information(param_buf, param_value)
|
||||
buf.push_uint_var(param_id)
|
||||
buf.push_uint_var(param_buf.tell())
|
||||
buf.push_bytes(param_buf.data)
|
||||
|
||||
|
||||
# FRAMES
|
||||
|
||||
|
||||
class QuicFrameType(IntEnum):
|
||||
PADDING = 0x00
|
||||
PING = 0x01
|
||||
ACK = 0x02
|
||||
ACK_ECN = 0x03
|
||||
RESET_STREAM = 0x04
|
||||
STOP_SENDING = 0x05
|
||||
CRYPTO = 0x06
|
||||
NEW_TOKEN = 0x07
|
||||
STREAM_BASE = 0x08
|
||||
MAX_DATA = 0x10
|
||||
MAX_STREAM_DATA = 0x11
|
||||
MAX_STREAMS_BIDI = 0x12
|
||||
MAX_STREAMS_UNI = 0x13
|
||||
DATA_BLOCKED = 0x14
|
||||
STREAM_DATA_BLOCKED = 0x15
|
||||
STREAMS_BLOCKED_BIDI = 0x16
|
||||
STREAMS_BLOCKED_UNI = 0x17
|
||||
NEW_CONNECTION_ID = 0x18
|
||||
RETIRE_CONNECTION_ID = 0x19
|
||||
PATH_CHALLENGE = 0x1A
|
||||
PATH_RESPONSE = 0x1B
|
||||
TRANSPORT_CLOSE = 0x1C
|
||||
APPLICATION_CLOSE = 0x1D
|
||||
HANDSHAKE_DONE = 0x1E
|
||||
DATAGRAM = 0x30
|
||||
DATAGRAM_WITH_LENGTH = 0x31
|
||||
|
||||
|
||||
NON_ACK_ELICITING_FRAME_TYPES = frozenset(
|
||||
[
|
||||
QuicFrameType.ACK,
|
||||
QuicFrameType.ACK_ECN,
|
||||
QuicFrameType.PADDING,
|
||||
QuicFrameType.TRANSPORT_CLOSE,
|
||||
QuicFrameType.APPLICATION_CLOSE,
|
||||
]
|
||||
)
|
||||
NON_IN_FLIGHT_FRAME_TYPES = frozenset(
|
||||
[
|
||||
QuicFrameType.ACK,
|
||||
QuicFrameType.ACK_ECN,
|
||||
QuicFrameType.TRANSPORT_CLOSE,
|
||||
QuicFrameType.APPLICATION_CLOSE,
|
||||
]
|
||||
)
|
||||
|
||||
PROBING_FRAME_TYPES = frozenset(
|
||||
[
|
||||
QuicFrameType.PATH_CHALLENGE,
|
||||
QuicFrameType.PATH_RESPONSE,
|
||||
QuicFrameType.PADDING,
|
||||
QuicFrameType.NEW_CONNECTION_ID,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@dataclass(**DATACLASS_KWARGS)
|
||||
class QuicResetStreamFrame:
|
||||
error_code: int
|
||||
final_size: int
|
||||
stream_id: int
|
||||
|
||||
|
||||
@dataclass(**DATACLASS_KWARGS)
|
||||
class QuicStopSendingFrame:
|
||||
error_code: int
|
||||
stream_id: int
|
||||
|
||||
|
||||
@dataclass(**DATACLASS_KWARGS)
|
||||
class QuicStreamFrame:
|
||||
data: bytes = b""
|
||||
fin: bool = False
|
||||
offset: int = 0
|
||||
|
||||
|
||||
def pull_ack_frame(buf: Buffer) -> tuple[RangeSet, int]:
|
||||
rangeset = RangeSet()
|
||||
end = buf.pull_uint_var() # largest acknowledged
|
||||
delay = buf.pull_uint_var()
|
||||
ack_range_count = buf.pull_uint_var()
|
||||
ack_count = buf.pull_uint_var() # first ack range
|
||||
rangeset.add(end - ack_count, end + 1)
|
||||
end -= ack_count
|
||||
for _ in range(ack_range_count):
|
||||
end -= buf.pull_uint_var() + 2
|
||||
ack_count = buf.pull_uint_var()
|
||||
rangeset.add(end - ack_count, end + 1)
|
||||
end -= ack_count
|
||||
return rangeset, delay
|
||||
|
||||
|
||||
def push_ack_frame(buf: Buffer, rangeset: RangeSet, delay: int) -> int:
|
||||
ranges = len(rangeset)
|
||||
index = ranges - 1
|
||||
r = rangeset[index]
|
||||
buf.push_uint_var(r[1] - 1)
|
||||
buf.push_uint_var(delay)
|
||||
buf.push_uint_var(index)
|
||||
buf.push_uint_var(r[1] - 1 - r[0])
|
||||
start = r[0]
|
||||
while index > 0:
|
||||
index -= 1
|
||||
r = rangeset[index]
|
||||
buf.push_uint_var(start - r[1] - 1)
|
||||
buf.push_uint_var(r[1] - r[0] - 1)
|
||||
start = r[0]
|
||||
return ranges
|
||||
437
.venv/lib/python3.9/site-packages/qh3/quic/packet_builder.py
Normal file
437
.venv/lib/python3.9/site-packages/qh3/quic/packet_builder.py
Normal file
@@ -0,0 +1,437 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
from typing import Any, Callable, Sequence
|
||||
|
||||
from .._compat import DATACLASS_KWARGS
|
||||
from .._hazmat import Buffer, size_uint_var
|
||||
from ..tls import Epoch
|
||||
from .crypto import CryptoPair
|
||||
from .logger import QuicLoggerTrace
|
||||
from .packet import (
|
||||
NON_ACK_ELICITING_FRAME_TYPES,
|
||||
NON_IN_FLIGHT_FRAME_TYPES,
|
||||
PACKET_FIXED_BIT,
|
||||
PACKET_NUMBER_MAX_SIZE,
|
||||
QuicFrameType,
|
||||
QuicPacketType,
|
||||
encode_long_header_first_byte,
|
||||
)
|
||||
|
||||
# MinPacketSize and MaxPacketSize control the packet sizes for UDP datagrams.
|
||||
# If MinPacketSize is unset, a default value of 1280 bytes
|
||||
# will be used during the handshake.
|
||||
# If MaxPacketSize is unset, a default value of 1452 bytes will be used.
|
||||
# DPLPMTUD will automatically determine the MTU supported
|
||||
# by the link-up to the MaxPacketSize,
|
||||
# except for in the case where MinPacketSize and MaxPacketSize
|
||||
# are configured to the same value,
|
||||
# in which case path MTU discovery will be disabled.
|
||||
# Values above 65355 are invalid.
|
||||
# 20-bytes for IPv6 overhead.
|
||||
# 1280 is very conservative
|
||||
# Chrome tries 1350 at startup
|
||||
# we should do a rudimentary MTU discovery
|
||||
# Sending a PING frame 1350
|
||||
# THEN 1452
|
||||
SMALLEST_MAX_DATAGRAM_SIZE = 1200
|
||||
PACKET_MAX_SIZE = 1280
|
||||
MTU_PROBE_SIZES = [1350, 1452]
|
||||
|
||||
PACKET_LENGTH_SEND_SIZE = 2
|
||||
PACKET_NUMBER_SEND_SIZE = 2
|
||||
|
||||
|
||||
QuicDeliveryHandler = Callable[..., None]
|
||||
|
||||
|
||||
class QuicDeliveryState(IntEnum):
|
||||
ACKED = 0
|
||||
LOST = 1
|
||||
|
||||
|
||||
@dataclass(**DATACLASS_KWARGS)
|
||||
class QuicSentPacket:
|
||||
epoch: Epoch
|
||||
in_flight: bool
|
||||
is_ack_eliciting: bool
|
||||
is_crypto_packet: bool
|
||||
packet_number: int
|
||||
packet_type: QuicPacketType
|
||||
sent_time: float | None = None
|
||||
sent_bytes: int = 0
|
||||
|
||||
delivery_handlers: list[tuple[QuicDeliveryHandler, Any]] = field(
|
||||
default_factory=list
|
||||
)
|
||||
quic_logger_frames: list[dict] = field(default_factory=list)
|
||||
|
||||
|
||||
class QuicPacketBuilderStop(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class QuicPacketBuilder:
|
||||
"""
|
||||
Helper for building QUIC packets.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"max_flight_bytes",
|
||||
"max_total_bytes",
|
||||
"quic_logger_frames",
|
||||
"_host_cid",
|
||||
"_is_client",
|
||||
"_peer_cid",
|
||||
"_peer_token",
|
||||
"_quic_logger",
|
||||
"_spin_bit",
|
||||
"_version",
|
||||
"_datagrams",
|
||||
"_datagram_flight_bytes",
|
||||
"_datagram_init",
|
||||
"_datagram_needs_padding",
|
||||
"_packets",
|
||||
"_flight_bytes",
|
||||
"_total_bytes",
|
||||
"_header_size",
|
||||
"_packet",
|
||||
"_packet_crypto",
|
||||
"_packet_long_header",
|
||||
"_packet_number",
|
||||
"_packet_start",
|
||||
"_packet_type",
|
||||
"_buffer",
|
||||
"_buffer_capacity",
|
||||
"_flight_capacity",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
host_cid: bytes,
|
||||
peer_cid: bytes,
|
||||
version: int,
|
||||
is_client: bool,
|
||||
max_datagram_size: int = PACKET_MAX_SIZE,
|
||||
packet_number: int = 0,
|
||||
peer_token: bytes = b"",
|
||||
quic_logger: QuicLoggerTrace | None = None,
|
||||
spin_bit: bool = False,
|
||||
):
|
||||
self.max_flight_bytes: int | None = None
|
||||
self.max_total_bytes: int | None = None
|
||||
self.quic_logger_frames: list[dict] | None = None
|
||||
|
||||
self._host_cid = host_cid
|
||||
self._is_client = is_client
|
||||
self._peer_cid = peer_cid
|
||||
self._peer_token = peer_token
|
||||
self._quic_logger = quic_logger
|
||||
self._spin_bit = spin_bit
|
||||
self._version = version
|
||||
|
||||
# assembled datagrams and packets
|
||||
self._datagrams: list[bytes] = []
|
||||
self._datagram_flight_bytes = 0
|
||||
self._datagram_init = True
|
||||
self._datagram_needs_padding = False
|
||||
self._packets: list[QuicSentPacket] = []
|
||||
self._flight_bytes = 0
|
||||
self._total_bytes = 0
|
||||
|
||||
# current packet
|
||||
self._header_size = 0
|
||||
self._packet: QuicSentPacket | None = None
|
||||
self._packet_crypto: CryptoPair | None = None
|
||||
self._packet_long_header = False
|
||||
self._packet_number = packet_number
|
||||
self._packet_start = 0
|
||||
self._packet_type: QuicPacketType | None = None
|
||||
|
||||
self._buffer = Buffer(max_datagram_size)
|
||||
self._buffer_capacity = max_datagram_size
|
||||
self._flight_capacity = max_datagram_size
|
||||
|
||||
@property
|
||||
def packet_is_empty(self) -> bool:
|
||||
"""
|
||||
Returns `True` if the current packet is empty.
|
||||
"""
|
||||
assert self._packet is not None
|
||||
packet_size = self._buffer.tell() - self._packet_start
|
||||
return packet_size <= self._header_size
|
||||
|
||||
@property
|
||||
def packet_number(self) -> int:
|
||||
"""
|
||||
Returns the packet number for the next packet.
|
||||
"""
|
||||
return self._packet_number
|
||||
|
||||
@property
|
||||
def remaining_buffer_space(self) -> int:
|
||||
"""
|
||||
Returns the remaining number of bytes which can be used in
|
||||
the current packet.
|
||||
"""
|
||||
return (
|
||||
self._buffer_capacity
|
||||
- self._buffer.tell()
|
||||
- self._packet_crypto.aead_tag_size
|
||||
)
|
||||
|
||||
@property
|
||||
def remaining_flight_space(self) -> int:
|
||||
"""
|
||||
Returns the remaining number of bytes which can be used in
|
||||
the current packet.
|
||||
"""
|
||||
return (
|
||||
self._flight_capacity
|
||||
- self._buffer.tell()
|
||||
- self._packet_crypto.aead_tag_size
|
||||
)
|
||||
|
||||
def flush(self) -> tuple[list[bytes], list[QuicSentPacket]]:
|
||||
"""
|
||||
Returns the assembled datagrams.
|
||||
"""
|
||||
if self._packet is not None:
|
||||
self._end_packet()
|
||||
self._flush_current_datagram()
|
||||
|
||||
datagrams = self._datagrams
|
||||
packets = self._packets
|
||||
self._datagrams = []
|
||||
self._packets = []
|
||||
return datagrams, packets
|
||||
|
||||
def start_frame(
|
||||
self,
|
||||
frame_type: int,
|
||||
capacity: int = 1,
|
||||
handler: QuicDeliveryHandler | None = None,
|
||||
handler_args: Sequence[Any] = [],
|
||||
) -> Buffer:
|
||||
"""
|
||||
Starts a new frame.
|
||||
"""
|
||||
if self.remaining_buffer_space < capacity or (
|
||||
frame_type not in NON_IN_FLIGHT_FRAME_TYPES
|
||||
and self.remaining_flight_space < capacity
|
||||
):
|
||||
raise QuicPacketBuilderStop
|
||||
|
||||
self._buffer.push_uint_var(frame_type)
|
||||
if frame_type not in NON_ACK_ELICITING_FRAME_TYPES:
|
||||
self._packet.is_ack_eliciting = True
|
||||
if frame_type not in NON_IN_FLIGHT_FRAME_TYPES:
|
||||
self._packet.in_flight = True
|
||||
if frame_type == QuicFrameType.CRYPTO:
|
||||
self._packet.is_crypto_packet = True
|
||||
if handler is not None:
|
||||
self._packet.delivery_handlers.append((handler, handler_args))
|
||||
return self._buffer
|
||||
|
||||
def start_packet(self, packet_type: QuicPacketType, crypto: CryptoPair) -> None:
|
||||
"""
|
||||
Starts a new packet.
|
||||
"""
|
||||
assert packet_type not in {
|
||||
QuicPacketType.RETRY,
|
||||
QuicPacketType.VERSION_NEGOTIATION,
|
||||
}, "Invalid packet type"
|
||||
|
||||
buf = self._buffer
|
||||
|
||||
# finish previous datagram
|
||||
if self._packet is not None:
|
||||
self._end_packet()
|
||||
|
||||
# if there is too little space remaining, start a new datagram
|
||||
# FIXME: the limit is arbitrary!
|
||||
packet_start = buf.tell()
|
||||
if self._buffer_capacity - packet_start < 128:
|
||||
self._flush_current_datagram()
|
||||
packet_start = 0
|
||||
|
||||
# initialize datagram if needed
|
||||
if self._datagram_init:
|
||||
if self.max_total_bytes is not None:
|
||||
remaining_total_bytes = self.max_total_bytes - self._total_bytes
|
||||
if remaining_total_bytes < self._buffer_capacity:
|
||||
self._buffer_capacity = remaining_total_bytes
|
||||
|
||||
self._flight_capacity = self._buffer_capacity
|
||||
if self.max_flight_bytes is not None:
|
||||
remaining_flight_bytes = self.max_flight_bytes - self._flight_bytes
|
||||
if remaining_flight_bytes < self._flight_capacity:
|
||||
self._flight_capacity = remaining_flight_bytes
|
||||
self._datagram_flight_bytes = 0
|
||||
self._datagram_init = False
|
||||
self._datagram_needs_padding = False
|
||||
|
||||
# calculate header size
|
||||
if packet_type != QuicPacketType.ONE_RTT:
|
||||
header_size = 11 + len(self._peer_cid) + len(self._host_cid)
|
||||
if packet_type == QuicPacketType.INITIAL:
|
||||
token_length = len(self._peer_token)
|
||||
header_size += size_uint_var(token_length) + token_length
|
||||
else:
|
||||
header_size = 3 + len(self._peer_cid)
|
||||
|
||||
# check we have enough space
|
||||
if packet_start + header_size >= self._buffer_capacity:
|
||||
raise QuicPacketBuilderStop
|
||||
|
||||
# determine ack epoch
|
||||
if packet_type == QuicPacketType.INITIAL:
|
||||
epoch = Epoch.INITIAL
|
||||
elif packet_type == QuicPacketType.HANDSHAKE:
|
||||
epoch = Epoch.HANDSHAKE
|
||||
else:
|
||||
epoch = Epoch.ONE_RTT
|
||||
|
||||
self._header_size = header_size
|
||||
self._packet = QuicSentPacket(
|
||||
epoch=epoch,
|
||||
in_flight=False,
|
||||
is_ack_eliciting=False,
|
||||
is_crypto_packet=False,
|
||||
packet_number=self._packet_number,
|
||||
packet_type=packet_type,
|
||||
)
|
||||
self._packet_crypto = crypto
|
||||
self._packet_start = packet_start
|
||||
self._packet_type = packet_type
|
||||
self.quic_logger_frames = self._packet.quic_logger_frames
|
||||
|
||||
buf.seek(self._packet_start + self._header_size)
|
||||
|
||||
def _end_packet(self) -> None:
|
||||
"""
|
||||
Ends the current packet.
|
||||
"""
|
||||
buf = self._buffer
|
||||
packet_size = buf.tell() - self._packet_start
|
||||
if packet_size > self._header_size:
|
||||
# padding to ensure sufficient sample size
|
||||
padding_size = (
|
||||
PACKET_NUMBER_MAX_SIZE
|
||||
- PACKET_NUMBER_SEND_SIZE
|
||||
+ self._header_size
|
||||
- packet_size
|
||||
)
|
||||
|
||||
# Padding for datagrams containing initial packets; see RFC 9000
|
||||
# section 14.1.
|
||||
if (
|
||||
self._is_client or self._packet.is_ack_eliciting
|
||||
) and self._packet_type == QuicPacketType.INITIAL:
|
||||
self._datagram_needs_padding = True
|
||||
|
||||
# For datagrams containing 1-RTT data, we *must* apply the padding
|
||||
# inside the packet, we cannot tack bytes onto the end of the
|
||||
# datagram.
|
||||
if (
|
||||
self._datagram_needs_padding
|
||||
and self._packet_type == QuicPacketType.ONE_RTT
|
||||
):
|
||||
if self.remaining_flight_space > padding_size:
|
||||
padding_size = self.remaining_flight_space
|
||||
self._datagram_needs_padding = False
|
||||
|
||||
# write padding
|
||||
if padding_size > 0:
|
||||
buf.push_bytes(bytes(padding_size))
|
||||
packet_size += padding_size
|
||||
self._packet.in_flight = True
|
||||
|
||||
# log frame
|
||||
if self._quic_logger is not None:
|
||||
self._packet.quic_logger_frames.append(
|
||||
self._quic_logger.encode_padding_frame()
|
||||
)
|
||||
|
||||
# write header
|
||||
if self._packet_type != QuicPacketType.ONE_RTT:
|
||||
length = (
|
||||
packet_size
|
||||
- self._header_size
|
||||
+ PACKET_NUMBER_SEND_SIZE
|
||||
+ self._packet_crypto.aead_tag_size
|
||||
)
|
||||
|
||||
buf.seek(self._packet_start)
|
||||
buf.push_uint8(
|
||||
encode_long_header_first_byte(
|
||||
self._version, self._packet_type, PACKET_NUMBER_SEND_SIZE - 1
|
||||
)
|
||||
)
|
||||
buf.push_uint32(self._version)
|
||||
buf.push_uint8(len(self._peer_cid))
|
||||
buf.push_bytes(self._peer_cid)
|
||||
buf.push_uint8(len(self._host_cid))
|
||||
buf.push_bytes(self._host_cid)
|
||||
if self._packet_type == QuicPacketType.INITIAL:
|
||||
buf.push_uint_var(len(self._peer_token))
|
||||
buf.push_bytes(self._peer_token)
|
||||
buf.push_uint16(length | 0x4000)
|
||||
buf.push_uint16(self._packet_number & 0xFFFF)
|
||||
else:
|
||||
buf.seek(self._packet_start)
|
||||
buf.push_uint8(
|
||||
PACKET_FIXED_BIT
|
||||
| (self._spin_bit << 5)
|
||||
| (self._packet_crypto.key_phase << 2)
|
||||
| (PACKET_NUMBER_SEND_SIZE - 1)
|
||||
)
|
||||
buf.push_bytes(self._peer_cid)
|
||||
buf.push_uint16(self._packet_number & 0xFFFF)
|
||||
|
||||
# encrypt in place
|
||||
plain = buf.data_slice(self._packet_start, self._packet_start + packet_size)
|
||||
buf.seek(self._packet_start)
|
||||
buf.push_bytes(
|
||||
self._packet_crypto.encrypt_packet(
|
||||
plain[0 : self._header_size],
|
||||
plain[self._header_size : packet_size],
|
||||
self._packet_number,
|
||||
)
|
||||
)
|
||||
self._packet.sent_bytes = buf.tell() - self._packet_start
|
||||
self._packets.append(self._packet)
|
||||
if self._packet.in_flight:
|
||||
self._datagram_flight_bytes += self._packet.sent_bytes
|
||||
|
||||
# Short header packets cannot be coalesced, we need a new datagram.
|
||||
if self._packet_type == QuicPacketType.ONE_RTT:
|
||||
self._flush_current_datagram()
|
||||
|
||||
self._packet_number += 1
|
||||
else:
|
||||
# "cancel" the packet
|
||||
buf.seek(self._packet_start)
|
||||
|
||||
self._packet = None
|
||||
self.quic_logger_frames = None
|
||||
|
||||
def _flush_current_datagram(self) -> None:
|
||||
datagram_bytes = self._buffer.tell()
|
||||
if datagram_bytes:
|
||||
# Padding for datagrams containing initial packets; see RFC 9000
|
||||
# section 14.1.
|
||||
if self._datagram_needs_padding:
|
||||
extra_bytes = self._flight_capacity - self._buffer.tell()
|
||||
if extra_bytes > 0:
|
||||
self._buffer.push_bytes(bytes(extra_bytes))
|
||||
self._datagram_flight_bytes += extra_bytes
|
||||
datagram_bytes += extra_bytes
|
||||
|
||||
self._datagrams.append(self._buffer.data)
|
||||
self._flight_bytes += self._datagram_flight_bytes
|
||||
self._total_bytes += datagram_bytes
|
||||
self._datagram_init = True
|
||||
self._buffer.seek(0)
|
||||
503
.venv/lib/python3.9/site-packages/qh3/quic/recovery.py
Normal file
503
.venv/lib/python3.9/site-packages/qh3/quic/recovery.py
Normal file
@@ -0,0 +1,503 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Callable, Iterable
|
||||
|
||||
from .._hazmat import QuicPacketPacer, QuicRttMonitor, RangeSet
|
||||
from .logger import QuicLoggerTrace
|
||||
from .packet_builder import QuicDeliveryState, QuicSentPacket
|
||||
|
||||
# loss detection
|
||||
K_PACKET_THRESHOLD = 3
|
||||
K_GRANULARITY = 0.001 # seconds
|
||||
K_TIME_THRESHOLD = 9 / 8
|
||||
K_MICRO_SECOND = 0.000001
|
||||
K_SECOND = 1.0
|
||||
|
||||
# congestion control
|
||||
K_INITIAL_WINDOW = 10
|
||||
K_MINIMUM_WINDOW = 2
|
||||
|
||||
# Cubic constants (RFC 9438)
|
||||
K_CUBIC_C = 0.4
|
||||
K_CUBIC_LOSS_REDUCTION_FACTOR = 0.7
|
||||
K_CUBIC_MAX_IDLE_TIME = 2.0 # seconds
|
||||
|
||||
|
||||
def _cubic_root(x: float) -> float:
|
||||
if x < 0:
|
||||
return -((-x) ** (1.0 / 3.0))
|
||||
return x ** (1.0 / 3.0)
|
||||
|
||||
|
||||
class QuicPacketSpace:
|
||||
def __init__(self) -> None:
|
||||
self.ack_at: float | None = None
|
||||
self.ack_queue = RangeSet()
|
||||
self.discarded = False
|
||||
self.expected_packet_number = 0
|
||||
self.largest_received_packet = -1
|
||||
self.largest_received_time: float | None = None
|
||||
|
||||
# sent packets and loss
|
||||
self.ack_eliciting_in_flight = 0
|
||||
self.largest_acked_packet = 0
|
||||
self.loss_time: float | None = None
|
||||
self.sent_packets: dict[int, QuicSentPacket] = {}
|
||||
|
||||
|
||||
class QuicCongestionControl:
|
||||
"""
|
||||
Cubic congestion control (RFC 9438).
|
||||
"""
|
||||
|
||||
def __init__(self, max_datagram_size: int) -> None:
|
||||
self._max_datagram_size = max_datagram_size
|
||||
self._rtt_monitor = QuicRttMonitor()
|
||||
self._congestion_recovery_start_time = 0.0
|
||||
self._rtt = 0.02 # initial RTT estimate (20 ms)
|
||||
self._last_ack = 0.0
|
||||
|
||||
self.bytes_in_flight = 0
|
||||
self.congestion_window = max_datagram_size * K_INITIAL_WINDOW
|
||||
self.ssthresh: int | None = None
|
||||
|
||||
# Cubic state
|
||||
self._first_slow_start = True
|
||||
self._starting_congestion_avoidance = False
|
||||
self._K: float = 0.0
|
||||
self._W_max: int = self.congestion_window
|
||||
self._W_est: int = 0
|
||||
self._cwnd_epoch: int = 0
|
||||
self._t_epoch: float = 0.0
|
||||
|
||||
def _W_cubic(self, t: float) -> int:
|
||||
W_max_segments = self._W_max / self._max_datagram_size
|
||||
target_segments = K_CUBIC_C * (t - self._K) ** 3 + W_max_segments
|
||||
return int(target_segments * self._max_datagram_size)
|
||||
|
||||
def _reset(self) -> None:
|
||||
self.congestion_window = self._max_datagram_size * K_INITIAL_WINDOW
|
||||
self.ssthresh = None
|
||||
self._first_slow_start = True
|
||||
self._starting_congestion_avoidance = False
|
||||
self._K = 0.0
|
||||
self._W_max = self.congestion_window
|
||||
self._W_est = 0
|
||||
self._cwnd_epoch = 0
|
||||
self._t_epoch = 0.0
|
||||
|
||||
def _start_epoch(self, now: float) -> None:
|
||||
self._t_epoch = now
|
||||
self._cwnd_epoch = self.congestion_window
|
||||
self._W_est = self._cwnd_epoch
|
||||
W_max_seg = self._W_max / self._max_datagram_size
|
||||
cwnd_seg = self._cwnd_epoch / self._max_datagram_size
|
||||
self._K = _cubic_root((W_max_seg - cwnd_seg) / K_CUBIC_C)
|
||||
|
||||
def on_packet_acked(self, packet: QuicSentPacket) -> None:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
self._last_ack = packet.sent_time
|
||||
|
||||
if self.ssthresh is None or self.congestion_window < self.ssthresh:
|
||||
# slow start
|
||||
self.congestion_window += packet.sent_bytes
|
||||
else:
|
||||
# congestion avoidance
|
||||
if self._first_slow_start and not self._starting_congestion_avoidance:
|
||||
# exiting slow start without a loss (HyStart triggered)
|
||||
self._first_slow_start = False
|
||||
self._W_max = self.congestion_window
|
||||
self._start_epoch(packet.sent_time)
|
||||
|
||||
if self._starting_congestion_avoidance:
|
||||
# entering congestion avoidance after a loss
|
||||
self._starting_congestion_avoidance = False
|
||||
self._first_slow_start = False
|
||||
self._start_epoch(packet.sent_time)
|
||||
|
||||
# TCP-friendly estimate (Reno-like linear growth)
|
||||
self._W_est = int(
|
||||
self._W_est
|
||||
+ self._max_datagram_size * (packet.sent_bytes / self.congestion_window)
|
||||
)
|
||||
|
||||
t = packet.sent_time - self._t_epoch
|
||||
W_cubic = self._W_cubic(t + self._rtt)
|
||||
|
||||
# clamp target
|
||||
if W_cubic < self.congestion_window:
|
||||
target = self.congestion_window
|
||||
elif W_cubic > int(1.5 * self.congestion_window):
|
||||
target = int(1.5 * self.congestion_window)
|
||||
else:
|
||||
target = W_cubic
|
||||
|
||||
if self._W_cubic(t) < self._W_est:
|
||||
# Reno-friendly region
|
||||
self.congestion_window = self._W_est
|
||||
else:
|
||||
# concave / convex region
|
||||
self.congestion_window = int(
|
||||
self.congestion_window
|
||||
+ (target - self.congestion_window)
|
||||
* (self._max_datagram_size / self.congestion_window)
|
||||
)
|
||||
|
||||
def on_packet_sent(self, packet: QuicSentPacket) -> None:
|
||||
self.bytes_in_flight += packet.sent_bytes
|
||||
# reset cwnd after prolonged idle
|
||||
if self._last_ack > 0.0:
|
||||
elapsed_idle = packet.sent_time - self._last_ack
|
||||
if elapsed_idle >= K_CUBIC_MAX_IDLE_TIME:
|
||||
self._reset()
|
||||
|
||||
def on_packets_expired(self, packets: Iterable[QuicSentPacket]) -> None:
|
||||
for packet in packets:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
|
||||
def on_packets_lost(self, packets: Iterable[QuicSentPacket], now: float) -> None:
|
||||
lost_largest_time = 0.0
|
||||
for packet in packets:
|
||||
self.bytes_in_flight -= packet.sent_bytes
|
||||
lost_largest_time = packet.sent_time
|
||||
|
||||
# start a new congestion event if packet was sent after the
|
||||
# start of the previous congestion recovery period.
|
||||
if lost_largest_time > self._congestion_recovery_start_time:
|
||||
self._congestion_recovery_start_time = now
|
||||
|
||||
# fast convergence: if W_max is decreasing, reduce it further
|
||||
if self.congestion_window < self._W_max:
|
||||
self._W_max = int(
|
||||
self.congestion_window * (1 + K_CUBIC_LOSS_REDUCTION_FACTOR) / 2
|
||||
)
|
||||
else:
|
||||
self._W_max = self.congestion_window
|
||||
|
||||
self.congestion_window = max(
|
||||
int(self.congestion_window * K_CUBIC_LOSS_REDUCTION_FACTOR),
|
||||
self._max_datagram_size * K_MINIMUM_WINDOW,
|
||||
)
|
||||
self.ssthresh = self.congestion_window
|
||||
self._starting_congestion_avoidance = True
|
||||
|
||||
def on_rtt_measurement(self, latest_rtt: float, now: float) -> None:
|
||||
self._rtt = latest_rtt
|
||||
# check whether we should exit slow start
|
||||
if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing(
|
||||
latest_rtt, now
|
||||
):
|
||||
self.ssthresh = self.congestion_window
|
||||
|
||||
|
||||
class QuicPacketRecovery:
|
||||
"""
|
||||
Packet loss and congestion controller.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initial_rtt: float,
|
||||
peer_completed_address_validation: bool,
|
||||
send_probe: Callable[[], None],
|
||||
max_datagram_size: int = 1280,
|
||||
logger: logging.LoggerAdapter | None = None,
|
||||
quic_logger: QuicLoggerTrace | None = None,
|
||||
) -> None:
|
||||
self.max_ack_delay = 0.025
|
||||
self.peer_completed_address_validation = peer_completed_address_validation
|
||||
self.spaces: list[QuicPacketSpace] = []
|
||||
|
||||
# callbacks
|
||||
self._logger = logger
|
||||
self._quic_logger = quic_logger
|
||||
self._send_probe = send_probe
|
||||
|
||||
# loss detection
|
||||
self._pto_count = 0
|
||||
self._rtt_initial = initial_rtt
|
||||
self._rtt_initialized = False
|
||||
self._rtt_latest = 0.0
|
||||
self._rtt_min = math.inf
|
||||
self._rtt_smoothed = 0.0
|
||||
self._rtt_variance = 0.0
|
||||
self._time_of_last_sent_ack_eliciting_packet = 0.0
|
||||
|
||||
# congestion control
|
||||
self._cc = QuicCongestionControl(max_datagram_size)
|
||||
self._pacer = QuicPacketPacer(max_datagram_size)
|
||||
|
||||
@property
|
||||
def bytes_in_flight(self) -> int:
|
||||
return self._cc.bytes_in_flight
|
||||
|
||||
@property
|
||||
def congestion_window(self) -> int:
|
||||
return self._cc.congestion_window
|
||||
|
||||
def discard_space(self, space: QuicPacketSpace) -> None:
|
||||
assert space in self.spaces
|
||||
|
||||
self._cc.on_packets_expired(
|
||||
filter(lambda x: x.in_flight, space.sent_packets.values())
|
||||
)
|
||||
space.sent_packets.clear()
|
||||
|
||||
space.ack_at = None
|
||||
space.ack_eliciting_in_flight = 0
|
||||
space.loss_time = None
|
||||
|
||||
# reset PTO count
|
||||
self._pto_count = 0
|
||||
|
||||
if self._quic_logger is not None:
|
||||
self._log_metrics_updated()
|
||||
|
||||
def get_loss_detection_time(self) -> float:
|
||||
# loss timer
|
||||
loss_space = self._get_loss_space()
|
||||
if loss_space is not None:
|
||||
return loss_space.loss_time
|
||||
|
||||
# packet timer
|
||||
if (
|
||||
not self.peer_completed_address_validation
|
||||
or sum(space.ack_eliciting_in_flight for space in self.spaces) > 0
|
||||
):
|
||||
timeout = self.get_probe_timeout() * (2**self._pto_count)
|
||||
return self._time_of_last_sent_ack_eliciting_packet + timeout
|
||||
|
||||
return None
|
||||
|
||||
def get_probe_timeout(self) -> float:
|
||||
if not self._rtt_initialized:
|
||||
return 2 * self._rtt_initial
|
||||
return (
|
||||
self._rtt_smoothed
|
||||
+ max(4 * self._rtt_variance, K_GRANULARITY)
|
||||
+ self.max_ack_delay
|
||||
)
|
||||
|
||||
def on_ack_received(
|
||||
self,
|
||||
space: QuicPacketSpace,
|
||||
ack_rangeset: RangeSet,
|
||||
ack_delay: float,
|
||||
now: float,
|
||||
) -> None:
|
||||
"""
|
||||
Update metrics as the result of an ACK being received.
|
||||
"""
|
||||
is_ack_eliciting = False
|
||||
largest_acked = ack_rangeset.bounds()[1] - 1
|
||||
largest_newly_acked = None
|
||||
largest_sent_time = None
|
||||
|
||||
if largest_acked > space.largest_acked_packet:
|
||||
space.largest_acked_packet = largest_acked
|
||||
|
||||
for packet_number in sorted(space.sent_packets.keys()):
|
||||
if packet_number > largest_acked:
|
||||
break
|
||||
if packet_number in ack_rangeset:
|
||||
# remove packet and update counters
|
||||
packet = space.sent_packets.pop(packet_number)
|
||||
if packet.is_ack_eliciting:
|
||||
is_ack_eliciting = True
|
||||
space.ack_eliciting_in_flight -= 1
|
||||
if packet.in_flight:
|
||||
self._cc.on_packet_acked(packet)
|
||||
largest_newly_acked = packet_number
|
||||
largest_sent_time = packet.sent_time
|
||||
|
||||
# trigger callbacks
|
||||
for handler, args in packet.delivery_handlers:
|
||||
handler(QuicDeliveryState.ACKED, *args)
|
||||
|
||||
# nothing to do if there are no newly acked packets
|
||||
if largest_newly_acked is None:
|
||||
return
|
||||
|
||||
if largest_acked == largest_newly_acked and is_ack_eliciting:
|
||||
latest_rtt = now - largest_sent_time
|
||||
log_rtt = True
|
||||
|
||||
# limit ACK delay to max_ack_delay
|
||||
ack_delay = min(ack_delay, self.max_ack_delay)
|
||||
|
||||
# update RTT estimate, which cannot be < 1 ms
|
||||
self._rtt_latest = max(latest_rtt, 0.001)
|
||||
if self._rtt_latest < self._rtt_min:
|
||||
self._rtt_min = self._rtt_latest
|
||||
if self._rtt_latest > self._rtt_min + ack_delay:
|
||||
self._rtt_latest -= ack_delay
|
||||
|
||||
if not self._rtt_initialized:
|
||||
self._rtt_initialized = True
|
||||
self._rtt_variance = latest_rtt / 2
|
||||
self._rtt_smoothed = latest_rtt
|
||||
else:
|
||||
self._rtt_variance = 3 / 4 * self._rtt_variance + 1 / 4 * abs(
|
||||
self._rtt_min - self._rtt_latest
|
||||
)
|
||||
self._rtt_smoothed = (
|
||||
7 / 8 * self._rtt_smoothed + 1 / 8 * self._rtt_latest
|
||||
)
|
||||
|
||||
# inform congestion controller
|
||||
self._cc.on_rtt_measurement(latest_rtt, now=now)
|
||||
self._pacer.update_rate(
|
||||
congestion_window=self._cc.congestion_window,
|
||||
smoothed_rtt=self._rtt_smoothed,
|
||||
)
|
||||
|
||||
else:
|
||||
log_rtt = False
|
||||
|
||||
self._detect_loss(space, now=now)
|
||||
|
||||
# reset PTO count
|
||||
self._pto_count = 0
|
||||
|
||||
if self._quic_logger is not None:
|
||||
self._log_metrics_updated(log_rtt=log_rtt)
|
||||
|
||||
def on_loss_detection_timeout(self, now: float) -> None:
|
||||
loss_space = self._get_loss_space()
|
||||
if loss_space is not None:
|
||||
self._detect_loss(loss_space, now=now)
|
||||
else:
|
||||
self._pto_count += 1
|
||||
self.reschedule_data(now=now)
|
||||
|
||||
def on_packet_sent(self, packet: QuicSentPacket, space: QuicPacketSpace) -> None:
|
||||
space.sent_packets[packet.packet_number] = packet
|
||||
|
||||
if packet.is_ack_eliciting:
|
||||
space.ack_eliciting_in_flight += 1
|
||||
if packet.in_flight:
|
||||
if packet.is_ack_eliciting:
|
||||
self._time_of_last_sent_ack_eliciting_packet = packet.sent_time
|
||||
|
||||
# add packet to bytes in flight
|
||||
self._cc.on_packet_sent(packet)
|
||||
|
||||
if self._quic_logger is not None:
|
||||
self._log_metrics_updated()
|
||||
|
||||
def reschedule_data(self, now: float) -> None:
|
||||
"""
|
||||
Schedule some data for retransmission.
|
||||
"""
|
||||
# if there is any outstanding CRYPTO, retransmit it
|
||||
crypto_scheduled = False
|
||||
for space in self.spaces:
|
||||
packets = tuple(
|
||||
filter(lambda i: i.is_crypto_packet, space.sent_packets.values())
|
||||
)
|
||||
if packets:
|
||||
self._on_packets_lost(packets, space=space, now=now)
|
||||
crypto_scheduled = True
|
||||
if crypto_scheduled and self._logger is not None:
|
||||
self._logger.debug("Scheduled CRYPTO data for retransmission")
|
||||
|
||||
# ensure an ACK-elliciting packet is sent
|
||||
self._send_probe()
|
||||
|
||||
def _detect_loss(self, space: QuicPacketSpace, now: float) -> None:
|
||||
"""
|
||||
Check whether any packets should be declared lost.
|
||||
"""
|
||||
loss_delay = K_TIME_THRESHOLD * (
|
||||
max(self._rtt_latest, self._rtt_smoothed)
|
||||
if self._rtt_initialized
|
||||
else self._rtt_initial
|
||||
)
|
||||
packet_threshold = space.largest_acked_packet - K_PACKET_THRESHOLD
|
||||
time_threshold = now - loss_delay
|
||||
|
||||
lost_packets = []
|
||||
space.loss_time = None
|
||||
for packet_number, packet in space.sent_packets.items():
|
||||
if packet_number > space.largest_acked_packet:
|
||||
break
|
||||
|
||||
if packet_number <= packet_threshold or packet.sent_time <= time_threshold:
|
||||
lost_packets.append(packet)
|
||||
else:
|
||||
packet_loss_time = packet.sent_time + loss_delay
|
||||
if space.loss_time is None or space.loss_time > packet_loss_time:
|
||||
space.loss_time = packet_loss_time
|
||||
|
||||
self._on_packets_lost(lost_packets, space=space, now=now)
|
||||
|
||||
def _get_loss_space(self) -> QuicPacketSpace | None:
|
||||
loss_space = None
|
||||
for space in self.spaces:
|
||||
if space.loss_time is not None and (
|
||||
loss_space is None or space.loss_time < loss_space.loss_time
|
||||
):
|
||||
loss_space = space
|
||||
return loss_space
|
||||
|
||||
def _log_metrics_updated(self, log_rtt=False) -> None:
|
||||
data: dict[str, Any] = {
|
||||
"bytes_in_flight": self._cc.bytes_in_flight,
|
||||
"cwnd": self._cc.congestion_window,
|
||||
}
|
||||
if self._cc.ssthresh is not None:
|
||||
data["ssthresh"] = self._cc.ssthresh
|
||||
|
||||
if log_rtt:
|
||||
data.update(
|
||||
{
|
||||
"latest_rtt": self._quic_logger.encode_time(self._rtt_latest),
|
||||
"min_rtt": self._quic_logger.encode_time(self._rtt_min),
|
||||
"smoothed_rtt": self._quic_logger.encode_time(self._rtt_smoothed),
|
||||
"rtt_variance": self._quic_logger.encode_time(self._rtt_variance),
|
||||
}
|
||||
)
|
||||
|
||||
self._quic_logger.log_event(
|
||||
category="recovery", event="metrics_updated", data=data
|
||||
)
|
||||
|
||||
def _on_packets_lost(
|
||||
self, packets: Iterable[QuicSentPacket], space: QuicPacketSpace, now: float
|
||||
) -> None:
|
||||
lost_packets_cc = []
|
||||
for packet in packets:
|
||||
del space.sent_packets[packet.packet_number]
|
||||
|
||||
if packet.in_flight:
|
||||
lost_packets_cc.append(packet)
|
||||
|
||||
if packet.is_ack_eliciting:
|
||||
space.ack_eliciting_in_flight -= 1
|
||||
|
||||
if self._quic_logger is not None:
|
||||
self._quic_logger.log_event(
|
||||
category="recovery",
|
||||
event="packet_lost",
|
||||
data={
|
||||
"type": self._quic_logger.packet_type(packet.packet_type),
|
||||
"packet_number": packet.packet_number,
|
||||
},
|
||||
)
|
||||
self._log_metrics_updated()
|
||||
|
||||
# trigger callbacks
|
||||
for handler, args in packet.delivery_handlers:
|
||||
handler(QuicDeliveryState.LOST, *args)
|
||||
|
||||
# inform congestion controller
|
||||
if lost_packets_cc:
|
||||
self._cc.on_packets_lost(lost_packets_cc, now=now)
|
||||
self._pacer.update_rate(
|
||||
congestion_window=self._cc.congestion_window,
|
||||
smoothed_rtt=self._rtt_smoothed,
|
||||
)
|
||||
if self._quic_logger is not None:
|
||||
self._log_metrics_updated()
|
||||
39
.venv/lib/python3.9/site-packages/qh3/quic/retry.py
Normal file
39
.venv/lib/python3.9/site-packages/qh3/quic/retry.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
|
||||
from .._hazmat import Buffer, Rsa
|
||||
from ..tls import pull_opaque, push_opaque
|
||||
from .connection import NetworkAddress
|
||||
|
||||
|
||||
def encode_address(addr: NetworkAddress) -> bytes:
|
||||
return ipaddress.ip_address(addr[0]).packed + bytes([addr[1] >> 8, addr[1] & 0xFF])
|
||||
|
||||
|
||||
class QuicRetryTokenHandler:
|
||||
def __init__(self) -> None:
|
||||
self._key = Rsa(key_size=2048)
|
||||
|
||||
def create_token(
|
||||
self,
|
||||
addr: NetworkAddress,
|
||||
original_destination_connection_id: bytes,
|
||||
retry_source_connection_id: bytes,
|
||||
) -> bytes:
|
||||
buf = Buffer(capacity=512)
|
||||
push_opaque(buf, 1, encode_address(addr))
|
||||
push_opaque(buf, 1, original_destination_connection_id)
|
||||
push_opaque(buf, 1, retry_source_connection_id)
|
||||
return self._key.encrypt(buf.data)
|
||||
|
||||
def validate_token(self, addr: NetworkAddress, token: bytes) -> tuple[bytes, bytes]:
|
||||
if not token or len(token) != 256:
|
||||
raise ValueError("Ciphertext length must be equal to key size.")
|
||||
buf = Buffer(data=self._key.decrypt(token))
|
||||
encoded_addr = pull_opaque(buf, 1)
|
||||
original_destination_connection_id = pull_opaque(buf, 1)
|
||||
retry_source_connection_id = pull_opaque(buf, 1)
|
||||
if encoded_addr != encode_address(addr):
|
||||
raise ValueError("Remote address does not match.")
|
||||
return original_destination_connection_id, retry_source_connection_id
|
||||
380
.venv/lib/python3.9/site-packages/qh3/quic/stream.py
Normal file
380
.venv/lib/python3.9/site-packages/qh3/quic/stream.py
Normal file
@@ -0,0 +1,380 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .._hazmat import RangeSet
|
||||
from . import events
|
||||
from .packet import (
|
||||
QuicErrorCode,
|
||||
QuicResetStreamFrame,
|
||||
QuicStopSendingFrame,
|
||||
QuicStreamFrame,
|
||||
)
|
||||
from .packet_builder import QuicDeliveryState
|
||||
|
||||
|
||||
class FinalSizeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class StreamFinishedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class QuicStreamReceiver:
|
||||
"""
|
||||
The receive part of a QUIC stream.
|
||||
|
||||
It finishes:
|
||||
- immediately for a send-only stream
|
||||
- upon reception of a STREAM_RESET frame
|
||||
- upon reception of a data frame with the FIN bit set
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"highest_offset",
|
||||
"is_finished",
|
||||
"stop_pending",
|
||||
"_buffer",
|
||||
"_buffer_start",
|
||||
"_final_size",
|
||||
"_ranges",
|
||||
"_stream_id",
|
||||
"_stop_error_code",
|
||||
)
|
||||
|
||||
def __init__(self, stream_id: int | None, readable: bool) -> None:
|
||||
self.highest_offset = 0 # the highest offset ever seen
|
||||
self.is_finished = False
|
||||
self.stop_pending = False
|
||||
|
||||
self._buffer = bytearray()
|
||||
self._buffer_start = 0 # the offset for the start of the buffer
|
||||
self._final_size: int | None = None
|
||||
self._ranges = RangeSet()
|
||||
self._stream_id = stream_id
|
||||
self._stop_error_code: int | None = None
|
||||
|
||||
def get_stop_frame(self) -> QuicStopSendingFrame:
|
||||
self.stop_pending = False
|
||||
return QuicStopSendingFrame(
|
||||
error_code=self._stop_error_code,
|
||||
stream_id=self._stream_id,
|
||||
)
|
||||
|
||||
def starting_offset(self) -> int:
|
||||
return self._buffer_start
|
||||
|
||||
def handle_frame(self, frame: QuicStreamFrame) -> events.StreamDataReceived | None:
|
||||
"""
|
||||
Handle a frame of received data.
|
||||
"""
|
||||
pos = frame.offset - self._buffer_start
|
||||
count = len(frame.data)
|
||||
frame_end = frame.offset + count
|
||||
|
||||
# we should receive no more data beyond FIN!
|
||||
if self._final_size is not None:
|
||||
if frame_end > self._final_size:
|
||||
raise FinalSizeError("Data received beyond final size")
|
||||
elif frame.fin and frame_end != self._final_size:
|
||||
raise FinalSizeError("Cannot change final size")
|
||||
if frame.fin:
|
||||
self._final_size = frame_end
|
||||
if frame_end > self.highest_offset:
|
||||
self.highest_offset = frame_end
|
||||
|
||||
# fast path: new in-order chunk
|
||||
if pos == 0 and count and not self._buffer:
|
||||
self._buffer_start += count
|
||||
if frame.fin:
|
||||
# all data up to the FIN has been received, we're done receiving
|
||||
self.is_finished = True
|
||||
return events.StreamDataReceived(
|
||||
data=frame.data, end_stream=frame.fin, stream_id=self._stream_id
|
||||
)
|
||||
|
||||
# discard duplicate data
|
||||
if pos < 0:
|
||||
frame.data = frame.data[-pos:]
|
||||
frame.offset -= pos
|
||||
pos = 0
|
||||
count = len(frame.data)
|
||||
|
||||
# marked received range
|
||||
if frame_end > frame.offset:
|
||||
self._ranges.add(frame.offset, frame_end)
|
||||
|
||||
# add new data
|
||||
gap = pos - len(self._buffer)
|
||||
if gap > 0:
|
||||
self._buffer += bytearray(gap)
|
||||
self._buffer[pos : pos + count] = frame.data
|
||||
|
||||
# return data from the front of the buffer
|
||||
data = self._pull_data()
|
||||
end_stream = self._buffer_start == self._final_size
|
||||
if end_stream:
|
||||
# all data up to the FIN has been received, we're done receiving
|
||||
self.is_finished = True
|
||||
if data or end_stream:
|
||||
return events.StreamDataReceived(
|
||||
data=data, end_stream=end_stream, stream_id=self._stream_id
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
def handle_reset(
|
||||
self, *, final_size: int, error_code: int = QuicErrorCode.NO_ERROR
|
||||
) -> events.StreamReset | None:
|
||||
"""
|
||||
Handle an abrupt termination of the receiving part of the QUIC stream.
|
||||
"""
|
||||
if self._final_size is not None and final_size != self._final_size:
|
||||
raise FinalSizeError("Cannot change final size")
|
||||
|
||||
# we are done receiving
|
||||
self._final_size = final_size
|
||||
self.is_finished = True
|
||||
return events.StreamReset(error_code=error_code, stream_id=self._stream_id)
|
||||
|
||||
def on_stop_sending_delivery(self, delivery: QuicDeliveryState) -> None:
|
||||
"""
|
||||
Callback when a STOP_SENDING is ACK'd.
|
||||
"""
|
||||
if delivery != QuicDeliveryState.ACKED:
|
||||
self.stop_pending = True
|
||||
|
||||
def stop(self, error_code: int = QuicErrorCode.NO_ERROR) -> None:
|
||||
"""
|
||||
Request the peer stop sending data on the QUIC stream.
|
||||
"""
|
||||
self._stop_error_code = error_code
|
||||
self.stop_pending = True
|
||||
|
||||
def _pull_data(self) -> bytes:
|
||||
"""
|
||||
Remove data from the front of the buffer.
|
||||
"""
|
||||
try:
|
||||
has_data_to_read = self._ranges[0][0] == self._buffer_start
|
||||
except IndexError:
|
||||
has_data_to_read = False
|
||||
if not has_data_to_read:
|
||||
return b""
|
||||
|
||||
r = self._ranges.shift()
|
||||
pos = r[1] - r[0]
|
||||
data = bytes(self._buffer[:pos])
|
||||
del self._buffer[:pos]
|
||||
self._buffer_start = r[1]
|
||||
return data
|
||||
|
||||
|
||||
class QuicStreamSender:
|
||||
"""
|
||||
The send part of a QUIC stream.
|
||||
|
||||
It finishes:
|
||||
- immediately for a receive-only stream
|
||||
- upon acknowledgement of a STREAM_RESET frame
|
||||
- upon acknowledgement of a data frame with the FIN bit set
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"buffer_is_empty",
|
||||
"highest_offset",
|
||||
"is_finished",
|
||||
"reset_pending",
|
||||
"_acked",
|
||||
"_buffer",
|
||||
"_buffer_fin",
|
||||
"_buffer_start",
|
||||
"_buffer_stop",
|
||||
"_pending",
|
||||
"_pending_eof",
|
||||
"_reset_error_code",
|
||||
"_stream_id",
|
||||
"send_buffer_empty",
|
||||
)
|
||||
|
||||
def __init__(self, stream_id: int | None, writable: bool) -> None:
|
||||
self.buffer_is_empty = True
|
||||
self.highest_offset = 0
|
||||
self.is_finished = not writable
|
||||
self.reset_pending = False
|
||||
|
||||
self._acked = RangeSet()
|
||||
self._buffer = bytearray()
|
||||
self._buffer_fin: int | None = None
|
||||
self._buffer_start = 0 # the offset for the start of the buffer
|
||||
self._buffer_stop = 0 # the offset for the stop of the buffer
|
||||
self._pending = RangeSet()
|
||||
self._pending_eof = False
|
||||
self._reset_error_code: int | None = None
|
||||
self._stream_id = stream_id
|
||||
|
||||
@property
|
||||
def next_offset(self) -> int:
|
||||
"""
|
||||
The offset for the next frame to send.
|
||||
|
||||
This is used to determine the space needed for the frame's `offset` field.
|
||||
"""
|
||||
try:
|
||||
return self._pending[0][0]
|
||||
except IndexError:
|
||||
return self._buffer_stop
|
||||
|
||||
def get_frame(
|
||||
self, max_size: int, max_offset: int | None = None
|
||||
) -> QuicStreamFrame | None:
|
||||
"""
|
||||
Get a frame of data to send.
|
||||
"""
|
||||
# get the first pending data range
|
||||
try:
|
||||
r = self._pending[0]
|
||||
except IndexError:
|
||||
if self._pending_eof:
|
||||
# FIN only
|
||||
self._pending_eof = False
|
||||
return QuicStreamFrame(fin=True, offset=self._buffer_fin)
|
||||
|
||||
self.buffer_is_empty = True
|
||||
return None
|
||||
|
||||
# apply flow control
|
||||
start = r[0]
|
||||
stop = min(r[1], start + max_size)
|
||||
if max_offset is not None and stop > max_offset:
|
||||
stop = max_offset
|
||||
if stop <= start:
|
||||
return None
|
||||
|
||||
# create frame
|
||||
frame = QuicStreamFrame(
|
||||
data=bytes(
|
||||
self._buffer[start - self._buffer_start : stop - self._buffer_start]
|
||||
),
|
||||
offset=start,
|
||||
)
|
||||
self._pending.subtract(start, stop)
|
||||
|
||||
# track the highest offset ever sent
|
||||
if stop > self.highest_offset:
|
||||
self.highest_offset = stop
|
||||
|
||||
# if the buffer is empty and EOF was written, set the FIN bit
|
||||
if self._buffer_fin == stop:
|
||||
frame.fin = True
|
||||
self._pending_eof = False
|
||||
|
||||
return frame
|
||||
|
||||
def get_reset_frame(self) -> QuicResetStreamFrame:
|
||||
self.reset_pending = False
|
||||
return QuicResetStreamFrame(
|
||||
error_code=self._reset_error_code,
|
||||
final_size=self.highest_offset,
|
||||
stream_id=self._stream_id,
|
||||
)
|
||||
|
||||
def on_data_delivery(
|
||||
self, delivery: QuicDeliveryState, start: int, stop: int
|
||||
) -> None:
|
||||
"""
|
||||
Callback when sent data is ACK'd.
|
||||
"""
|
||||
self.buffer_is_empty = False
|
||||
if delivery == QuicDeliveryState.ACKED:
|
||||
if stop > start:
|
||||
self._acked.add(start, stop)
|
||||
first_range = self._acked[0]
|
||||
if first_range[0] == self._buffer_start:
|
||||
size = first_range[1] - first_range[0]
|
||||
self._acked.shift()
|
||||
self._buffer_start += size
|
||||
del self._buffer[:size]
|
||||
|
||||
if self._buffer_start == self._buffer_fin:
|
||||
# all date up to the FIN has been ACK'd, we're done sending
|
||||
self.is_finished = True
|
||||
else:
|
||||
if stop > start:
|
||||
self._pending.add(start, stop)
|
||||
if stop == self._buffer_fin:
|
||||
self.send_buffer_empty = False
|
||||
self._pending_eof = True
|
||||
|
||||
def on_reset_delivery(self, delivery: QuicDeliveryState) -> None:
|
||||
"""
|
||||
Callback when a reset is ACK'd.
|
||||
"""
|
||||
if delivery == QuicDeliveryState.ACKED:
|
||||
# the reset has been ACK'd, we're done sending
|
||||
self.is_finished = True
|
||||
else:
|
||||
self.reset_pending = True
|
||||
|
||||
def reset(self, error_code: int) -> None:
|
||||
"""
|
||||
Abruptly terminate the sending part of the QUIC stream.
|
||||
|
||||
Once this method has been called, any further calls to it
|
||||
will have no effect.
|
||||
"""
|
||||
if self._reset_error_code is None:
|
||||
self._reset_error_code = error_code
|
||||
self.reset_pending = True
|
||||
|
||||
# Prevent any more data from being sent or re-sent.
|
||||
self.buffer_is_empty = True
|
||||
|
||||
def write(self, data: bytes, end_stream: bool = False) -> None:
|
||||
"""
|
||||
Write some data bytes to the QUIC stream.
|
||||
"""
|
||||
assert self._buffer_fin is None, "cannot call write() after FIN"
|
||||
assert self._reset_error_code is None, "cannot call write() after reset()"
|
||||
size = len(data)
|
||||
|
||||
if size:
|
||||
self.buffer_is_empty = False
|
||||
self._pending.add(self._buffer_stop, self._buffer_stop + size)
|
||||
self._buffer += data
|
||||
self._buffer_stop += size
|
||||
if end_stream:
|
||||
self.buffer_is_empty = False
|
||||
self._buffer_fin = self._buffer_stop
|
||||
self._pending_eof = True
|
||||
|
||||
|
||||
class QuicStream:
|
||||
__slots__ = (
|
||||
"is_blocked",
|
||||
"max_stream_data_local",
|
||||
"max_stream_data_local_sent",
|
||||
"max_stream_data_remote",
|
||||
"receiver",
|
||||
"sender",
|
||||
"stream_id",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream_id: int | None = None,
|
||||
max_stream_data_local: int = 0,
|
||||
max_stream_data_remote: int = 0,
|
||||
readable: bool = True,
|
||||
writable: bool = True,
|
||||
) -> None:
|
||||
self.is_blocked = False
|
||||
self.max_stream_data_local = max_stream_data_local
|
||||
self.max_stream_data_local_sent = max_stream_data_local
|
||||
self.max_stream_data_remote = max_stream_data_remote
|
||||
self.receiver = QuicStreamReceiver(stream_id=stream_id, readable=readable)
|
||||
self.sender = QuicStreamSender(stream_id=stream_id, writable=writable)
|
||||
self.stream_id = stream_id
|
||||
|
||||
@property
|
||||
def is_finished(self) -> bool:
|
||||
return self.receiver.is_finished and self.sender.is_finished
|
||||
2278
.venv/lib/python3.9/site-packages/qh3/tls.py
Normal file
2278
.venv/lib/python3.9/site-packages/qh3/tls.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user