fix: 포트 충돌 회피 — note_bridge 8098, intent_service 8099

Jellyfin(8096), OrbStack(8097) 포트 충돌으로 변경.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Hyungi Ahn
2026-03-19 13:53:55 +09:00
parent dc08d29509
commit c2257d3a86
2709 changed files with 619549 additions and 10 deletions

View File

@@ -0,0 +1,191 @@
from __future__ import annotations
from dataclasses import dataclass, field
from os import PathLike
from re import split
from typing import TYPE_CHECKING, TextIO
if TYPE_CHECKING:
from .._hazmat import Certificate as X509Certificate
from .._hazmat import DsaPrivateKey, EcPrivateKey, Ed25519PrivateKey, RsaPrivateKey
from ..tls import (
CipherSuite,
SessionTicket,
load_pem_private_key,
load_pem_x509_certificates,
)
from .logger import QuicLogger
from .packet import QuicProtocolVersion
from .packet_builder import PACKET_MAX_SIZE
@dataclass
class QuicConfiguration:
"""
A QUIC configuration.
"""
alpn_protocols: list[str] | None = None
"""
A list of supported ALPN protocols.
"""
connection_id_length: int = 8
"""
The length in bytes of local connection IDs.
"""
idle_timeout: float = 60.0
"""
The idle timeout in seconds.
The connection is terminated if nothing is received for the given duration.
"""
is_client: bool = True
"""
Whether this is the client side of the QUIC connection.
"""
max_data: int = 1048576
"""
Connection-wide flow control limit.
"""
max_datagram_size: int = PACKET_MAX_SIZE
"""
The maximum QUIC payload size in bytes to send, excluding UDP or IP overhead.
"""
probe_datagram_size: bool = True
"""
Enable path MTU discovery. Client-only.
"""
max_stream_data: int = 1048576
"""
Per-stream flow control limit.
"""
quic_logger: QuicLogger | None = None
"""
The :class:`~qh3.quic.logger.QuicLogger` instance to log events to.
"""
secrets_log_file: TextIO = None
"""
A file-like object in which to log traffic secrets.
This is useful to analyze traffic captures with Wireshark.
"""
server_name: str | None = None
"""
The server name to send during the TLS handshake the Server Name Indication.
.. note:: This is only used by clients.
"""
session_ticket: SessionTicket | None = None
"""
The TLS session ticket which should be used for session resumption.
"""
hostname_checks_common_name: bool = False
assert_fingerprint: str | None = None
verify_hostname: bool = True
cadata: bytes | None = None
cafile: str | None = None
capath: str | None = None
certificate: X509Certificate | None = None
certificate_chain: list[X509Certificate] = field(default_factory=list)
cipher_suites: list[CipherSuite] | None = None
initial_rtt: float = 0.1
max_datagram_frame_size: int | None = None
original_version: int | None = None
private_key: (
EcPrivateKey | Ed25519PrivateKey | DsaPrivateKey | RsaPrivateKey | None
) = None
quantum_readiness_test: bool = False
supported_versions: list[int] = field(
default_factory=lambda: [
QuicProtocolVersion.VERSION_1,
QuicProtocolVersion.VERSION_2,
]
)
verify_mode: int | None = None
def load_cert_chain(
self,
certfile: str | bytes | PathLike,
keyfile: str | bytes | PathLike | None = None,
password: bytes | str | None = None,
) -> None:
"""
Load a private key and the corresponding certificate.
"""
if isinstance(certfile, str):
certfile = certfile.encode()
elif isinstance(certfile, PathLike):
certfile = str(certfile).encode()
if keyfile is not None:
if isinstance(keyfile, str):
keyfile = keyfile.encode()
elif isinstance(keyfile, PathLike):
keyfile = str(keyfile).encode()
# we either have the certificate or a file path in certfile/keyfile.
if b"-----BEGIN" not in certfile:
with open(certfile, "rb") as fp:
certfile = fp.read()
if keyfile is not None:
with open(keyfile, "rb") as fp:
keyfile = fp.read()
is_crlf = b"-----\r\n" in certfile
boundary = (
b"-----BEGIN PRIVATE KEY-----\n"
if not is_crlf
else b"-----BEGIN PRIVATE KEY-----\r\n"
)
chunks = split(b"\n" + boundary, certfile)
certificates = load_pem_x509_certificates(chunks[0])
if len(chunks) == 2:
private_key = boundary + chunks[1]
self.private_key = load_pem_private_key(private_key)
self.certificate = certificates[0]
self.certificate_chain = certificates[1:]
if keyfile is not None:
self.private_key = load_pem_private_key(
keyfile,
password=(
password.encode("utf8") if isinstance(password, str) else password
),
)
def load_verify_locations(
self,
cafile: str | None = None,
capath: str | None = None,
cadata: bytes | None = None,
) -> None:
"""
Load a set of "certification authority" (CA) certificates used to
validate other peers' certificates.
"""
self.cafile = cafile
self.capath = capath
self.cadata = cadata

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,284 @@
from __future__ import annotations
import binascii
from typing import Callable
from .._hazmat import (
AeadAes128Gcm,
AeadAes256Gcm,
AeadChaCha20Poly1305,
CryptoError,
decode_packet_number,
)
from .._hazmat import QUICHeaderProtection as HeaderProtection
from ..tls import CipherSuite, cipher_suite_hash, hkdf_expand_label, hkdf_extract
from .packet import (
QuicProtocolVersion,
is_long_header,
)
CIPHER_SUITES = {
CipherSuite.AES_128_GCM_SHA256: (b"aes-128-ecb", b"aes-128-gcm"),
CipherSuite.CHACHA20_POLY1305_SHA256: (b"chacha20", b"chacha20-poly1305"),
CipherSuite.AES_256_GCM_SHA384: (b"aes-256-ecb", b"aes-256-gcm"),
}
INITIAL_CIPHER_SUITE = CipherSuite.AES_128_GCM_SHA256
INITIAL_SALT_VERSION_1 = binascii.unhexlify("38762cf7f55934b34d179ae6a4c80cadccbb7f0a")
INITIAL_SALT_VERSION_2 = binascii.unhexlify("0dede3def700a6db819381be6e269dcbf9bd2ed9")
SAMPLE_SIZE = 16
Callback = Callable[[str], None]
def NoCallback(trigger: str) -> None:
pass
class KeyUnavailableError(CryptoError):
pass
def derive_key_iv_hp(
*, cipher_suite: CipherSuite, secret: bytes, version: int
) -> tuple[bytes, bytes, bytes]:
algorithm = cipher_suite_hash(cipher_suite)
if cipher_suite in [
CipherSuite.AES_256_GCM_SHA384,
CipherSuite.CHACHA20_POLY1305_SHA256,
]:
key_size = 32
else:
key_size = 16
if version == QuicProtocolVersion.VERSION_2:
return (
hkdf_expand_label(algorithm, secret, b"quicv2 key", b"", key_size),
hkdf_expand_label(algorithm, secret, b"quicv2 iv", b"", 12),
hkdf_expand_label(algorithm, secret, b"quicv2 hp", b"", key_size),
)
else:
return (
hkdf_expand_label(algorithm, secret, b"quic key", b"", key_size),
hkdf_expand_label(algorithm, secret, b"quic iv", b"", 12),
hkdf_expand_label(algorithm, secret, b"quic hp", b"", key_size),
)
class CryptoContext:
__slots__ = (
"aead",
"cipher_suite",
"hp",
"key_phase",
"secret",
"version",
"_setup_cb",
"_teardown_cb",
)
def __init__(
self,
key_phase: int = 0,
setup_cb: Callback = NoCallback,
teardown_cb: Callback = NoCallback,
) -> None:
self.aead: AeadChaCha20Poly1305 | AeadAes128Gcm | AeadAes256Gcm | None = None
self.cipher_suite: CipherSuite | None = None
self.hp: HeaderProtection | None = None
self.key_phase = key_phase
self.secret: bytes | None = None
self.version: int | None = None
self._setup_cb = setup_cb
self._teardown_cb = teardown_cb
def decrypt_packet(
self, packet: bytes, encrypted_offset: int, expected_packet_number: int
) -> tuple[bytes, bytes, int, bool]:
if self.aead is None:
raise KeyUnavailableError("Decryption key is not available")
# header protection
plain_header, packet_number = self.hp.remove(packet, encrypted_offset)
first_byte = plain_header[0]
# packet number
pn_length = (first_byte & 0x03) + 1
packet_number = decode_packet_number(
packet_number, pn_length * 8, expected_packet_number
)
# detect key phase change
crypto = self
if not is_long_header(first_byte):
key_phase = (first_byte & 4) >> 2
if key_phase != self.key_phase:
crypto = next_key_phase(self)
# payload protection
payload = crypto.aead.decrypt(
packet_number, packet[len(plain_header) :], plain_header
)
return plain_header, payload, packet_number, crypto != self
def encrypt_packet(
self, plain_header: bytes, plain_payload: bytes, packet_number: int
) -> bytes:
assert self.is_valid(), "Encryption key is not available"
# payload protection
protected_payload = self.aead.encrypt(
packet_number, plain_payload, plain_header
)
# header protection
return self.hp.apply(plain_header, protected_payload)
def is_valid(self) -> bool:
return self.aead is not None
def setup(self, *, cipher_suite: CipherSuite, secret: bytes, version: int) -> None:
hp_cipher_name, aead_cipher_name = CIPHER_SUITES[cipher_suite]
key, iv, hp = derive_key_iv_hp(
cipher_suite=cipher_suite,
secret=secret,
version=version,
)
if aead_cipher_name == b"chacha20-poly1305":
self.aead = AeadChaCha20Poly1305(key, iv)
elif aead_cipher_name == b"aes-256-gcm":
self.aead = AeadAes256Gcm(key, iv)
elif aead_cipher_name == b"aes-128-gcm":
self.aead = AeadAes128Gcm(key, iv)
else:
raise CryptoError(f"Invalid cipher name: {aead_cipher_name.decode()}")
self.cipher_suite = cipher_suite
self.hp = HeaderProtection(hp_cipher_name.decode(), hp)
self.secret = secret
self.version = version
# trigger callback
self._setup_cb("tls")
def teardown(self) -> None:
self.aead = None
self.cipher_suite = None
self.hp = None
self.secret = None
# trigger callback
self._teardown_cb("tls")
def apply_key_phase(self: CryptoContext, crypto: CryptoContext, trigger: str) -> None:
self.aead = crypto.aead
self.key_phase = crypto.key_phase
self.secret = crypto.secret
# trigger callback
self._setup_cb(trigger)
def next_key_phase(self: CryptoContext) -> CryptoContext:
algorithm = cipher_suite_hash(self.cipher_suite)
crypto = CryptoContext(key_phase=int(not self.key_phase))
crypto.setup(
cipher_suite=self.cipher_suite,
secret=hkdf_expand_label(
algorithm, self.secret, b"quic ku", b"", int(algorithm / 8)
),
version=self.version,
)
return crypto
class CryptoPair:
__slots__ = (
"aead_tag_size",
"recv",
"send",
"_update_key_requested",
)
def __init__(
self,
recv_setup_cb: Callback = NoCallback,
recv_teardown_cb: Callback = NoCallback,
send_setup_cb: Callback = NoCallback,
send_teardown_cb: Callback = NoCallback,
) -> None:
self.aead_tag_size = 16
self.recv = CryptoContext(setup_cb=recv_setup_cb, teardown_cb=recv_teardown_cb)
self.send = CryptoContext(setup_cb=send_setup_cb, teardown_cb=send_teardown_cb)
self._update_key_requested = False
def decrypt_packet(
self, packet: bytes, encrypted_offset: int, expected_packet_number: int
) -> tuple[bytes, bytes, int]:
plain_header, payload, packet_number, update_key = self.recv.decrypt_packet(
packet, encrypted_offset, expected_packet_number
)
if update_key:
self._update_key("remote_update")
return plain_header, payload, packet_number
def encrypt_packet(
self, plain_header: bytes, plain_payload: bytes, packet_number: int
) -> bytes:
if self._update_key_requested:
self._update_key("local_update")
return self.send.encrypt_packet(plain_header, plain_payload, packet_number)
def setup_initial(self, cid: bytes, is_client: bool, version: int) -> None:
if is_client:
recv_label, send_label = b"server in", b"client in"
else:
recv_label, send_label = b"client in", b"server in"
if version == QuicProtocolVersion.VERSION_2:
initial_salt = INITIAL_SALT_VERSION_2
else:
initial_salt = INITIAL_SALT_VERSION_1
algorithm = cipher_suite_hash(INITIAL_CIPHER_SUITE)
digest_size = int(algorithm / 8)
initial_secret = hkdf_extract(algorithm, initial_salt, cid)
self.recv.setup(
cipher_suite=INITIAL_CIPHER_SUITE,
secret=hkdf_expand_label(
algorithm, initial_secret, recv_label, b"", digest_size
),
version=version,
)
self.send.setup(
cipher_suite=INITIAL_CIPHER_SUITE,
secret=hkdf_expand_label(
algorithm, initial_secret, send_label, b"", digest_size
),
version=version,
)
def teardown(self) -> None:
self.recv.teardown()
self.send.teardown()
def update_key(self) -> None:
self._update_key_requested = True
@property
def key_phase(self) -> int:
if self._update_key_requested:
return int(not self.recv.key_phase)
else:
return self.recv.key_phase
def _update_key(self, trigger: str) -> None:
apply_key_phase(self.recv, next_key_phase(self.recv), trigger=trigger)
apply_key_phase(self.send, next_key_phase(self.send), trigger=trigger)
self._update_key_requested = False

View File

@@ -0,0 +1,127 @@
from __future__ import annotations
from dataclasses import dataclass
class QuicEvent:
"""
Base class for QUIC events.
"""
pass
@dataclass
class ConnectionIdIssued(QuicEvent):
connection_id: bytes
@dataclass
class ConnectionIdRetired(QuicEvent):
connection_id: bytes
@dataclass
class ConnectionTerminated(QuicEvent):
"""
The ConnectionTerminated event is fired when the QUIC connection is terminated.
"""
error_code: int
"The error code which was specified when closing the connection."
frame_type: int | None
"The frame type which caused the connection to be closed, or `None`."
reason_phrase: str
"The human-readable reason for which the connection was closed."
@dataclass
class DatagramFrameReceived(QuicEvent):
"""
The DatagramFrameReceived event is fired when a DATAGRAM frame is received.
"""
data: bytes
"The data which was received."
@dataclass
class HandshakeCompleted(QuicEvent):
"""
The HandshakeCompleted event is fired when the TLS handshake completes.
"""
alpn_protocol: str | None
"The protocol which was negotiated using ALPN, or `None`."
early_data_accepted: bool
"Whether early (0-RTT) data was accepted by the remote peer."
session_resumed: bool
"Whether a TLS session was resumed."
@dataclass
class PingAcknowledged(QuicEvent):
"""
The PingAcknowledged event is fired when a PING frame is acknowledged.
"""
uid: int
"The unique ID of the PING."
@dataclass
class ProtocolNegotiated(QuicEvent):
"""
The ProtocolNegotiated event is fired when ALPN negotiation completes.
"""
alpn_protocol: str | None
"The protocol which was negotiated using ALPN, or `None`."
@dataclass
class StreamDataReceived(QuicEvent):
"""
The StreamDataReceived event is fired whenever data is received on a
stream.
"""
data: bytes
"The data which was received."
end_stream: bool
"Whether the STREAM frame had the FIN bit set."
stream_id: int
"The ID of the stream the data was received for."
@dataclass
class StopSendingReceived(QuicEvent):
"""
The StopSendingReceived event is fired when the remote peer requests
stopping data transmission on a stream.
"""
error_code: int
"The error code that was sent from the peer."
stream_id: int
"The ID of the stream that the peer requested stopping data transmission."
@dataclass
class StreamReset(QuicEvent):
"""
The StreamReset event is fired when the remote peer resets a stream.
"""
error_code: int
"The error code that triggered the reset."
stream_id: int
"The ID of the stream that was reset."

View File

@@ -0,0 +1,333 @@
from __future__ import annotations
import binascii
import json
import os
import time
from collections import deque
from typing import Any
from .._hazmat import RangeSet
from ..h3.events import Headers
from .packet import (
QuicFrameType,
QuicPacketType,
QuicStreamFrame,
QuicTransportParameters,
)
PACKET_TYPE_NAMES = {
QuicPacketType.INITIAL: "initial",
QuicPacketType.HANDSHAKE: "handshake",
QuicPacketType.ZERO_RTT: "0RTT",
QuicPacketType.ONE_RTT: "1RTT",
QuicPacketType.RETRY: "retry",
QuicPacketType.VERSION_NEGOTIATION: "version_negotiation",
}
QLOG_VERSION = "0.3"
def hexdump(data: bytes) -> str:
return binascii.hexlify(data).decode("ascii")
class QuicLoggerTrace:
"""
A QUIC event trace.
Events are logged in the format defined by qlog.
See:
- https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-02
- https://datatracker.ietf.org/doc/html/draft-marx-quic-qlog-quic-events
- https://datatracker.ietf.org/doc/html/draft-marx-quic-qlog-h3-events
"""
def __init__(self, *, is_client: bool, odcid: bytes) -> None:
self._odcid = odcid
self._events: deque[dict[str, Any]] = deque()
self._vantage_point = {
"name": "qh3",
"type": "client" if is_client else "server",
}
# QUIC
def encode_ack_frame(self, ranges: RangeSet, delay: float) -> dict:
return {
"ack_delay": self.encode_time(delay),
"acked_ranges": [[x[0], x[1] - 1] for x in ranges],
"frame_type": "ack",
}
def encode_connection_close_frame(
self, error_code: int, frame_type: int | None, reason_phrase: str
) -> dict:
attrs = {
"error_code": error_code,
"error_space": "application" if frame_type is None else "transport",
"frame_type": "connection_close",
"raw_error_code": error_code,
"reason": reason_phrase,
}
if frame_type is not None:
attrs["trigger_frame_type"] = frame_type
return attrs
def encode_connection_limit_frame(self, frame_type: int, maximum: int) -> dict:
if frame_type == QuicFrameType.MAX_DATA:
return {"frame_type": "max_data", "maximum": maximum}
else:
return {
"frame_type": "max_streams",
"maximum": maximum,
"stream_type": (
"unidirectional"
if frame_type == QuicFrameType.MAX_STREAMS_UNI
else "bidirectional"
),
}
def encode_crypto_frame(self, frame: QuicStreamFrame) -> dict:
return {
"frame_type": "crypto",
"length": len(frame.data),
"offset": frame.offset,
}
def encode_data_blocked_frame(self, limit: int) -> dict:
return {"frame_type": "data_blocked", "limit": limit}
def encode_datagram_frame(self, length: int) -> dict:
return {"frame_type": "datagram", "length": length}
def encode_handshake_done_frame(self) -> dict:
return {"frame_type": "handshake_done"}
def encode_max_stream_data_frame(self, maximum: int, stream_id: int) -> dict:
return {
"frame_type": "max_stream_data",
"maximum": maximum,
"stream_id": stream_id,
}
def encode_new_connection_id_frame(
self,
connection_id: bytes,
retire_prior_to: int,
sequence_number: int,
stateless_reset_token: bytes,
) -> dict:
return {
"connection_id": hexdump(connection_id),
"frame_type": "new_connection_id",
"length": len(connection_id),
"reset_token": hexdump(stateless_reset_token),
"retire_prior_to": retire_prior_to,
"sequence_number": sequence_number,
}
def encode_new_token_frame(self, token: bytes) -> dict:
return {
"frame_type": "new_token",
"length": len(token),
"token": hexdump(token),
}
def encode_padding_frame(self) -> dict:
return {"frame_type": "padding"}
def encode_path_challenge_frame(self, data: bytes) -> dict:
return {"data": hexdump(data), "frame_type": "path_challenge"}
def encode_path_response_frame(self, data: bytes) -> dict:
return {"data": hexdump(data), "frame_type": "path_response"}
def encode_ping_frame(self) -> dict:
return {"frame_type": "ping"}
def encode_reset_stream_frame(
self, error_code: int, final_size: int, stream_id: int
) -> dict:
return {
"error_code": error_code,
"final_size": final_size,
"frame_type": "reset_stream",
"stream_id": stream_id,
}
def encode_retire_connection_id_frame(self, sequence_number: int) -> dict:
return {
"frame_type": "retire_connection_id",
"sequence_number": sequence_number,
}
def encode_stream_data_blocked_frame(self, limit: int, stream_id: int) -> dict:
return {
"frame_type": "stream_data_blocked",
"limit": limit,
"stream_id": stream_id,
}
def encode_stop_sending_frame(self, error_code: int, stream_id: int) -> dict:
return {
"frame_type": "stop_sending",
"error_code": error_code,
"stream_id": stream_id,
}
def encode_stream_frame(self, frame: QuicStreamFrame, stream_id: int) -> dict:
return {
"fin": frame.fin,
"frame_type": "stream",
"length": len(frame.data),
"offset": frame.offset,
"stream_id": stream_id,
}
def encode_streams_blocked_frame(self, is_unidirectional: bool, limit: int) -> dict:
return {
"frame_type": "streams_blocked",
"limit": limit,
"stream_type": "unidirectional" if is_unidirectional else "bidirectional",
}
def encode_time(self, seconds: float) -> float:
"""
Convert a time to milliseconds.
"""
return seconds * 1000
def encode_transport_parameters(
self, owner: str, parameters: QuicTransportParameters
) -> dict[str, Any]:
data: dict[str, Any] = {"owner": owner}
for param_name, param_value in parameters.__dict__.items():
if isinstance(param_value, bool):
data[param_name] = param_value
elif isinstance(param_value, bytes):
data[param_name] = hexdump(param_value)
elif isinstance(param_value, int):
data[param_name] = param_value
return data
def packet_type(self, packet_type: QuicPacketType) -> str:
return PACKET_TYPE_NAMES[packet_type]
# HTTP/3
def encode_http3_data_frame(self, length: int, stream_id: int) -> dict:
return {
"frame": {"frame_type": "data"},
"length": length,
"stream_id": stream_id,
}
def encode_http3_headers_frame(
self, length: int, headers: Headers, stream_id: int
) -> dict:
return {
"frame": {
"frame_type": "headers",
"headers": self._encode_http3_headers(headers),
},
"length": length,
"stream_id": stream_id,
}
def encode_http3_push_promise_frame(
self, length: int, headers: Headers, push_id: int, stream_id: int
) -> dict:
return {
"frame": {
"frame_type": "push_promise",
"headers": self._encode_http3_headers(headers),
"push_id": push_id,
},
"length": length,
"stream_id": stream_id,
}
def _encode_http3_headers(self, headers: Headers) -> list[dict]:
return [
{"name": h[0].decode("utf8"), "value": h[1].decode("utf8")} for h in headers
]
# CORE
def log_event(self, *, category: str, event: str, data: dict) -> None:
self._events.append(
{
"data": data,
"name": category + ":" + event,
"time": self.encode_time(time.time()),
}
)
def to_dict(self) -> dict[str, Any]:
"""
Return the trace as a dictionary which can be written as JSON.
"""
return {
"common_fields": {
"ODCID": hexdump(self._odcid),
},
"events": list(self._events),
"vantage_point": self._vantage_point,
}
class QuicLogger:
"""
A QUIC event logger which stores traces in memory.
"""
def __init__(self) -> None:
self._traces: list[QuicLoggerTrace] = []
def start_trace(self, is_client: bool, odcid: bytes) -> QuicLoggerTrace:
trace = QuicLoggerTrace(is_client=is_client, odcid=odcid)
self._traces.append(trace)
return trace
def end_trace(self, trace: QuicLoggerTrace) -> None:
assert trace in self._traces, "QuicLoggerTrace does not belong to QuicLogger"
def to_dict(self) -> dict[str, Any]:
"""
Return the traces as a dictionary which can be written as JSON.
"""
return {
"qlog_format": "JSON",
"qlog_version": QLOG_VERSION,
"traces": [trace.to_dict() for trace in self._traces],
}
class QuicFileLogger(QuicLogger):
"""
A QUIC event logger which writes one trace per file.
"""
def __init__(self, path: str) -> None:
if not os.path.isdir(path):
raise ValueError(f"QUIC log output directory '{path}' does not exist")
self.path = path
super().__init__()
def end_trace(self, trace: QuicLoggerTrace) -> None:
trace_dict = trace.to_dict()
trace_path = os.path.join(
self.path, trace_dict["common_fields"]["ODCID"] + ".qlog"
)
with open(trace_path, "w") as logger_fp:
json.dump(
{
"qlog_format": "JSON",
"qlog_version": QLOG_VERSION,
"traces": [trace_dict],
},
logger_fp,
)
self._traces.remove(trace)

View File

@@ -0,0 +1,628 @@
from __future__ import annotations
import binascii
import ipaddress
import os
from dataclasses import dataclass
from enum import IntEnum
from .._compat import DATACLASS_KWARGS
from .._hazmat import AeadAes128Gcm, Buffer, RangeSet
PACKET_LONG_HEADER = 0x80
PACKET_FIXED_BIT = 0x40
PACKET_SPIN_BIT = 0x20
CONNECTION_ID_MAX_SIZE = 20
PACKET_NUMBER_MAX_SIZE = 4
RETRY_AEAD_KEY_VERSION_1 = binascii.unhexlify("be0c690b9f66575a1d766b54e368c84e")
RETRY_AEAD_KEY_VERSION_2 = binascii.unhexlify("8fb4b01b56ac48e260fbcbcead7ccc92")
RETRY_AEAD_NONCE_VERSION_1 = binascii.unhexlify("461599d35d632bf2239825bb")
RETRY_AEAD_NONCE_VERSION_2 = binascii.unhexlify("d86969bc2d7c6d9990efb04a")
RETRY_INTEGRITY_TAG_SIZE = 16
STATELESS_RESET_TOKEN_SIZE = 16
class QuicErrorCode(IntEnum):
NO_ERROR = 0x0
INTERNAL_ERROR = 0x1
CONNECTION_REFUSED = 0x2
FLOW_CONTROL_ERROR = 0x3
STREAM_LIMIT_ERROR = 0x4
STREAM_STATE_ERROR = 0x5
FINAL_SIZE_ERROR = 0x6
FRAME_ENCODING_ERROR = 0x7
TRANSPORT_PARAMETER_ERROR = 0x8
CONNECTION_ID_LIMIT_ERROR = 0x9
PROTOCOL_VIOLATION = 0xA
INVALID_TOKEN = 0xB
APPLICATION_ERROR = 0xC
CRYPTO_BUFFER_EXCEEDED = 0xD
KEY_UPDATE_ERROR = 0xE
AEAD_LIMIT_REACHED = 0xF
VERSION_NEGOTIATION_ERROR = 0x11
CRYPTO_ERROR = 0x100
class QuicPacketType(IntEnum):
INITIAL = 0
ZERO_RTT = 1
HANDSHAKE = 2
RETRY = 3
VERSION_NEGOTIATION = 4
ONE_RTT = 5
# For backwards compatibility only, use `QuicPacketType` in new code.
PACKET_TYPE_INITIAL = QuicPacketType.INITIAL
# QUIC version 1
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2
PACKET_LONG_TYPE_ENCODE_VERSION_1 = {
QuicPacketType.INITIAL: 0,
QuicPacketType.ZERO_RTT: 1,
QuicPacketType.HANDSHAKE: 2,
QuicPacketType.RETRY: 3,
}
PACKET_LONG_TYPE_DECODE_VERSION_1 = {
v: i for (i, v) in PACKET_LONG_TYPE_ENCODE_VERSION_1.items()
}
# QUIC version 2
# https://datatracker.ietf.org/doc/html/rfc9369#section-3.2
PACKET_LONG_TYPE_ENCODE_VERSION_2 = {
QuicPacketType.INITIAL: 1,
QuicPacketType.ZERO_RTT: 2,
QuicPacketType.HANDSHAKE: 3,
QuicPacketType.RETRY: 0,
}
PACKET_LONG_TYPE_DECODE_VERSION_2 = {
v: i for (i, v) in PACKET_LONG_TYPE_ENCODE_VERSION_2.items()
}
class QuicProtocolVersion(IntEnum):
NEGOTIATION = 0
VERSION_1 = 0x00000001
VERSION_2 = 0x6B3343CF
@dataclass(**DATACLASS_KWARGS)
class QuicHeader:
version: int | None
"The protocol version. Only present in long header packets."
packet_type: QuicPacketType
"The type of the packet."
packet_length: int
"The total length of the packet, in bytes."
destination_cid: bytes
"The destination connection ID."
source_cid: bytes
"The destination connection ID."
token: bytes
"The address verification token. Only present in `INITIAL` and `RETRY` packets."
integrity_tag: bytes
"The retry integrity tag. Only present in `RETRY` packets."
supported_versions: list[int]
"Supported protocol versions. Only present in `VERSION_NEGOTIATION` packets."
def get_retry_integrity_tag(
packet_without_tag: bytes, original_destination_cid: bytes, version: int
) -> bytes:
"""
Calculate the integrity tag for a RETRY packet.
"""
# build Retry pseudo packet
buf = Buffer(capacity=1 + len(original_destination_cid) + len(packet_without_tag))
buf.push_uint8(len(original_destination_cid))
buf.push_bytes(original_destination_cid)
buf.push_bytes(packet_without_tag)
assert buf.eof()
if version == QuicProtocolVersion.VERSION_2:
aead_key = RETRY_AEAD_KEY_VERSION_2
aead_nonce = RETRY_AEAD_NONCE_VERSION_2
else:
aead_key = RETRY_AEAD_KEY_VERSION_1
aead_nonce = RETRY_AEAD_NONCE_VERSION_1
# run AES-128-GCM
aead = AeadAes128Gcm(aead_key, b"null!12bytes")
integrity_tag = aead.encrypt_with_nonce(aead_nonce, b"", buf.data)
assert len(integrity_tag) == RETRY_INTEGRITY_TAG_SIZE
return integrity_tag
def get_spin_bit(first_byte: int) -> bool:
if first_byte & PACKET_SPIN_BIT:
return True
return False
def is_long_header(first_byte: int) -> bool:
if first_byte & PACKET_LONG_HEADER:
return True
return False
def pretty_protocol_version(version: int) -> str:
"""
Return a user-friendly representation of a protocol version.
"""
try:
version_name = QuicProtocolVersion(version).name
except ValueError:
version_name = "UNKNOWN"
return f"0x{version:08x} ({version_name})"
def pull_quic_header(buf: Buffer, host_cid_length: int | None = None) -> QuicHeader:
packet_start = buf.tell()
version: int | None
integrity_tag = b""
supported_versions = []
token = b""
first_byte = buf.pull_uint8()
if is_long_header(first_byte):
# Long Header Packets.
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2
version = buf.pull_uint32()
destination_cid_length = buf.pull_uint8()
if destination_cid_length > CONNECTION_ID_MAX_SIZE:
raise ValueError(
f"Destination CID is too long ({destination_cid_length} bytes)"
)
destination_cid = buf.pull_bytes(destination_cid_length)
source_cid_length = buf.pull_uint8()
if source_cid_length > CONNECTION_ID_MAX_SIZE:
raise ValueError(f"Source CID is too long ({source_cid_length} bytes)")
source_cid = buf.pull_bytes(source_cid_length)
if version == QuicProtocolVersion.NEGOTIATION:
# Version Negotiation Packet.
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2.1
packet_type = QuicPacketType.VERSION_NEGOTIATION
while not buf.eof():
supported_versions.append(buf.pull_uint32())
packet_end = buf.tell()
else:
if not (first_byte & PACKET_FIXED_BIT):
raise ValueError("Packet fixed bit is zero")
if version == QuicProtocolVersion.VERSION_2:
packet_type = PACKET_LONG_TYPE_DECODE_VERSION_2[
(first_byte & 0x30) >> 4
]
else:
packet_type = PACKET_LONG_TYPE_DECODE_VERSION_1[
(first_byte & 0x30) >> 4
]
if packet_type == QuicPacketType.INITIAL:
token_length = buf.pull_uint_var()
token = buf.pull_bytes(token_length)
rest_length = buf.pull_uint_var()
elif packet_type == QuicPacketType.ZERO_RTT:
rest_length = buf.pull_uint_var()
elif packet_type == QuicPacketType.HANDSHAKE:
rest_length = buf.pull_uint_var()
else:
token_length = buf.capacity - buf.tell() - RETRY_INTEGRITY_TAG_SIZE
token = buf.pull_bytes(token_length)
integrity_tag = buf.pull_bytes(RETRY_INTEGRITY_TAG_SIZE)
rest_length = 0
# Check remainder length.
packet_end = buf.tell() + rest_length
if packet_end > buf.capacity:
raise ValueError("Packet payload is truncated")
else:
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.3
if not (first_byte & PACKET_FIXED_BIT):
raise ValueError("Packet fixed bit is zero")
version = None
packet_type = QuicPacketType.ONE_RTT
destination_cid = buf.pull_bytes(host_cid_length)
source_cid = b""
packet_end = buf.capacity
return QuicHeader(
version=version,
packet_type=packet_type,
packet_length=packet_end - packet_start,
destination_cid=destination_cid,
source_cid=source_cid,
token=token,
integrity_tag=integrity_tag,
supported_versions=supported_versions,
)
def encode_long_header_first_byte(
version: int, packet_type: QuicPacketType, bits: int
) -> int:
"""
Encode the first byte of a long header packet.
"""
if version == QuicProtocolVersion.VERSION_2:
long_type_encode = PACKET_LONG_TYPE_ENCODE_VERSION_2
else:
long_type_encode = PACKET_LONG_TYPE_ENCODE_VERSION_1
return (
PACKET_LONG_HEADER
| PACKET_FIXED_BIT
| long_type_encode[packet_type] << 4
| bits
)
def encode_quic_retry(
version: int,
source_cid: bytes,
destination_cid: bytes,
original_destination_cid: bytes,
retry_token: bytes,
unused: int = 0,
) -> bytes:
buf = Buffer(
capacity=7
+ len(destination_cid)
+ len(source_cid)
+ len(retry_token)
+ RETRY_INTEGRITY_TAG_SIZE
)
buf.push_uint8(encode_long_header_first_byte(version, QuicPacketType.RETRY, unused))
buf.push_uint32(version)
buf.push_uint8(len(destination_cid))
buf.push_bytes(destination_cid)
buf.push_uint8(len(source_cid))
buf.push_bytes(source_cid)
buf.push_bytes(retry_token)
buf.push_bytes(
get_retry_integrity_tag(buf.data, original_destination_cid, version=version)
)
assert buf.eof()
return buf.data
def encode_quic_version_negotiation(
source_cid: bytes, destination_cid: bytes, supported_versions: list[int]
) -> bytes:
buf = Buffer(
capacity=7
+ len(destination_cid)
+ len(source_cid)
+ 4 * len(supported_versions)
)
buf.push_uint8(os.urandom(1)[0] | PACKET_LONG_HEADER)
buf.push_uint32(QuicProtocolVersion.NEGOTIATION)
buf.push_uint8(len(destination_cid))
buf.push_bytes(destination_cid)
buf.push_uint8(len(source_cid))
buf.push_bytes(source_cid)
for version in supported_versions:
buf.push_uint32(version)
return buf.data
# TLS EXTENSION
@dataclass(**DATACLASS_KWARGS)
class QuicPreferredAddress:
ipv4_address: tuple[str, int] | None
ipv6_address: tuple[str, int] | None
connection_id: bytes
stateless_reset_token: bytes
@dataclass(**DATACLASS_KWARGS)
class QuicVersionInformation:
chosen_version: int
available_versions: list[int]
@dataclass()
class QuicTransportParameters:
original_destination_connection_id: bytes | None = None
max_idle_timeout: int | None = None
stateless_reset_token: bytes | None = None
max_udp_payload_size: int | None = None
initial_max_data: int | None = None
initial_max_stream_data_bidi_local: int | None = None
initial_max_stream_data_bidi_remote: int | None = None
initial_max_stream_data_uni: int | None = None
initial_max_streams_bidi: int | None = None
initial_max_streams_uni: int | None = None
ack_delay_exponent: int | None = None
max_ack_delay: int | None = None
disable_active_migration: bool | None = False
preferred_address: QuicPreferredAddress | None = None
active_connection_id_limit: int | None = None
initial_source_connection_id: bytes | None = None
retry_source_connection_id: bytes | None = None
version_information: QuicVersionInformation | None = None
max_datagram_frame_size: int | None = None
quantum_readiness: bytes | None = None
PARAMS = {
0x00: ("original_destination_connection_id", bytes),
0x01: ("max_idle_timeout", int),
0x02: ("stateless_reset_token", bytes),
0x03: ("max_udp_payload_size", int),
0x04: ("initial_max_data", int),
0x05: ("initial_max_stream_data_bidi_local", int),
0x06: ("initial_max_stream_data_bidi_remote", int),
0x07: ("initial_max_stream_data_uni", int),
0x08: ("initial_max_streams_bidi", int),
0x09: ("initial_max_streams_uni", int),
0x0A: ("ack_delay_exponent", int),
0x0B: ("max_ack_delay", int),
0x0C: ("disable_active_migration", bool),
0x0D: ("preferred_address", QuicPreferredAddress),
0x0E: ("active_connection_id_limit", int),
0x0F: ("initial_source_connection_id", bytes),
0x10: ("retry_source_connection_id", bytes),
# https://datatracker.ietf.org/doc/html/rfc9368#section-3
0x11: ("version_information", QuicVersionInformation),
# extensions
0x0020: ("max_datagram_frame_size", int),
0x0C37: ("quantum_readiness", bytes),
}
def pull_quic_preferred_address(buf: Buffer) -> QuicPreferredAddress:
ipv4_address = None
ipv4_host = buf.pull_bytes(4)
ipv4_port = buf.pull_uint16()
if ipv4_host != bytes(4):
ipv4_address = (str(ipaddress.IPv4Address(ipv4_host)), ipv4_port)
ipv6_address = None
ipv6_host = buf.pull_bytes(16)
ipv6_port = buf.pull_uint16()
if ipv6_host != bytes(16):
ipv6_address = (str(ipaddress.IPv6Address(ipv6_host)), ipv6_port)
connection_id_length = buf.pull_uint8()
connection_id = buf.pull_bytes(connection_id_length)
stateless_reset_token = buf.pull_bytes(16)
return QuicPreferredAddress(
ipv4_address=ipv4_address,
ipv6_address=ipv6_address,
connection_id=connection_id,
stateless_reset_token=stateless_reset_token,
)
def push_quic_preferred_address(
buf: Buffer, preferred_address: QuicPreferredAddress
) -> None:
if preferred_address.ipv4_address is not None:
buf.push_bytes(ipaddress.IPv4Address(preferred_address.ipv4_address[0]).packed)
buf.push_uint16(preferred_address.ipv4_address[1])
else:
buf.push_bytes(bytes(6))
if preferred_address.ipv6_address is not None:
buf.push_bytes(ipaddress.IPv6Address(preferred_address.ipv6_address[0]).packed)
buf.push_uint16(preferred_address.ipv6_address[1])
else:
buf.push_bytes(bytes(18))
buf.push_uint8(len(preferred_address.connection_id))
buf.push_bytes(preferred_address.connection_id)
buf.push_bytes(preferred_address.stateless_reset_token)
def pull_quic_version_information(buf: Buffer, length: int) -> QuicVersionInformation:
chosen_version = buf.pull_uint32()
available_versions = []
for i in range(length // 4 - 1):
available_versions.append(buf.pull_uint32())
# If an endpoint receives a Chosen Version equal to zero, or any Available Version
# equal to zero, it MUST treat it as a parsing failure.
#
# https://datatracker.ietf.org/doc/html/rfc9368#section-4
if chosen_version == 0 or 0 in available_versions:
raise ValueError("Version Information must not contain version 0")
return QuicVersionInformation(
chosen_version=chosen_version,
available_versions=available_versions,
)
def push_quic_version_information(
buf: Buffer, version_information: QuicVersionInformation
) -> None:
buf.push_uint32(version_information.chosen_version)
for version in version_information.available_versions:
buf.push_uint32(version)
def pull_quic_transport_parameters(buf: Buffer) -> QuicTransportParameters:
params = QuicTransportParameters()
while not buf.eof():
param_id = buf.pull_uint_var()
param_len = buf.pull_uint_var()
param_start = buf.tell()
if param_id in PARAMS:
# parse known parameter
param_name, param_type = PARAMS[param_id]
if param_type is int:
setattr(params, param_name, buf.pull_uint_var())
elif param_type is bytes:
setattr(params, param_name, buf.pull_bytes(param_len))
elif param_type is QuicPreferredAddress:
setattr(params, param_name, pull_quic_preferred_address(buf))
elif param_type is QuicVersionInformation:
setattr(
params,
param_name,
pull_quic_version_information(buf, param_len),
)
else:
setattr(params, param_name, True)
else:
# skip unknown parameter
buf.pull_bytes(param_len)
if buf.tell() != param_start + param_len:
raise ValueError("Transport parameter length does not match")
return params
def push_quic_transport_parameters(
buf: Buffer, params: QuicTransportParameters
) -> None:
for param_id, (param_name, param_type) in PARAMS.items():
param_value = getattr(params, param_name)
if param_value is not None and param_value is not False:
param_buf = Buffer(capacity=65536)
if param_type is int:
param_buf.push_uint_var(param_value)
elif param_type is bytes:
param_buf.push_bytes(param_value)
elif param_type is QuicPreferredAddress:
push_quic_preferred_address(param_buf, param_value)
elif param_type is QuicVersionInformation:
push_quic_version_information(param_buf, param_value)
buf.push_uint_var(param_id)
buf.push_uint_var(param_buf.tell())
buf.push_bytes(param_buf.data)
# FRAMES
class QuicFrameType(IntEnum):
PADDING = 0x00
PING = 0x01
ACK = 0x02
ACK_ECN = 0x03
RESET_STREAM = 0x04
STOP_SENDING = 0x05
CRYPTO = 0x06
NEW_TOKEN = 0x07
STREAM_BASE = 0x08
MAX_DATA = 0x10
MAX_STREAM_DATA = 0x11
MAX_STREAMS_BIDI = 0x12
MAX_STREAMS_UNI = 0x13
DATA_BLOCKED = 0x14
STREAM_DATA_BLOCKED = 0x15
STREAMS_BLOCKED_BIDI = 0x16
STREAMS_BLOCKED_UNI = 0x17
NEW_CONNECTION_ID = 0x18
RETIRE_CONNECTION_ID = 0x19
PATH_CHALLENGE = 0x1A
PATH_RESPONSE = 0x1B
TRANSPORT_CLOSE = 0x1C
APPLICATION_CLOSE = 0x1D
HANDSHAKE_DONE = 0x1E
DATAGRAM = 0x30
DATAGRAM_WITH_LENGTH = 0x31
NON_ACK_ELICITING_FRAME_TYPES = frozenset(
[
QuicFrameType.ACK,
QuicFrameType.ACK_ECN,
QuicFrameType.PADDING,
QuicFrameType.TRANSPORT_CLOSE,
QuicFrameType.APPLICATION_CLOSE,
]
)
NON_IN_FLIGHT_FRAME_TYPES = frozenset(
[
QuicFrameType.ACK,
QuicFrameType.ACK_ECN,
QuicFrameType.TRANSPORT_CLOSE,
QuicFrameType.APPLICATION_CLOSE,
]
)
PROBING_FRAME_TYPES = frozenset(
[
QuicFrameType.PATH_CHALLENGE,
QuicFrameType.PATH_RESPONSE,
QuicFrameType.PADDING,
QuicFrameType.NEW_CONNECTION_ID,
]
)
@dataclass(**DATACLASS_KWARGS)
class QuicResetStreamFrame:
error_code: int
final_size: int
stream_id: int
@dataclass(**DATACLASS_KWARGS)
class QuicStopSendingFrame:
error_code: int
stream_id: int
@dataclass(**DATACLASS_KWARGS)
class QuicStreamFrame:
data: bytes = b""
fin: bool = False
offset: int = 0
def pull_ack_frame(buf: Buffer) -> tuple[RangeSet, int]:
rangeset = RangeSet()
end = buf.pull_uint_var() # largest acknowledged
delay = buf.pull_uint_var()
ack_range_count = buf.pull_uint_var()
ack_count = buf.pull_uint_var() # first ack range
rangeset.add(end - ack_count, end + 1)
end -= ack_count
for _ in range(ack_range_count):
end -= buf.pull_uint_var() + 2
ack_count = buf.pull_uint_var()
rangeset.add(end - ack_count, end + 1)
end -= ack_count
return rangeset, delay
def push_ack_frame(buf: Buffer, rangeset: RangeSet, delay: int) -> int:
ranges = len(rangeset)
index = ranges - 1
r = rangeset[index]
buf.push_uint_var(r[1] - 1)
buf.push_uint_var(delay)
buf.push_uint_var(index)
buf.push_uint_var(r[1] - 1 - r[0])
start = r[0]
while index > 0:
index -= 1
r = rangeset[index]
buf.push_uint_var(start - r[1] - 1)
buf.push_uint_var(r[1] - r[0] - 1)
start = r[0]
return ranges

View File

@@ -0,0 +1,437 @@
from __future__ import annotations
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Any, Callable, Sequence
from .._compat import DATACLASS_KWARGS
from .._hazmat import Buffer, size_uint_var
from ..tls import Epoch
from .crypto import CryptoPair
from .logger import QuicLoggerTrace
from .packet import (
NON_ACK_ELICITING_FRAME_TYPES,
NON_IN_FLIGHT_FRAME_TYPES,
PACKET_FIXED_BIT,
PACKET_NUMBER_MAX_SIZE,
QuicFrameType,
QuicPacketType,
encode_long_header_first_byte,
)
# MinPacketSize and MaxPacketSize control the packet sizes for UDP datagrams.
# If MinPacketSize is unset, a default value of 1280 bytes
# will be used during the handshake.
# If MaxPacketSize is unset, a default value of 1452 bytes will be used.
# DPLPMTUD will automatically determine the MTU supported
# by the link-up to the MaxPacketSize,
# except for in the case where MinPacketSize and MaxPacketSize
# are configured to the same value,
# in which case path MTU discovery will be disabled.
# Values above 65355 are invalid.
# 20-bytes for IPv6 overhead.
# 1280 is very conservative
# Chrome tries 1350 at startup
# we should do a rudimentary MTU discovery
# Sending a PING frame 1350
# THEN 1452
SMALLEST_MAX_DATAGRAM_SIZE = 1200
PACKET_MAX_SIZE = 1280
MTU_PROBE_SIZES = [1350, 1452]
PACKET_LENGTH_SEND_SIZE = 2
PACKET_NUMBER_SEND_SIZE = 2
QuicDeliveryHandler = Callable[..., None]
class QuicDeliveryState(IntEnum):
ACKED = 0
LOST = 1
@dataclass(**DATACLASS_KWARGS)
class QuicSentPacket:
epoch: Epoch
in_flight: bool
is_ack_eliciting: bool
is_crypto_packet: bool
packet_number: int
packet_type: QuicPacketType
sent_time: float | None = None
sent_bytes: int = 0
delivery_handlers: list[tuple[QuicDeliveryHandler, Any]] = field(
default_factory=list
)
quic_logger_frames: list[dict] = field(default_factory=list)
class QuicPacketBuilderStop(Exception):
pass
class QuicPacketBuilder:
"""
Helper for building QUIC packets.
"""
__slots__ = (
"max_flight_bytes",
"max_total_bytes",
"quic_logger_frames",
"_host_cid",
"_is_client",
"_peer_cid",
"_peer_token",
"_quic_logger",
"_spin_bit",
"_version",
"_datagrams",
"_datagram_flight_bytes",
"_datagram_init",
"_datagram_needs_padding",
"_packets",
"_flight_bytes",
"_total_bytes",
"_header_size",
"_packet",
"_packet_crypto",
"_packet_long_header",
"_packet_number",
"_packet_start",
"_packet_type",
"_buffer",
"_buffer_capacity",
"_flight_capacity",
)
def __init__(
self,
*,
host_cid: bytes,
peer_cid: bytes,
version: int,
is_client: bool,
max_datagram_size: int = PACKET_MAX_SIZE,
packet_number: int = 0,
peer_token: bytes = b"",
quic_logger: QuicLoggerTrace | None = None,
spin_bit: bool = False,
):
self.max_flight_bytes: int | None = None
self.max_total_bytes: int | None = None
self.quic_logger_frames: list[dict] | None = None
self._host_cid = host_cid
self._is_client = is_client
self._peer_cid = peer_cid
self._peer_token = peer_token
self._quic_logger = quic_logger
self._spin_bit = spin_bit
self._version = version
# assembled datagrams and packets
self._datagrams: list[bytes] = []
self._datagram_flight_bytes = 0
self._datagram_init = True
self._datagram_needs_padding = False
self._packets: list[QuicSentPacket] = []
self._flight_bytes = 0
self._total_bytes = 0
# current packet
self._header_size = 0
self._packet: QuicSentPacket | None = None
self._packet_crypto: CryptoPair | None = None
self._packet_long_header = False
self._packet_number = packet_number
self._packet_start = 0
self._packet_type: QuicPacketType | None = None
self._buffer = Buffer(max_datagram_size)
self._buffer_capacity = max_datagram_size
self._flight_capacity = max_datagram_size
@property
def packet_is_empty(self) -> bool:
"""
Returns `True` if the current packet is empty.
"""
assert self._packet is not None
packet_size = self._buffer.tell() - self._packet_start
return packet_size <= self._header_size
@property
def packet_number(self) -> int:
"""
Returns the packet number for the next packet.
"""
return self._packet_number
@property
def remaining_buffer_space(self) -> int:
"""
Returns the remaining number of bytes which can be used in
the current packet.
"""
return (
self._buffer_capacity
- self._buffer.tell()
- self._packet_crypto.aead_tag_size
)
@property
def remaining_flight_space(self) -> int:
"""
Returns the remaining number of bytes which can be used in
the current packet.
"""
return (
self._flight_capacity
- self._buffer.tell()
- self._packet_crypto.aead_tag_size
)
def flush(self) -> tuple[list[bytes], list[QuicSentPacket]]:
"""
Returns the assembled datagrams.
"""
if self._packet is not None:
self._end_packet()
self._flush_current_datagram()
datagrams = self._datagrams
packets = self._packets
self._datagrams = []
self._packets = []
return datagrams, packets
def start_frame(
self,
frame_type: int,
capacity: int = 1,
handler: QuicDeliveryHandler | None = None,
handler_args: Sequence[Any] = [],
) -> Buffer:
"""
Starts a new frame.
"""
if self.remaining_buffer_space < capacity or (
frame_type not in NON_IN_FLIGHT_FRAME_TYPES
and self.remaining_flight_space < capacity
):
raise QuicPacketBuilderStop
self._buffer.push_uint_var(frame_type)
if frame_type not in NON_ACK_ELICITING_FRAME_TYPES:
self._packet.is_ack_eliciting = True
if frame_type not in NON_IN_FLIGHT_FRAME_TYPES:
self._packet.in_flight = True
if frame_type == QuicFrameType.CRYPTO:
self._packet.is_crypto_packet = True
if handler is not None:
self._packet.delivery_handlers.append((handler, handler_args))
return self._buffer
def start_packet(self, packet_type: QuicPacketType, crypto: CryptoPair) -> None:
"""
Starts a new packet.
"""
assert packet_type not in {
QuicPacketType.RETRY,
QuicPacketType.VERSION_NEGOTIATION,
}, "Invalid packet type"
buf = self._buffer
# finish previous datagram
if self._packet is not None:
self._end_packet()
# if there is too little space remaining, start a new datagram
# FIXME: the limit is arbitrary!
packet_start = buf.tell()
if self._buffer_capacity - packet_start < 128:
self._flush_current_datagram()
packet_start = 0
# initialize datagram if needed
if self._datagram_init:
if self.max_total_bytes is not None:
remaining_total_bytes = self.max_total_bytes - self._total_bytes
if remaining_total_bytes < self._buffer_capacity:
self._buffer_capacity = remaining_total_bytes
self._flight_capacity = self._buffer_capacity
if self.max_flight_bytes is not None:
remaining_flight_bytes = self.max_flight_bytes - self._flight_bytes
if remaining_flight_bytes < self._flight_capacity:
self._flight_capacity = remaining_flight_bytes
self._datagram_flight_bytes = 0
self._datagram_init = False
self._datagram_needs_padding = False
# calculate header size
if packet_type != QuicPacketType.ONE_RTT:
header_size = 11 + len(self._peer_cid) + len(self._host_cid)
if packet_type == QuicPacketType.INITIAL:
token_length = len(self._peer_token)
header_size += size_uint_var(token_length) + token_length
else:
header_size = 3 + len(self._peer_cid)
# check we have enough space
if packet_start + header_size >= self._buffer_capacity:
raise QuicPacketBuilderStop
# determine ack epoch
if packet_type == QuicPacketType.INITIAL:
epoch = Epoch.INITIAL
elif packet_type == QuicPacketType.HANDSHAKE:
epoch = Epoch.HANDSHAKE
else:
epoch = Epoch.ONE_RTT
self._header_size = header_size
self._packet = QuicSentPacket(
epoch=epoch,
in_flight=False,
is_ack_eliciting=False,
is_crypto_packet=False,
packet_number=self._packet_number,
packet_type=packet_type,
)
self._packet_crypto = crypto
self._packet_start = packet_start
self._packet_type = packet_type
self.quic_logger_frames = self._packet.quic_logger_frames
buf.seek(self._packet_start + self._header_size)
def _end_packet(self) -> None:
"""
Ends the current packet.
"""
buf = self._buffer
packet_size = buf.tell() - self._packet_start
if packet_size > self._header_size:
# padding to ensure sufficient sample size
padding_size = (
PACKET_NUMBER_MAX_SIZE
- PACKET_NUMBER_SEND_SIZE
+ self._header_size
- packet_size
)
# Padding for datagrams containing initial packets; see RFC 9000
# section 14.1.
if (
self._is_client or self._packet.is_ack_eliciting
) and self._packet_type == QuicPacketType.INITIAL:
self._datagram_needs_padding = True
# For datagrams containing 1-RTT data, we *must* apply the padding
# inside the packet, we cannot tack bytes onto the end of the
# datagram.
if (
self._datagram_needs_padding
and self._packet_type == QuicPacketType.ONE_RTT
):
if self.remaining_flight_space > padding_size:
padding_size = self.remaining_flight_space
self._datagram_needs_padding = False
# write padding
if padding_size > 0:
buf.push_bytes(bytes(padding_size))
packet_size += padding_size
self._packet.in_flight = True
# log frame
if self._quic_logger is not None:
self._packet.quic_logger_frames.append(
self._quic_logger.encode_padding_frame()
)
# write header
if self._packet_type != QuicPacketType.ONE_RTT:
length = (
packet_size
- self._header_size
+ PACKET_NUMBER_SEND_SIZE
+ self._packet_crypto.aead_tag_size
)
buf.seek(self._packet_start)
buf.push_uint8(
encode_long_header_first_byte(
self._version, self._packet_type, PACKET_NUMBER_SEND_SIZE - 1
)
)
buf.push_uint32(self._version)
buf.push_uint8(len(self._peer_cid))
buf.push_bytes(self._peer_cid)
buf.push_uint8(len(self._host_cid))
buf.push_bytes(self._host_cid)
if self._packet_type == QuicPacketType.INITIAL:
buf.push_uint_var(len(self._peer_token))
buf.push_bytes(self._peer_token)
buf.push_uint16(length | 0x4000)
buf.push_uint16(self._packet_number & 0xFFFF)
else:
buf.seek(self._packet_start)
buf.push_uint8(
PACKET_FIXED_BIT
| (self._spin_bit << 5)
| (self._packet_crypto.key_phase << 2)
| (PACKET_NUMBER_SEND_SIZE - 1)
)
buf.push_bytes(self._peer_cid)
buf.push_uint16(self._packet_number & 0xFFFF)
# encrypt in place
plain = buf.data_slice(self._packet_start, self._packet_start + packet_size)
buf.seek(self._packet_start)
buf.push_bytes(
self._packet_crypto.encrypt_packet(
plain[0 : self._header_size],
plain[self._header_size : packet_size],
self._packet_number,
)
)
self._packet.sent_bytes = buf.tell() - self._packet_start
self._packets.append(self._packet)
if self._packet.in_flight:
self._datagram_flight_bytes += self._packet.sent_bytes
# Short header packets cannot be coalesced, we need a new datagram.
if self._packet_type == QuicPacketType.ONE_RTT:
self._flush_current_datagram()
self._packet_number += 1
else:
# "cancel" the packet
buf.seek(self._packet_start)
self._packet = None
self.quic_logger_frames = None
def _flush_current_datagram(self) -> None:
datagram_bytes = self._buffer.tell()
if datagram_bytes:
# Padding for datagrams containing initial packets; see RFC 9000
# section 14.1.
if self._datagram_needs_padding:
extra_bytes = self._flight_capacity - self._buffer.tell()
if extra_bytes > 0:
self._buffer.push_bytes(bytes(extra_bytes))
self._datagram_flight_bytes += extra_bytes
datagram_bytes += extra_bytes
self._datagrams.append(self._buffer.data)
self._flight_bytes += self._datagram_flight_bytes
self._total_bytes += datagram_bytes
self._datagram_init = True
self._buffer.seek(0)

View File

@@ -0,0 +1,503 @@
from __future__ import annotations
import logging
import math
from typing import Any, Callable, Iterable
from .._hazmat import QuicPacketPacer, QuicRttMonitor, RangeSet
from .logger import QuicLoggerTrace
from .packet_builder import QuicDeliveryState, QuicSentPacket
# loss detection
K_PACKET_THRESHOLD = 3
K_GRANULARITY = 0.001 # seconds
K_TIME_THRESHOLD = 9 / 8
K_MICRO_SECOND = 0.000001
K_SECOND = 1.0
# congestion control
K_INITIAL_WINDOW = 10
K_MINIMUM_WINDOW = 2
# Cubic constants (RFC 9438)
K_CUBIC_C = 0.4
K_CUBIC_LOSS_REDUCTION_FACTOR = 0.7
K_CUBIC_MAX_IDLE_TIME = 2.0 # seconds
def _cubic_root(x: float) -> float:
if x < 0:
return -((-x) ** (1.0 / 3.0))
return x ** (1.0 / 3.0)
class QuicPacketSpace:
def __init__(self) -> None:
self.ack_at: float | None = None
self.ack_queue = RangeSet()
self.discarded = False
self.expected_packet_number = 0
self.largest_received_packet = -1
self.largest_received_time: float | None = None
# sent packets and loss
self.ack_eliciting_in_flight = 0
self.largest_acked_packet = 0
self.loss_time: float | None = None
self.sent_packets: dict[int, QuicSentPacket] = {}
class QuicCongestionControl:
"""
Cubic congestion control (RFC 9438).
"""
def __init__(self, max_datagram_size: int) -> None:
self._max_datagram_size = max_datagram_size
self._rtt_monitor = QuicRttMonitor()
self._congestion_recovery_start_time = 0.0
self._rtt = 0.02 # initial RTT estimate (20 ms)
self._last_ack = 0.0
self.bytes_in_flight = 0
self.congestion_window = max_datagram_size * K_INITIAL_WINDOW
self.ssthresh: int | None = None
# Cubic state
self._first_slow_start = True
self._starting_congestion_avoidance = False
self._K: float = 0.0
self._W_max: int = self.congestion_window
self._W_est: int = 0
self._cwnd_epoch: int = 0
self._t_epoch: float = 0.0
def _W_cubic(self, t: float) -> int:
W_max_segments = self._W_max / self._max_datagram_size
target_segments = K_CUBIC_C * (t - self._K) ** 3 + W_max_segments
return int(target_segments * self._max_datagram_size)
def _reset(self) -> None:
self.congestion_window = self._max_datagram_size * K_INITIAL_WINDOW
self.ssthresh = None
self._first_slow_start = True
self._starting_congestion_avoidance = False
self._K = 0.0
self._W_max = self.congestion_window
self._W_est = 0
self._cwnd_epoch = 0
self._t_epoch = 0.0
def _start_epoch(self, now: float) -> None:
self._t_epoch = now
self._cwnd_epoch = self.congestion_window
self._W_est = self._cwnd_epoch
W_max_seg = self._W_max / self._max_datagram_size
cwnd_seg = self._cwnd_epoch / self._max_datagram_size
self._K = _cubic_root((W_max_seg - cwnd_seg) / K_CUBIC_C)
def on_packet_acked(self, packet: QuicSentPacket) -> None:
self.bytes_in_flight -= packet.sent_bytes
self._last_ack = packet.sent_time
if self.ssthresh is None or self.congestion_window < self.ssthresh:
# slow start
self.congestion_window += packet.sent_bytes
else:
# congestion avoidance
if self._first_slow_start and not self._starting_congestion_avoidance:
# exiting slow start without a loss (HyStart triggered)
self._first_slow_start = False
self._W_max = self.congestion_window
self._start_epoch(packet.sent_time)
if self._starting_congestion_avoidance:
# entering congestion avoidance after a loss
self._starting_congestion_avoidance = False
self._first_slow_start = False
self._start_epoch(packet.sent_time)
# TCP-friendly estimate (Reno-like linear growth)
self._W_est = int(
self._W_est
+ self._max_datagram_size * (packet.sent_bytes / self.congestion_window)
)
t = packet.sent_time - self._t_epoch
W_cubic = self._W_cubic(t + self._rtt)
# clamp target
if W_cubic < self.congestion_window:
target = self.congestion_window
elif W_cubic > int(1.5 * self.congestion_window):
target = int(1.5 * self.congestion_window)
else:
target = W_cubic
if self._W_cubic(t) < self._W_est:
# Reno-friendly region
self.congestion_window = self._W_est
else:
# concave / convex region
self.congestion_window = int(
self.congestion_window
+ (target - self.congestion_window)
* (self._max_datagram_size / self.congestion_window)
)
def on_packet_sent(self, packet: QuicSentPacket) -> None:
self.bytes_in_flight += packet.sent_bytes
# reset cwnd after prolonged idle
if self._last_ack > 0.0:
elapsed_idle = packet.sent_time - self._last_ack
if elapsed_idle >= K_CUBIC_MAX_IDLE_TIME:
self._reset()
def on_packets_expired(self, packets: Iterable[QuicSentPacket]) -> None:
for packet in packets:
self.bytes_in_flight -= packet.sent_bytes
def on_packets_lost(self, packets: Iterable[QuicSentPacket], now: float) -> None:
lost_largest_time = 0.0
for packet in packets:
self.bytes_in_flight -= packet.sent_bytes
lost_largest_time = packet.sent_time
# start a new congestion event if packet was sent after the
# start of the previous congestion recovery period.
if lost_largest_time > self._congestion_recovery_start_time:
self._congestion_recovery_start_time = now
# fast convergence: if W_max is decreasing, reduce it further
if self.congestion_window < self._W_max:
self._W_max = int(
self.congestion_window * (1 + K_CUBIC_LOSS_REDUCTION_FACTOR) / 2
)
else:
self._W_max = self.congestion_window
self.congestion_window = max(
int(self.congestion_window * K_CUBIC_LOSS_REDUCTION_FACTOR),
self._max_datagram_size * K_MINIMUM_WINDOW,
)
self.ssthresh = self.congestion_window
self._starting_congestion_avoidance = True
def on_rtt_measurement(self, latest_rtt: float, now: float) -> None:
self._rtt = latest_rtt
# check whether we should exit slow start
if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing(
latest_rtt, now
):
self.ssthresh = self.congestion_window
class QuicPacketRecovery:
"""
Packet loss and congestion controller.
"""
def __init__(
self,
initial_rtt: float,
peer_completed_address_validation: bool,
send_probe: Callable[[], None],
max_datagram_size: int = 1280,
logger: logging.LoggerAdapter | None = None,
quic_logger: QuicLoggerTrace | None = None,
) -> None:
self.max_ack_delay = 0.025
self.peer_completed_address_validation = peer_completed_address_validation
self.spaces: list[QuicPacketSpace] = []
# callbacks
self._logger = logger
self._quic_logger = quic_logger
self._send_probe = send_probe
# loss detection
self._pto_count = 0
self._rtt_initial = initial_rtt
self._rtt_initialized = False
self._rtt_latest = 0.0
self._rtt_min = math.inf
self._rtt_smoothed = 0.0
self._rtt_variance = 0.0
self._time_of_last_sent_ack_eliciting_packet = 0.0
# congestion control
self._cc = QuicCongestionControl(max_datagram_size)
self._pacer = QuicPacketPacer(max_datagram_size)
@property
def bytes_in_flight(self) -> int:
return self._cc.bytes_in_flight
@property
def congestion_window(self) -> int:
return self._cc.congestion_window
def discard_space(self, space: QuicPacketSpace) -> None:
assert space in self.spaces
self._cc.on_packets_expired(
filter(lambda x: x.in_flight, space.sent_packets.values())
)
space.sent_packets.clear()
space.ack_at = None
space.ack_eliciting_in_flight = 0
space.loss_time = None
# reset PTO count
self._pto_count = 0
if self._quic_logger is not None:
self._log_metrics_updated()
def get_loss_detection_time(self) -> float:
# loss timer
loss_space = self._get_loss_space()
if loss_space is not None:
return loss_space.loss_time
# packet timer
if (
not self.peer_completed_address_validation
or sum(space.ack_eliciting_in_flight for space in self.spaces) > 0
):
timeout = self.get_probe_timeout() * (2**self._pto_count)
return self._time_of_last_sent_ack_eliciting_packet + timeout
return None
def get_probe_timeout(self) -> float:
if not self._rtt_initialized:
return 2 * self._rtt_initial
return (
self._rtt_smoothed
+ max(4 * self._rtt_variance, K_GRANULARITY)
+ self.max_ack_delay
)
def on_ack_received(
self,
space: QuicPacketSpace,
ack_rangeset: RangeSet,
ack_delay: float,
now: float,
) -> None:
"""
Update metrics as the result of an ACK being received.
"""
is_ack_eliciting = False
largest_acked = ack_rangeset.bounds()[1] - 1
largest_newly_acked = None
largest_sent_time = None
if largest_acked > space.largest_acked_packet:
space.largest_acked_packet = largest_acked
for packet_number in sorted(space.sent_packets.keys()):
if packet_number > largest_acked:
break
if packet_number in ack_rangeset:
# remove packet and update counters
packet = space.sent_packets.pop(packet_number)
if packet.is_ack_eliciting:
is_ack_eliciting = True
space.ack_eliciting_in_flight -= 1
if packet.in_flight:
self._cc.on_packet_acked(packet)
largest_newly_acked = packet_number
largest_sent_time = packet.sent_time
# trigger callbacks
for handler, args in packet.delivery_handlers:
handler(QuicDeliveryState.ACKED, *args)
# nothing to do if there are no newly acked packets
if largest_newly_acked is None:
return
if largest_acked == largest_newly_acked and is_ack_eliciting:
latest_rtt = now - largest_sent_time
log_rtt = True
# limit ACK delay to max_ack_delay
ack_delay = min(ack_delay, self.max_ack_delay)
# update RTT estimate, which cannot be < 1 ms
self._rtt_latest = max(latest_rtt, 0.001)
if self._rtt_latest < self._rtt_min:
self._rtt_min = self._rtt_latest
if self._rtt_latest > self._rtt_min + ack_delay:
self._rtt_latest -= ack_delay
if not self._rtt_initialized:
self._rtt_initialized = True
self._rtt_variance = latest_rtt / 2
self._rtt_smoothed = latest_rtt
else:
self._rtt_variance = 3 / 4 * self._rtt_variance + 1 / 4 * abs(
self._rtt_min - self._rtt_latest
)
self._rtt_smoothed = (
7 / 8 * self._rtt_smoothed + 1 / 8 * self._rtt_latest
)
# inform congestion controller
self._cc.on_rtt_measurement(latest_rtt, now=now)
self._pacer.update_rate(
congestion_window=self._cc.congestion_window,
smoothed_rtt=self._rtt_smoothed,
)
else:
log_rtt = False
self._detect_loss(space, now=now)
# reset PTO count
self._pto_count = 0
if self._quic_logger is not None:
self._log_metrics_updated(log_rtt=log_rtt)
def on_loss_detection_timeout(self, now: float) -> None:
loss_space = self._get_loss_space()
if loss_space is not None:
self._detect_loss(loss_space, now=now)
else:
self._pto_count += 1
self.reschedule_data(now=now)
def on_packet_sent(self, packet: QuicSentPacket, space: QuicPacketSpace) -> None:
space.sent_packets[packet.packet_number] = packet
if packet.is_ack_eliciting:
space.ack_eliciting_in_flight += 1
if packet.in_flight:
if packet.is_ack_eliciting:
self._time_of_last_sent_ack_eliciting_packet = packet.sent_time
# add packet to bytes in flight
self._cc.on_packet_sent(packet)
if self._quic_logger is not None:
self._log_metrics_updated()
def reschedule_data(self, now: float) -> None:
"""
Schedule some data for retransmission.
"""
# if there is any outstanding CRYPTO, retransmit it
crypto_scheduled = False
for space in self.spaces:
packets = tuple(
filter(lambda i: i.is_crypto_packet, space.sent_packets.values())
)
if packets:
self._on_packets_lost(packets, space=space, now=now)
crypto_scheduled = True
if crypto_scheduled and self._logger is not None:
self._logger.debug("Scheduled CRYPTO data for retransmission")
# ensure an ACK-elliciting packet is sent
self._send_probe()
def _detect_loss(self, space: QuicPacketSpace, now: float) -> None:
"""
Check whether any packets should be declared lost.
"""
loss_delay = K_TIME_THRESHOLD * (
max(self._rtt_latest, self._rtt_smoothed)
if self._rtt_initialized
else self._rtt_initial
)
packet_threshold = space.largest_acked_packet - K_PACKET_THRESHOLD
time_threshold = now - loss_delay
lost_packets = []
space.loss_time = None
for packet_number, packet in space.sent_packets.items():
if packet_number > space.largest_acked_packet:
break
if packet_number <= packet_threshold or packet.sent_time <= time_threshold:
lost_packets.append(packet)
else:
packet_loss_time = packet.sent_time + loss_delay
if space.loss_time is None or space.loss_time > packet_loss_time:
space.loss_time = packet_loss_time
self._on_packets_lost(lost_packets, space=space, now=now)
def _get_loss_space(self) -> QuicPacketSpace | None:
loss_space = None
for space in self.spaces:
if space.loss_time is not None and (
loss_space is None or space.loss_time < loss_space.loss_time
):
loss_space = space
return loss_space
def _log_metrics_updated(self, log_rtt=False) -> None:
data: dict[str, Any] = {
"bytes_in_flight": self._cc.bytes_in_flight,
"cwnd": self._cc.congestion_window,
}
if self._cc.ssthresh is not None:
data["ssthresh"] = self._cc.ssthresh
if log_rtt:
data.update(
{
"latest_rtt": self._quic_logger.encode_time(self._rtt_latest),
"min_rtt": self._quic_logger.encode_time(self._rtt_min),
"smoothed_rtt": self._quic_logger.encode_time(self._rtt_smoothed),
"rtt_variance": self._quic_logger.encode_time(self._rtt_variance),
}
)
self._quic_logger.log_event(
category="recovery", event="metrics_updated", data=data
)
def _on_packets_lost(
self, packets: Iterable[QuicSentPacket], space: QuicPacketSpace, now: float
) -> None:
lost_packets_cc = []
for packet in packets:
del space.sent_packets[packet.packet_number]
if packet.in_flight:
lost_packets_cc.append(packet)
if packet.is_ack_eliciting:
space.ack_eliciting_in_flight -= 1
if self._quic_logger is not None:
self._quic_logger.log_event(
category="recovery",
event="packet_lost",
data={
"type": self._quic_logger.packet_type(packet.packet_type),
"packet_number": packet.packet_number,
},
)
self._log_metrics_updated()
# trigger callbacks
for handler, args in packet.delivery_handlers:
handler(QuicDeliveryState.LOST, *args)
# inform congestion controller
if lost_packets_cc:
self._cc.on_packets_lost(lost_packets_cc, now=now)
self._pacer.update_rate(
congestion_window=self._cc.congestion_window,
smoothed_rtt=self._rtt_smoothed,
)
if self._quic_logger is not None:
self._log_metrics_updated()

View File

@@ -0,0 +1,39 @@
from __future__ import annotations
import ipaddress
from .._hazmat import Buffer, Rsa
from ..tls import pull_opaque, push_opaque
from .connection import NetworkAddress
def encode_address(addr: NetworkAddress) -> bytes:
return ipaddress.ip_address(addr[0]).packed + bytes([addr[1] >> 8, addr[1] & 0xFF])
class QuicRetryTokenHandler:
def __init__(self) -> None:
self._key = Rsa(key_size=2048)
def create_token(
self,
addr: NetworkAddress,
original_destination_connection_id: bytes,
retry_source_connection_id: bytes,
) -> bytes:
buf = Buffer(capacity=512)
push_opaque(buf, 1, encode_address(addr))
push_opaque(buf, 1, original_destination_connection_id)
push_opaque(buf, 1, retry_source_connection_id)
return self._key.encrypt(buf.data)
def validate_token(self, addr: NetworkAddress, token: bytes) -> tuple[bytes, bytes]:
if not token or len(token) != 256:
raise ValueError("Ciphertext length must be equal to key size.")
buf = Buffer(data=self._key.decrypt(token))
encoded_addr = pull_opaque(buf, 1)
original_destination_connection_id = pull_opaque(buf, 1)
retry_source_connection_id = pull_opaque(buf, 1)
if encoded_addr != encode_address(addr):
raise ValueError("Remote address does not match.")
return original_destination_connection_id, retry_source_connection_id

View File

@@ -0,0 +1,380 @@
from __future__ import annotations
from .._hazmat import RangeSet
from . import events
from .packet import (
QuicErrorCode,
QuicResetStreamFrame,
QuicStopSendingFrame,
QuicStreamFrame,
)
from .packet_builder import QuicDeliveryState
class FinalSizeError(Exception):
pass
class StreamFinishedError(Exception):
pass
class QuicStreamReceiver:
"""
The receive part of a QUIC stream.
It finishes:
- immediately for a send-only stream
- upon reception of a STREAM_RESET frame
- upon reception of a data frame with the FIN bit set
"""
__slots__ = (
"highest_offset",
"is_finished",
"stop_pending",
"_buffer",
"_buffer_start",
"_final_size",
"_ranges",
"_stream_id",
"_stop_error_code",
)
def __init__(self, stream_id: int | None, readable: bool) -> None:
self.highest_offset = 0 # the highest offset ever seen
self.is_finished = False
self.stop_pending = False
self._buffer = bytearray()
self._buffer_start = 0 # the offset for the start of the buffer
self._final_size: int | None = None
self._ranges = RangeSet()
self._stream_id = stream_id
self._stop_error_code: int | None = None
def get_stop_frame(self) -> QuicStopSendingFrame:
self.stop_pending = False
return QuicStopSendingFrame(
error_code=self._stop_error_code,
stream_id=self._stream_id,
)
def starting_offset(self) -> int:
return self._buffer_start
def handle_frame(self, frame: QuicStreamFrame) -> events.StreamDataReceived | None:
"""
Handle a frame of received data.
"""
pos = frame.offset - self._buffer_start
count = len(frame.data)
frame_end = frame.offset + count
# we should receive no more data beyond FIN!
if self._final_size is not None:
if frame_end > self._final_size:
raise FinalSizeError("Data received beyond final size")
elif frame.fin and frame_end != self._final_size:
raise FinalSizeError("Cannot change final size")
if frame.fin:
self._final_size = frame_end
if frame_end > self.highest_offset:
self.highest_offset = frame_end
# fast path: new in-order chunk
if pos == 0 and count and not self._buffer:
self._buffer_start += count
if frame.fin:
# all data up to the FIN has been received, we're done receiving
self.is_finished = True
return events.StreamDataReceived(
data=frame.data, end_stream=frame.fin, stream_id=self._stream_id
)
# discard duplicate data
if pos < 0:
frame.data = frame.data[-pos:]
frame.offset -= pos
pos = 0
count = len(frame.data)
# marked received range
if frame_end > frame.offset:
self._ranges.add(frame.offset, frame_end)
# add new data
gap = pos - len(self._buffer)
if gap > 0:
self._buffer += bytearray(gap)
self._buffer[pos : pos + count] = frame.data
# return data from the front of the buffer
data = self._pull_data()
end_stream = self._buffer_start == self._final_size
if end_stream:
# all data up to the FIN has been received, we're done receiving
self.is_finished = True
if data or end_stream:
return events.StreamDataReceived(
data=data, end_stream=end_stream, stream_id=self._stream_id
)
else:
return None
def handle_reset(
self, *, final_size: int, error_code: int = QuicErrorCode.NO_ERROR
) -> events.StreamReset | None:
"""
Handle an abrupt termination of the receiving part of the QUIC stream.
"""
if self._final_size is not None and final_size != self._final_size:
raise FinalSizeError("Cannot change final size")
# we are done receiving
self._final_size = final_size
self.is_finished = True
return events.StreamReset(error_code=error_code, stream_id=self._stream_id)
def on_stop_sending_delivery(self, delivery: QuicDeliveryState) -> None:
"""
Callback when a STOP_SENDING is ACK'd.
"""
if delivery != QuicDeliveryState.ACKED:
self.stop_pending = True
def stop(self, error_code: int = QuicErrorCode.NO_ERROR) -> None:
"""
Request the peer stop sending data on the QUIC stream.
"""
self._stop_error_code = error_code
self.stop_pending = True
def _pull_data(self) -> bytes:
"""
Remove data from the front of the buffer.
"""
try:
has_data_to_read = self._ranges[0][0] == self._buffer_start
except IndexError:
has_data_to_read = False
if not has_data_to_read:
return b""
r = self._ranges.shift()
pos = r[1] - r[0]
data = bytes(self._buffer[:pos])
del self._buffer[:pos]
self._buffer_start = r[1]
return data
class QuicStreamSender:
"""
The send part of a QUIC stream.
It finishes:
- immediately for a receive-only stream
- upon acknowledgement of a STREAM_RESET frame
- upon acknowledgement of a data frame with the FIN bit set
"""
__slots__ = (
"buffer_is_empty",
"highest_offset",
"is_finished",
"reset_pending",
"_acked",
"_buffer",
"_buffer_fin",
"_buffer_start",
"_buffer_stop",
"_pending",
"_pending_eof",
"_reset_error_code",
"_stream_id",
"send_buffer_empty",
)
def __init__(self, stream_id: int | None, writable: bool) -> None:
self.buffer_is_empty = True
self.highest_offset = 0
self.is_finished = not writable
self.reset_pending = False
self._acked = RangeSet()
self._buffer = bytearray()
self._buffer_fin: int | None = None
self._buffer_start = 0 # the offset for the start of the buffer
self._buffer_stop = 0 # the offset for the stop of the buffer
self._pending = RangeSet()
self._pending_eof = False
self._reset_error_code: int | None = None
self._stream_id = stream_id
@property
def next_offset(self) -> int:
"""
The offset for the next frame to send.
This is used to determine the space needed for the frame's `offset` field.
"""
try:
return self._pending[0][0]
except IndexError:
return self._buffer_stop
def get_frame(
self, max_size: int, max_offset: int | None = None
) -> QuicStreamFrame | None:
"""
Get a frame of data to send.
"""
# get the first pending data range
try:
r = self._pending[0]
except IndexError:
if self._pending_eof:
# FIN only
self._pending_eof = False
return QuicStreamFrame(fin=True, offset=self._buffer_fin)
self.buffer_is_empty = True
return None
# apply flow control
start = r[0]
stop = min(r[1], start + max_size)
if max_offset is not None and stop > max_offset:
stop = max_offset
if stop <= start:
return None
# create frame
frame = QuicStreamFrame(
data=bytes(
self._buffer[start - self._buffer_start : stop - self._buffer_start]
),
offset=start,
)
self._pending.subtract(start, stop)
# track the highest offset ever sent
if stop > self.highest_offset:
self.highest_offset = stop
# if the buffer is empty and EOF was written, set the FIN bit
if self._buffer_fin == stop:
frame.fin = True
self._pending_eof = False
return frame
def get_reset_frame(self) -> QuicResetStreamFrame:
self.reset_pending = False
return QuicResetStreamFrame(
error_code=self._reset_error_code,
final_size=self.highest_offset,
stream_id=self._stream_id,
)
def on_data_delivery(
self, delivery: QuicDeliveryState, start: int, stop: int
) -> None:
"""
Callback when sent data is ACK'd.
"""
self.buffer_is_empty = False
if delivery == QuicDeliveryState.ACKED:
if stop > start:
self._acked.add(start, stop)
first_range = self._acked[0]
if first_range[0] == self._buffer_start:
size = first_range[1] - first_range[0]
self._acked.shift()
self._buffer_start += size
del self._buffer[:size]
if self._buffer_start == self._buffer_fin:
# all date up to the FIN has been ACK'd, we're done sending
self.is_finished = True
else:
if stop > start:
self._pending.add(start, stop)
if stop == self._buffer_fin:
self.send_buffer_empty = False
self._pending_eof = True
def on_reset_delivery(self, delivery: QuicDeliveryState) -> None:
"""
Callback when a reset is ACK'd.
"""
if delivery == QuicDeliveryState.ACKED:
# the reset has been ACK'd, we're done sending
self.is_finished = True
else:
self.reset_pending = True
def reset(self, error_code: int) -> None:
"""
Abruptly terminate the sending part of the QUIC stream.
Once this method has been called, any further calls to it
will have no effect.
"""
if self._reset_error_code is None:
self._reset_error_code = error_code
self.reset_pending = True
# Prevent any more data from being sent or re-sent.
self.buffer_is_empty = True
def write(self, data: bytes, end_stream: bool = False) -> None:
"""
Write some data bytes to the QUIC stream.
"""
assert self._buffer_fin is None, "cannot call write() after FIN"
assert self._reset_error_code is None, "cannot call write() after reset()"
size = len(data)
if size:
self.buffer_is_empty = False
self._pending.add(self._buffer_stop, self._buffer_stop + size)
self._buffer += data
self._buffer_stop += size
if end_stream:
self.buffer_is_empty = False
self._buffer_fin = self._buffer_stop
self._pending_eof = True
class QuicStream:
__slots__ = (
"is_blocked",
"max_stream_data_local",
"max_stream_data_local_sent",
"max_stream_data_remote",
"receiver",
"sender",
"stream_id",
)
def __init__(
self,
stream_id: int | None = None,
max_stream_data_local: int = 0,
max_stream_data_remote: int = 0,
readable: bool = True,
writable: bool = True,
) -> None:
self.is_blocked = False
self.max_stream_data_local = max_stream_data_local
self.max_stream_data_local_sent = max_stream_data_local
self.max_stream_data_remote = max_stream_data_remote
self.receiver = QuicStreamReceiver(stream_id=stream_id, readable=readable)
self.sender = QuicStreamSender(stream_id=stream_id, writable=writable)
self.stream_id = stream_id
@property
def is_finished(self) -> bool:
return self.receiver.is_finished and self.sender.is_finished