Jellyfin(8096), OrbStack(8097) 포트 충돌으로 변경. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2279 lines
77 KiB
Python
2279 lines
77 KiB
Python
from __future__ import annotations
|
|
|
|
import datetime
|
|
import glob
|
|
import hashlib
|
|
import logging
|
|
import os
|
|
import re
|
|
import ssl
|
|
import struct
|
|
from binascii import unhexlify
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass, field
|
|
from enum import IntEnum
|
|
from functools import lru_cache, partial
|
|
from hmac import HMAC
|
|
from typing import Any, Callable, Generator, Optional, Sequence, Tuple, TypeVar
|
|
|
|
from ._hazmat import (
|
|
Buffer,
|
|
BufferReadError,
|
|
CryptoError,
|
|
DsaPrivateKey,
|
|
ECDHP256KeyExchange,
|
|
ECDHP384KeyExchange,
|
|
ECDHP521KeyExchange,
|
|
EcPrivateKey,
|
|
Ed25519PrivateKey,
|
|
ExpiredCertificateError,
|
|
InvalidNameCertificateError,
|
|
KeyType,
|
|
PrivateKeyInfo,
|
|
RsaPrivateKey,
|
|
SelfSignedCertificateError,
|
|
ServerVerifier,
|
|
SignatureError,
|
|
TlsCertUsage,
|
|
UnacceptableCertificateError,
|
|
X25519KeyExchange,
|
|
X25519ML768KeyExchange,
|
|
idna_encode,
|
|
rebuild_chain,
|
|
verify_with_public_key,
|
|
)
|
|
from ._hazmat import (
|
|
Certificate as X509Certificate,
|
|
)
|
|
|
|
_HASHED_CERT_FILENAME_RE = re.compile(r"^[0-9a-fA-F]{8}\.[0-9]$")
|
|
|
|
TLS_VERSION_GREASE = 0x0A0A
|
|
TLS_VERSION_1_2 = 0x0303
|
|
TLS_VERSION_1_3 = 0x0304
|
|
|
|
T = TypeVar("T")
|
|
|
|
# Maps the length of a digest to a possible hash function producing this digest
|
|
HASHFUNC_MAP = {
|
|
length: getattr(hashlib, algorithm, None)
|
|
for length, algorithm in (
|
|
(32, "md5"), # some algorithm may be unavailable
|
|
(40, "sha1"),
|
|
(64, "sha256"),
|
|
)
|
|
}
|
|
|
|
|
|
# facilitate mocking for the test suite
|
|
def utcnow(remove_tz: bool = True) -> datetime.datetime:
|
|
dt = datetime.datetime.now(datetime.timezone.utc)
|
|
return dt.replace(tzinfo=None) if remove_tz else dt
|
|
|
|
|
|
class AlertDescription(IntEnum):
|
|
close_notify = 0
|
|
unexpected_message = 10
|
|
bad_record_mac = 20
|
|
record_overflow = 22
|
|
handshake_failure = 40
|
|
bad_certificate = 42
|
|
unsupported_certificate = 43
|
|
certificate_revoked = 44
|
|
certificate_expired = 45
|
|
certificate_unknown = 46
|
|
illegal_parameter = 47
|
|
unknown_ca = 48
|
|
access_denied = 49
|
|
decode_error = 50
|
|
decrypt_error = 51
|
|
protocol_version = 70
|
|
insufficient_security = 71
|
|
internal_error = 80
|
|
inappropriate_fallback = 86
|
|
user_canceled = 90
|
|
missing_extension = 109
|
|
unsupported_extension = 110
|
|
unrecognized_name = 112
|
|
bad_certificate_status_response = 113
|
|
unknown_psk_identity = 115
|
|
certificate_required = 116
|
|
no_application_protocol = 120
|
|
|
|
|
|
class Alert(Exception):
|
|
description: AlertDescription
|
|
|
|
|
|
class AlertBadCertificate(Alert):
|
|
description = AlertDescription.bad_certificate
|
|
|
|
|
|
class AlertCertificateExpired(Alert):
|
|
description = AlertDescription.certificate_expired
|
|
|
|
|
|
class AlertDecryptError(Alert):
|
|
description = AlertDescription.decrypt_error
|
|
|
|
|
|
class AlertHandshakeFailure(Alert):
|
|
description = AlertDescription.handshake_failure
|
|
|
|
|
|
class AlertIllegalParameter(Alert):
|
|
description = AlertDescription.illegal_parameter
|
|
|
|
|
|
class AlertInternalError(Alert):
|
|
description = AlertDescription.internal_error
|
|
|
|
|
|
class AlertProtocolVersion(Alert):
|
|
description = AlertDescription.protocol_version
|
|
|
|
|
|
class AlertUnexpectedMessage(Alert):
|
|
description = AlertDescription.unexpected_message
|
|
|
|
|
|
class Direction(IntEnum):
|
|
DECRYPT = 0
|
|
ENCRYPT = 1
|
|
|
|
|
|
class Epoch(IntEnum):
|
|
INITIAL = 0
|
|
ZERO_RTT = 1
|
|
HANDSHAKE = 2
|
|
ONE_RTT = 3
|
|
|
|
|
|
class State(IntEnum):
|
|
CLIENT_HANDSHAKE_START = 0
|
|
CLIENT_EXPECT_SERVER_HELLO = 1
|
|
CLIENT_EXPECT_ENCRYPTED_EXTENSIONS = 2
|
|
CLIENT_EXPECT_CERTIFICATE_REQUEST_OR_CERTIFICATE = 3
|
|
CLIENT_EXPECT_CERTIFICATE_CERTIFICATE = 4
|
|
CLIENT_EXPECT_CERTIFICATE_VERIFY = 5
|
|
CLIENT_EXPECT_FINISHED = 6
|
|
CLIENT_POST_HANDSHAKE = 7
|
|
|
|
SERVER_EXPECT_CLIENT_HELLO = 8
|
|
SERVER_EXPECT_FINISHED = 9
|
|
SERVER_POST_HANDSHAKE = 10
|
|
|
|
|
|
class HKDFExpand:
|
|
def __init__(
|
|
self,
|
|
algorithm: int,
|
|
length: int,
|
|
info: bytes | None,
|
|
):
|
|
self._algorithm = algorithm
|
|
self._digest_size = int(algorithm / 8)
|
|
|
|
max_length = 255 * self._digest_size
|
|
|
|
if length > max_length:
|
|
raise ValueError(f"Cannot derive keys larger than {max_length} octets.")
|
|
|
|
self._length = length
|
|
|
|
if info is None:
|
|
info = b""
|
|
|
|
self._info = info
|
|
self._used = False
|
|
|
|
def _expand(self, key_material: bytes) -> bytes:
|
|
output = [b""]
|
|
counter = 1
|
|
|
|
while self._digest_size * (len(output) - 1) < self._length:
|
|
h = HMAC(key_material, digestmod=f"sha{self._algorithm}")
|
|
h.update(output[-1])
|
|
h.update(self._info)
|
|
h.update(bytes([counter]))
|
|
output.append(h.digest())
|
|
counter += 1
|
|
|
|
return b"".join(output)[: self._length]
|
|
|
|
def derive(self, key_material: bytes) -> bytes:
|
|
if self._used:
|
|
raise CryptoError
|
|
|
|
self._used = True
|
|
return self._expand(key_material)
|
|
|
|
|
|
def hkdf_label(label: bytes, hash_value: bytes, length: int) -> bytes:
|
|
full_label = b"tls13 " + label
|
|
return (
|
|
struct.pack("!HB", length, len(full_label))
|
|
+ full_label
|
|
+ struct.pack("!B", len(hash_value))
|
|
+ hash_value
|
|
)
|
|
|
|
|
|
def hkdf_expand_label(
|
|
algorithm: int,
|
|
secret: bytes,
|
|
label: bytes,
|
|
hash_value: bytes,
|
|
length: int,
|
|
) -> bytes:
|
|
return HKDFExpand(
|
|
algorithm=algorithm,
|
|
length=length,
|
|
info=hkdf_label(label, hash_value, length),
|
|
).derive(secret)
|
|
|
|
|
|
def hkdf_extract(algorithm: int, salt: bytes, key_material: bytes) -> bytes:
|
|
h = HMAC(salt, digestmod=f"sha{algorithm}")
|
|
h.update(key_material)
|
|
return h.digest()
|
|
|
|
|
|
def load_pem_private_key(
|
|
data: bytes, password: bytes | None = None
|
|
) -> EcPrivateKey | DsaPrivateKey | RsaPrivateKey | Ed25519PrivateKey:
|
|
"""
|
|
Load a PEM-encoded private key.
|
|
"""
|
|
pkey_info = PrivateKeyInfo(data, password)
|
|
|
|
if pkey_info.get_type() in [
|
|
KeyType.ECDSA_P256,
|
|
KeyType.ECDSA_P384,
|
|
KeyType.ECDSA_P521,
|
|
]:
|
|
curve_type = None
|
|
if pkey_info.get_type() == KeyType.ECDSA_P256:
|
|
curve_type = 256
|
|
elif pkey_info.get_type() == KeyType.ECDSA_P384:
|
|
curve_type = 384
|
|
elif pkey_info.get_type() == KeyType.ECDSA_P521:
|
|
curve_type = 521
|
|
|
|
assert curve_type is not None
|
|
|
|
return EcPrivateKey(
|
|
pkey_info.public_bytes(), curve_type, b"BEGIN EC PRIVATE KEY" not in data
|
|
)
|
|
elif pkey_info.get_type() == KeyType.DSA:
|
|
return DsaPrivateKey(pkey_info.public_bytes())
|
|
elif pkey_info.get_type() == KeyType.RSA:
|
|
return RsaPrivateKey(pkey_info.public_bytes())
|
|
elif pkey_info.get_type() == KeyType.ED25519:
|
|
return Ed25519PrivateKey(pkey_info.public_bytes())
|
|
|
|
raise ssl.SSLError("Unsupported private key format")
|
|
|
|
|
|
def load_pem_x509_certificates(data: bytes) -> list[X509Certificate]:
|
|
"""
|
|
Load a chain of PEM-encoded X509 certificates.
|
|
"""
|
|
line_ending = b"\n" if b"-----\r\n" not in data else b"\r\n"
|
|
boundary = b"-----END CERTIFICATE-----" + line_ending
|
|
certificates = []
|
|
for chunk in data.split(boundary):
|
|
if chunk:
|
|
start_marker = chunk.find(b"-----BEGIN CERTIFICATE-----" + line_ending)
|
|
if start_marker == -1:
|
|
break
|
|
pem_reconstructed = b"".join([chunk[start_marker:], boundary]).decode(
|
|
"ascii"
|
|
)
|
|
certificates.append(
|
|
X509Certificate(ssl.PEM_cert_to_DER_cert(pem_reconstructed))
|
|
)
|
|
return certificates
|
|
|
|
|
|
def _capath_contains_certs(capath: str) -> bool:
|
|
"""Check whether capath exists and contains certs in the expected format."""
|
|
if not os.path.isdir(capath):
|
|
return False
|
|
for name in os.listdir(capath):
|
|
if _HASHED_CERT_FILENAME_RE.match(name):
|
|
return True
|
|
return False
|
|
|
|
|
|
@lru_cache(maxsize=64)
|
|
def load_store_and_sort(
|
|
cadata: bytes | None = None,
|
|
cafile: str | None = None,
|
|
capath: str | None = None,
|
|
) -> tuple[list[X509Certificate], list[X509Certificate], list[X509Certificate]]:
|
|
"""
|
|
Given cadata, cafile and capath load X509 certificates and sort
|
|
them into three distinct list:
|
|
- Trust anchors (ca self-signed)
|
|
- Intermediates (ca signed by other ca)
|
|
- Others (not suitable for our purposes)
|
|
|
|
This function consumes a lot of CPU times, so we want to cache it.
|
|
"""
|
|
trust_anchors = []
|
|
intermediaries = []
|
|
others = []
|
|
|
|
def _sort_cert_in_appropriate_list(c) -> None:
|
|
nonlocal trust_anchors, intermediaries, others
|
|
|
|
if c.self_signed:
|
|
# root CA must be tagged OTHER
|
|
# EKU must not have client auth
|
|
# or server auth. it's a red flag! period.
|
|
if c.usage != TlsCertUsage.Other:
|
|
others.append(c)
|
|
else:
|
|
trust_anchors.append(c)
|
|
else:
|
|
intermediaries.append(c)
|
|
|
|
if cadata is not None:
|
|
for cert in load_pem_x509_certificates(cadata):
|
|
_sort_cert_in_appropriate_list(cert)
|
|
|
|
if cafile is not None or capath is not None:
|
|
if cafile:
|
|
with open(cafile, "rb") as fp:
|
|
for cert in load_pem_x509_certificates(fp.read()):
|
|
_sort_cert_in_appropriate_list(cert)
|
|
if capath and _capath_contains_certs(capath):
|
|
for path in glob.glob(f"{capath}/*"):
|
|
with open(path, "rb") as fp:
|
|
for cert in load_pem_x509_certificates(fp.read()):
|
|
_sort_cert_in_appropriate_list(cert)
|
|
|
|
if cadata is None and cafile is None and capath is None:
|
|
default_ctx = ssl.create_default_context()
|
|
default_ctx.load_default_certs()
|
|
|
|
for ca in default_ctx.get_ca_certs(binary_form=True):
|
|
_sort_cert_in_appropriate_list(X509Certificate(ca))
|
|
|
|
return trust_anchors, intermediaries, others
|
|
|
|
|
|
def verify_certificate(
|
|
certificate: X509Certificate,
|
|
chain: list[X509Certificate] = None,
|
|
cadata: bytes | None = None,
|
|
cafile: str | None = None,
|
|
capath: str | None = None,
|
|
server_name: str | None = None,
|
|
assert_server_name: bool = True,
|
|
ocsp_response: bytes | None = None,
|
|
) -> None:
|
|
if chain is None:
|
|
chain = []
|
|
|
|
trust_anchors, intermediaries, _ = load_store_and_sort(
|
|
cadata=cadata,
|
|
cafile=cafile,
|
|
capath=capath,
|
|
)
|
|
|
|
if server_name is None or assert_server_name is False:
|
|
# get_subject_alt_names()... caution for :
|
|
# IPAddress(20:01:48:60:48:60:00:00:00:00:00:00:00:00:00:64)
|
|
# or..
|
|
# IPAddress(08:08:08:08)
|
|
for alt_name in certificate.get_subject_alt_names():
|
|
server_name_candidate = alt_name.decode()
|
|
server_name_candidate = server_name_candidate[
|
|
server_name_candidate.find("(") + 1 : server_name_candidate.find(")")
|
|
]
|
|
server_name_candidate.replace("*.", "unverified.")
|
|
|
|
if ":" in server_name_candidate:
|
|
if len(server_name_candidate) == 11:
|
|
server_name = ".".join(
|
|
str(int(p)) for p in server_name_candidate.split(":")
|
|
)
|
|
else:
|
|
continue
|
|
|
|
else:
|
|
server_name = server_name_candidate
|
|
|
|
break
|
|
|
|
if server_name is None:
|
|
raise AlertBadCertificate("unable to determine server name target")
|
|
|
|
if not trust_anchors:
|
|
raise AlertBadCertificate(
|
|
"unable to get local issuer certificate (empty CA store)"
|
|
)
|
|
|
|
# rebuild the intermediate chain locally
|
|
# in case the server did not pass them along
|
|
# and the configuration does hold a list of intermediates.
|
|
if not chain and intermediaries:
|
|
raw_chain = rebuild_chain(
|
|
certificate.public_bytes(),
|
|
[c.public_bytes() for c in intermediaries],
|
|
)
|
|
|
|
if len(raw_chain) >= 2:
|
|
for i in raw_chain[1:]:
|
|
chain.append(X509Certificate(i))
|
|
else:
|
|
chain = []
|
|
|
|
# load CAs
|
|
try:
|
|
store = ServerVerifier([c.public_bytes() for c in trust_anchors])
|
|
except CryptoError as e:
|
|
raise AlertBadCertificate("unable to create the verifier x509 store") from e
|
|
|
|
try:
|
|
store.verify(
|
|
certificate.public_bytes(),
|
|
[c.public_bytes() for c in chain],
|
|
server_name,
|
|
ocsp_response or b"",
|
|
)
|
|
except (
|
|
SelfSignedCertificateError,
|
|
InvalidNameCertificateError,
|
|
ExpiredCertificateError,
|
|
UnacceptableCertificateError,
|
|
) as exc:
|
|
if isinstance(exc, InvalidNameCertificateError) and assert_server_name is False:
|
|
return
|
|
raise AlertBadCertificate(exc.args[0])
|
|
|
|
|
|
class CipherSuite(IntEnum):
|
|
AES_128_GCM_SHA256 = 0x1301
|
|
AES_256_GCM_SHA384 = 0x1302
|
|
CHACHA20_POLY1305_SHA256 = 0x1303
|
|
EMPTY_RENEGOTIATION_INFO_SCSV = 0x00FF
|
|
GREASE = 0xDADA
|
|
|
|
|
|
class CompressionMethod(IntEnum):
|
|
NULL = 0
|
|
|
|
|
|
class ExtensionType(IntEnum):
|
|
SERVER_NAME = 0
|
|
STATUS_REQUEST = 5
|
|
SUPPORTED_GROUPS = 10
|
|
SIGNATURE_ALGORITHMS = 13
|
|
ALPN = 16
|
|
COMPRESS_CERTIFICATE = 27
|
|
PRE_SHARED_KEY = 41
|
|
EARLY_DATA = 42
|
|
SUPPORTED_VERSIONS = 43
|
|
COOKIE = 44
|
|
PSK_KEY_EXCHANGE_MODES = 45
|
|
KEY_SHARE = 51
|
|
QUIC_TRANSPORT_PARAMETERS = 0x0039
|
|
ENCRYPTED_SERVER_NAME = 65486
|
|
GREASE = 0x0A0A
|
|
|
|
|
|
class Group(IntEnum):
|
|
SECP256R1 = 0x0017
|
|
SECP384R1 = 0x0018
|
|
SECP521R1 = 0x0019
|
|
X25519KYBER768DRAFT00 = 0x6399
|
|
X25519ML768 = 0x11EC
|
|
X25519 = 0x001D
|
|
X448 = 0x001E
|
|
GREASE = 0xAAAA
|
|
|
|
|
|
class HandshakeType(IntEnum):
|
|
CLIENT_HELLO = 1
|
|
SERVER_HELLO = 2
|
|
NEW_SESSION_TICKET = 4
|
|
END_OF_EARLY_DATA = 5
|
|
ENCRYPTED_EXTENSIONS = 8
|
|
CERTIFICATE = 11
|
|
CERTIFICATE_REQUEST = 13
|
|
CERTIFICATE_VERIFY = 15
|
|
FINISHED = 20
|
|
KEY_UPDATE = 24
|
|
COMPRESSED_CERTIFICATE = 25
|
|
MESSAGE_HASH = 254
|
|
|
|
|
|
class PskKeyExchangeMode(IntEnum):
|
|
PSK_KE = 0
|
|
PSK_DHE_KE = 1
|
|
|
|
|
|
class SignatureAlgorithm(IntEnum):
|
|
ECDSA_SECP256R1_SHA256 = 0x0403
|
|
ECDSA_SECP384R1_SHA384 = 0x0503
|
|
ECDSA_SECP521R1_SHA512 = 0x0603
|
|
ED25519 = 0x0807
|
|
ED448 = 0x0808 # unsupported
|
|
RSA_PKCS1_SHA256 = 0x0401
|
|
RSA_PKCS1_SHA384 = 0x0501
|
|
RSA_PKCS1_SHA512 = 0x0601
|
|
RSA_PSS_PSS_SHA256 = 0x0809
|
|
RSA_PSS_PSS_SHA384 = 0x080A
|
|
RSA_PSS_PSS_SHA512 = 0x080B
|
|
RSA_PSS_RSAE_SHA256 = 0x0804
|
|
RSA_PSS_RSAE_SHA384 = 0x0805
|
|
RSA_PSS_RSAE_SHA512 = 0x0806
|
|
|
|
# unsafe, and unsupported (by us)!
|
|
RSA_PKCS1_SHA1 = 0x0201
|
|
SHA1_DSA = 0x0202
|
|
ECDSA_SHA1 = 0x0203
|
|
|
|
|
|
# BLOCKS
|
|
|
|
|
|
@contextmanager
|
|
def pull_block(buf: Buffer, capacity: int) -> Generator:
|
|
length = int.from_bytes(buf.pull_bytes(capacity), byteorder="big")
|
|
end = buf.tell() + length
|
|
yield length
|
|
assert buf.tell() == end
|
|
|
|
|
|
@contextmanager
|
|
def push_block(buf: Buffer, capacity: int) -> Generator:
|
|
"""
|
|
Context manager to push a variable-length block, with `capacity` bytes
|
|
to write the length.
|
|
"""
|
|
start = buf.tell() + capacity
|
|
buf.seek(start)
|
|
yield
|
|
end = buf.tell()
|
|
length = end - start
|
|
buf.seek(start - capacity)
|
|
buf.push_bytes(length.to_bytes(capacity, byteorder="big"))
|
|
buf.seek(end)
|
|
|
|
|
|
# LISTS
|
|
|
|
|
|
def pull_list(buf: Buffer, capacity: int, func: Callable[[], T]) -> list[T]:
|
|
"""
|
|
Pull a list of items.
|
|
"""
|
|
items = []
|
|
with pull_block(buf, capacity) as length:
|
|
end = buf.tell() + length
|
|
while buf.tell() < end:
|
|
items.append(func())
|
|
return items
|
|
|
|
|
|
def push_list(
|
|
buf: Buffer, capacity: int, func: Callable[[T], None], values: Sequence[T]
|
|
) -> None:
|
|
"""
|
|
Push a list of items.
|
|
"""
|
|
with push_block(buf, capacity):
|
|
for value in values:
|
|
func(value)
|
|
|
|
|
|
def pull_opaque(buf: Buffer, capacity: int) -> bytes:
|
|
"""
|
|
Pull an opaque value prefixed by a length.
|
|
"""
|
|
with pull_block(buf, capacity) as length:
|
|
return buf.pull_bytes(length)
|
|
|
|
|
|
def push_opaque(buf: Buffer, capacity: int, value: bytes) -> None:
|
|
"""
|
|
Push an opaque value prefix by a length.
|
|
"""
|
|
with push_block(buf, capacity):
|
|
buf.push_bytes(value)
|
|
|
|
|
|
@contextmanager
|
|
def push_extension(buf: Buffer, extension_type: int) -> Generator:
|
|
buf.push_uint16(extension_type)
|
|
with push_block(buf, 2):
|
|
yield
|
|
|
|
|
|
# KeyShareEntry
|
|
|
|
|
|
KeyShareEntry = Tuple[int, bytes]
|
|
|
|
|
|
def pull_key_share(buf: Buffer) -> KeyShareEntry:
|
|
group = buf.pull_uint16()
|
|
data = pull_opaque(buf, 2)
|
|
return (group, data)
|
|
|
|
|
|
def push_key_share(buf: Buffer, value: KeyShareEntry) -> None:
|
|
buf.push_uint16(value[0])
|
|
push_opaque(buf, 2, value[1])
|
|
|
|
|
|
# ALPN
|
|
|
|
|
|
def pull_alpn_protocol(buf: Buffer) -> str:
|
|
return pull_opaque(buf, 1).decode("ascii")
|
|
|
|
|
|
def push_alpn_protocol(buf: Buffer, protocol: str) -> None:
|
|
push_opaque(buf, 1, protocol.encode("ascii"))
|
|
|
|
|
|
# PRE SHARED KEY
|
|
|
|
PskIdentity = Tuple[bytes, int]
|
|
|
|
|
|
def pull_psk_identity(buf: Buffer) -> PskIdentity:
|
|
identity = pull_opaque(buf, 2)
|
|
obfuscated_ticket_age = buf.pull_uint32()
|
|
return (identity, obfuscated_ticket_age)
|
|
|
|
|
|
def push_psk_identity(buf: Buffer, entry: PskIdentity) -> None:
|
|
push_opaque(buf, 2, entry[0])
|
|
buf.push_uint32(entry[1])
|
|
|
|
|
|
def pull_psk_binder(buf: Buffer) -> bytes:
|
|
return pull_opaque(buf, 1)
|
|
|
|
|
|
def push_psk_binder(buf: Buffer, binder: bytes) -> None:
|
|
push_opaque(buf, 1, binder)
|
|
|
|
|
|
# MESSAGES
|
|
|
|
Extension = Tuple[int, bytes]
|
|
|
|
|
|
@dataclass
|
|
class OfferedPsks:
|
|
identities: list[PskIdentity]
|
|
binders: list[bytes]
|
|
|
|
|
|
@dataclass
|
|
class ClientHello:
|
|
random: bytes
|
|
legacy_session_id: bytes
|
|
cipher_suites: list[int]
|
|
legacy_compression_methods: list[int]
|
|
|
|
# extensions
|
|
alpn_protocols: list[str] | None = None
|
|
early_data: bool = False
|
|
key_share: list[KeyShareEntry] | None = None
|
|
pre_shared_key: OfferedPsks | None = None
|
|
psk_key_exchange_modes: list[int] | None = None
|
|
server_name: str | None = None
|
|
signature_algorithms: list[int] | None = None
|
|
supported_groups: list[int] | None = None
|
|
supported_versions: list[int] | None = None
|
|
|
|
other_extensions: list[Extension] = field(default_factory=list)
|
|
|
|
|
|
def pull_client_hello(buf: Buffer) -> ClientHello:
|
|
assert buf.pull_uint8() == HandshakeType.CLIENT_HELLO
|
|
with pull_block(buf, 3):
|
|
assert buf.pull_uint16() == TLS_VERSION_1_2
|
|
|
|
hello = ClientHello(
|
|
random=buf.pull_bytes(32),
|
|
legacy_session_id=pull_opaque(buf, 1),
|
|
cipher_suites=pull_list(buf, 2, buf.pull_uint16),
|
|
legacy_compression_methods=pull_list(buf, 1, buf.pull_uint8),
|
|
)
|
|
|
|
# extensions
|
|
after_psk = False
|
|
|
|
def pull_extension() -> None:
|
|
# pre_shared_key MUST be last
|
|
nonlocal after_psk
|
|
assert not after_psk
|
|
|
|
extension_type = buf.pull_uint16()
|
|
extension_length = buf.pull_uint16()
|
|
if extension_type == ExtensionType.KEY_SHARE:
|
|
hello.key_share = pull_list(buf, 2, partial(pull_key_share, buf))
|
|
elif extension_type == ExtensionType.SUPPORTED_VERSIONS:
|
|
hello.supported_versions = pull_list(buf, 1, buf.pull_uint16)
|
|
elif extension_type == ExtensionType.SIGNATURE_ALGORITHMS:
|
|
hello.signature_algorithms = pull_list(buf, 2, buf.pull_uint16)
|
|
elif extension_type == ExtensionType.SUPPORTED_GROUPS:
|
|
hello.supported_groups = pull_list(buf, 2, buf.pull_uint16)
|
|
elif extension_type == ExtensionType.PSK_KEY_EXCHANGE_MODES:
|
|
hello.psk_key_exchange_modes = pull_list(buf, 1, buf.pull_uint8)
|
|
elif extension_type == ExtensionType.SERVER_NAME:
|
|
with pull_block(buf, 2):
|
|
assert buf.pull_uint8() == 0
|
|
hello.server_name = pull_opaque(buf, 2).decode("ascii")
|
|
elif extension_type == ExtensionType.ALPN:
|
|
hello.alpn_protocols = pull_list(
|
|
buf, 2, partial(pull_alpn_protocol, buf)
|
|
)
|
|
elif extension_type == ExtensionType.EARLY_DATA:
|
|
hello.early_data = True
|
|
elif extension_type == ExtensionType.PRE_SHARED_KEY:
|
|
hello.pre_shared_key = OfferedPsks(
|
|
identities=pull_list(buf, 2, partial(pull_psk_identity, buf)),
|
|
binders=pull_list(buf, 2, partial(pull_psk_binder, buf)),
|
|
)
|
|
after_psk = True
|
|
elif extension_type == ExtensionType.STATUS_REQUEST:
|
|
buf.pull_bytes(
|
|
extension_length
|
|
) # we don't implement it for the server...
|
|
elif extension_type == ExtensionType.GREASE:
|
|
pass # simply ignore it!
|
|
else:
|
|
hello.other_extensions.append(
|
|
(extension_type, buf.pull_bytes(extension_length))
|
|
)
|
|
|
|
pull_list(buf, 2, pull_extension)
|
|
|
|
return hello
|
|
|
|
|
|
def push_client_hello(buf: Buffer, hello: ClientHello) -> None:
|
|
buf.push_uint8(HandshakeType.CLIENT_HELLO)
|
|
with push_block(buf, 3):
|
|
buf.push_uint16(TLS_VERSION_1_2)
|
|
buf.push_bytes(hello.random)
|
|
push_opaque(buf, 1, hello.legacy_session_id)
|
|
push_list(buf, 2, buf.push_uint16, hello.cipher_suites)
|
|
push_list(buf, 1, buf.push_uint8, hello.legacy_compression_methods)
|
|
|
|
# extensions
|
|
with push_block(buf, 2):
|
|
with push_extension(buf, ExtensionType.GREASE):
|
|
pass
|
|
|
|
with push_extension(buf, ExtensionType.KEY_SHARE):
|
|
push_list(buf, 2, partial(push_key_share, buf), hello.key_share)
|
|
|
|
with push_extension(buf, ExtensionType.SUPPORTED_VERSIONS):
|
|
push_list(buf, 1, buf.push_uint16, hello.supported_versions)
|
|
|
|
with push_extension(buf, ExtensionType.SIGNATURE_ALGORITHMS):
|
|
push_list(buf, 2, buf.push_uint16, hello.signature_algorithms)
|
|
|
|
with push_extension(buf, ExtensionType.SUPPORTED_GROUPS):
|
|
push_list(buf, 2, buf.push_uint16, hello.supported_groups)
|
|
|
|
if hello.psk_key_exchange_modes is not None:
|
|
with push_extension(buf, ExtensionType.PSK_KEY_EXCHANGE_MODES):
|
|
push_list(buf, 1, buf.push_uint8, hello.psk_key_exchange_modes)
|
|
|
|
if hello.server_name is not None:
|
|
with push_extension(buf, ExtensionType.SERVER_NAME):
|
|
with push_block(buf, 2):
|
|
buf.push_uint8(0)
|
|
push_opaque(buf, 2, hello.server_name.encode("ascii"))
|
|
|
|
if hello.alpn_protocols is not None:
|
|
with push_extension(buf, ExtensionType.ALPN):
|
|
push_list(
|
|
buf, 2, partial(push_alpn_protocol, buf), hello.alpn_protocols
|
|
)
|
|
|
|
for extension_type, extension_value in hello.other_extensions:
|
|
with push_extension(buf, extension_type):
|
|
buf.push_bytes(extension_value)
|
|
|
|
if hello.early_data:
|
|
with push_extension(buf, ExtensionType.EARLY_DATA):
|
|
pass
|
|
|
|
with push_extension(buf, ExtensionType.STATUS_REQUEST):
|
|
buf.push_uint8(1) # OCSP
|
|
with push_block(buf, 2): # empty responder_id_list
|
|
pass
|
|
with push_block(buf, 2): # empty extensions
|
|
pass
|
|
|
|
# pre_shared_key MUST be last
|
|
if hello.pre_shared_key is not None:
|
|
with push_extension(buf, ExtensionType.PRE_SHARED_KEY):
|
|
push_list(
|
|
buf,
|
|
2,
|
|
partial(push_psk_identity, buf),
|
|
hello.pre_shared_key.identities,
|
|
)
|
|
push_list(
|
|
buf,
|
|
2,
|
|
partial(push_psk_binder, buf),
|
|
hello.pre_shared_key.binders,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class ServerHello:
|
|
random: bytes
|
|
legacy_session_id: bytes
|
|
cipher_suite: int
|
|
compression_method: int
|
|
|
|
# extensions
|
|
key_share: KeyShareEntry | None = None
|
|
pre_shared_key: int | None = None
|
|
supported_version: int | None = None
|
|
other_extensions: list[tuple[int, bytes]] = field(default_factory=list)
|
|
|
|
|
|
def pull_server_hello(buf: Buffer) -> ServerHello:
|
|
assert buf.pull_uint8() == HandshakeType.SERVER_HELLO
|
|
with pull_block(buf, 3):
|
|
assert buf.pull_uint16() == TLS_VERSION_1_2
|
|
|
|
hello = ServerHello(
|
|
random=buf.pull_bytes(32),
|
|
legacy_session_id=pull_opaque(buf, 1),
|
|
cipher_suite=buf.pull_uint16(),
|
|
compression_method=buf.pull_uint8(),
|
|
)
|
|
|
|
# extensions
|
|
def pull_extension() -> None:
|
|
extension_type = buf.pull_uint16()
|
|
extension_length = buf.pull_uint16()
|
|
if extension_type == ExtensionType.SUPPORTED_VERSIONS:
|
|
hello.supported_version = buf.pull_uint16()
|
|
elif extension_type == ExtensionType.KEY_SHARE:
|
|
hello.key_share = pull_key_share(buf)
|
|
elif extension_type == ExtensionType.PRE_SHARED_KEY:
|
|
hello.pre_shared_key = buf.pull_uint16()
|
|
else:
|
|
hello.other_extensions.append(
|
|
(extension_type, buf.pull_bytes(extension_length))
|
|
)
|
|
|
|
pull_list(buf, 2, pull_extension)
|
|
|
|
return hello
|
|
|
|
|
|
def push_server_hello(buf: Buffer, hello: ServerHello) -> None:
|
|
buf.push_uint8(HandshakeType.SERVER_HELLO)
|
|
with push_block(buf, 3):
|
|
buf.push_uint16(TLS_VERSION_1_2)
|
|
buf.push_bytes(hello.random)
|
|
|
|
push_opaque(buf, 1, hello.legacy_session_id)
|
|
buf.push_uint16(hello.cipher_suite)
|
|
buf.push_uint8(hello.compression_method)
|
|
|
|
# extensions
|
|
with push_block(buf, 2):
|
|
if hello.supported_version is not None:
|
|
with push_extension(buf, ExtensionType.SUPPORTED_VERSIONS):
|
|
buf.push_uint16(hello.supported_version)
|
|
|
|
if hello.key_share is not None:
|
|
with push_extension(buf, ExtensionType.KEY_SHARE):
|
|
push_key_share(buf, hello.key_share)
|
|
|
|
if hello.pre_shared_key is not None:
|
|
with push_extension(buf, ExtensionType.PRE_SHARED_KEY):
|
|
buf.push_uint16(hello.pre_shared_key)
|
|
|
|
for extension_type, extension_value in hello.other_extensions:
|
|
with push_extension(buf, extension_type):
|
|
buf.push_bytes(extension_value)
|
|
|
|
|
|
@dataclass
|
|
class NewSessionTicket:
|
|
ticket_lifetime: int = 0
|
|
ticket_age_add: int = 0
|
|
ticket_nonce: bytes = b""
|
|
ticket: bytes = b""
|
|
|
|
# extensions
|
|
max_early_data_size: int | None = None
|
|
other_extensions: list[tuple[int, bytes]] = field(default_factory=list)
|
|
|
|
|
|
def pull_new_session_ticket(buf: Buffer) -> NewSessionTicket:
|
|
new_session_ticket = NewSessionTicket()
|
|
|
|
assert buf.pull_uint8() == HandshakeType.NEW_SESSION_TICKET
|
|
with pull_block(buf, 3):
|
|
new_session_ticket.ticket_lifetime = buf.pull_uint32()
|
|
new_session_ticket.ticket_age_add = buf.pull_uint32()
|
|
new_session_ticket.ticket_nonce = pull_opaque(buf, 1)
|
|
new_session_ticket.ticket = pull_opaque(buf, 2)
|
|
|
|
def pull_extension() -> None:
|
|
extension_type = buf.pull_uint16()
|
|
extension_length = buf.pull_uint16()
|
|
if extension_type == ExtensionType.EARLY_DATA:
|
|
new_session_ticket.max_early_data_size = buf.pull_uint32()
|
|
else:
|
|
new_session_ticket.other_extensions.append(
|
|
(extension_type, buf.pull_bytes(extension_length))
|
|
)
|
|
|
|
pull_list(buf, 2, pull_extension)
|
|
|
|
return new_session_ticket
|
|
|
|
|
|
def push_new_session_ticket(buf: Buffer, new_session_ticket: NewSessionTicket) -> None:
|
|
buf.push_uint8(HandshakeType.NEW_SESSION_TICKET)
|
|
with push_block(buf, 3):
|
|
buf.push_uint32(new_session_ticket.ticket_lifetime)
|
|
buf.push_uint32(new_session_ticket.ticket_age_add)
|
|
push_opaque(buf, 1, new_session_ticket.ticket_nonce)
|
|
push_opaque(buf, 2, new_session_ticket.ticket)
|
|
|
|
with push_block(buf, 2):
|
|
if new_session_ticket.max_early_data_size is not None:
|
|
with push_extension(buf, ExtensionType.EARLY_DATA):
|
|
buf.push_uint32(new_session_ticket.max_early_data_size)
|
|
|
|
for extension_type, extension_value in new_session_ticket.other_extensions:
|
|
with push_extension(buf, extension_type):
|
|
buf.push_bytes(extension_value)
|
|
|
|
|
|
@dataclass
|
|
class EncryptedExtensions:
|
|
alpn_protocol: str | None = None
|
|
early_data: bool = False
|
|
|
|
other_extensions: list[tuple[int, bytes]] = field(default_factory=list)
|
|
|
|
|
|
def pull_encrypted_extensions(buf: Buffer) -> EncryptedExtensions:
|
|
extensions = EncryptedExtensions()
|
|
|
|
assert buf.pull_uint8() == HandshakeType.ENCRYPTED_EXTENSIONS
|
|
with pull_block(buf, 3):
|
|
|
|
def pull_extension() -> None:
|
|
extension_type = buf.pull_uint16()
|
|
extension_length = buf.pull_uint16()
|
|
if extension_type == ExtensionType.ALPN:
|
|
extensions.alpn_protocol = pull_list(
|
|
buf, 2, partial(pull_alpn_protocol, buf)
|
|
)[0]
|
|
elif extension_type == ExtensionType.EARLY_DATA:
|
|
extensions.early_data = True
|
|
else:
|
|
extensions.other_extensions.append(
|
|
(extension_type, buf.pull_bytes(extension_length))
|
|
)
|
|
|
|
pull_list(buf, 2, pull_extension)
|
|
|
|
return extensions
|
|
|
|
|
|
def push_encrypted_extensions(buf: Buffer, extensions: EncryptedExtensions) -> None:
|
|
buf.push_uint8(HandshakeType.ENCRYPTED_EXTENSIONS)
|
|
with push_block(buf, 3):
|
|
with push_block(buf, 2):
|
|
if extensions.alpn_protocol is not None:
|
|
with push_extension(buf, ExtensionType.ALPN):
|
|
push_list(
|
|
buf,
|
|
2,
|
|
partial(push_alpn_protocol, buf),
|
|
[extensions.alpn_protocol],
|
|
)
|
|
|
|
if extensions.early_data:
|
|
with push_extension(buf, ExtensionType.EARLY_DATA):
|
|
pass
|
|
|
|
for extension_type, extension_value in extensions.other_extensions:
|
|
with push_extension(buf, extension_type):
|
|
buf.push_bytes(extension_value)
|
|
|
|
|
|
CertificateEntry = Tuple[bytes, bytes]
|
|
|
|
|
|
@dataclass
|
|
class Certificate:
|
|
request_context: bytes = b""
|
|
certificates: list[CertificateEntry] = field(default_factory=list)
|
|
|
|
|
|
@dataclass
|
|
class CertificateRequest:
|
|
request_context: bytes = b""
|
|
signature_algorithms: list[int] | None = None
|
|
other_extensions: list[tuple[int, bytes]] = field(default_factory=list)
|
|
|
|
|
|
def pull_certificate(buf: Buffer) -> Certificate:
|
|
certificate = Certificate()
|
|
|
|
assert buf.pull_uint8() == HandshakeType.CERTIFICATE
|
|
with pull_block(buf, 3):
|
|
certificate.request_context = pull_opaque(buf, 1)
|
|
|
|
def pull_certificate_entry(buf: Buffer) -> CertificateEntry:
|
|
data = pull_opaque(buf, 3)
|
|
extensions = pull_opaque(buf, 2)
|
|
return (data, extensions)
|
|
|
|
certificate.certificates = pull_list(
|
|
buf, 3, partial(pull_certificate_entry, buf)
|
|
)
|
|
|
|
return certificate
|
|
|
|
|
|
def pull_certificate_request(buf: Buffer) -> CertificateRequest:
|
|
certificate_request = CertificateRequest()
|
|
|
|
assert buf.pull_uint8() == HandshakeType.CERTIFICATE_REQUEST
|
|
with pull_block(buf, 3):
|
|
certificate_request.request_context = pull_opaque(buf, 1)
|
|
|
|
def pull_extension() -> None:
|
|
extension_type = buf.pull_uint16()
|
|
extension_length = buf.pull_uint16()
|
|
if extension_type == ExtensionType.SIGNATURE_ALGORITHMS:
|
|
certificate_request.signature_algorithms = pull_list(
|
|
buf, 2, buf.pull_uint16
|
|
)
|
|
else:
|
|
certificate_request.other_extensions.append(
|
|
(extension_type, buf.pull_bytes(extension_length))
|
|
)
|
|
|
|
pull_list(buf, 2, pull_extension)
|
|
|
|
return certificate_request
|
|
|
|
|
|
def push_certificate(buf: Buffer, certificate: Certificate) -> None:
|
|
buf.push_uint8(HandshakeType.CERTIFICATE)
|
|
with push_block(buf, 3):
|
|
push_opaque(buf, 1, certificate.request_context)
|
|
|
|
def push_certificate_entry(buf: Buffer, entry: CertificateEntry) -> None:
|
|
push_opaque(buf, 3, entry[0])
|
|
push_opaque(buf, 2, entry[1])
|
|
|
|
push_list(
|
|
buf, 3, partial(push_certificate_entry, buf), certificate.certificates
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class CertificateVerify:
|
|
algorithm: int
|
|
signature: bytes
|
|
|
|
|
|
def pull_certificate_verify(buf: Buffer) -> CertificateVerify:
|
|
assert buf.pull_uint8() == HandshakeType.CERTIFICATE_VERIFY
|
|
with pull_block(buf, 3):
|
|
algorithm = buf.pull_uint16()
|
|
signature = pull_opaque(buf, 2)
|
|
|
|
return CertificateVerify(algorithm=algorithm, signature=signature)
|
|
|
|
|
|
def push_certificate_verify(buf: Buffer, verify: CertificateVerify) -> None:
|
|
buf.push_uint8(HandshakeType.CERTIFICATE_VERIFY)
|
|
with push_block(buf, 3):
|
|
buf.push_uint16(verify.algorithm)
|
|
push_opaque(buf, 2, verify.signature)
|
|
|
|
|
|
@dataclass
|
|
class Finished:
|
|
verify_data: bytes = b""
|
|
|
|
|
|
def pull_finished(buf: Buffer) -> Finished:
|
|
finished = Finished()
|
|
|
|
assert buf.pull_uint8() == HandshakeType.FINISHED
|
|
finished.verify_data = pull_opaque(buf, 3)
|
|
|
|
return finished
|
|
|
|
|
|
def push_finished(buf: Buffer, finished: Finished) -> None:
|
|
buf.push_uint8(HandshakeType.FINISHED)
|
|
push_opaque(buf, 3, finished.verify_data)
|
|
|
|
|
|
class KeySchedule:
|
|
def __init__(self, cipher_suite: CipherSuite):
|
|
self.algorithm = cipher_suite_hash(cipher_suite)
|
|
self.digest_size = int(self.algorithm / 8)
|
|
self.cipher_suite = cipher_suite
|
|
self.generation = 0
|
|
self.hash = hashlib.new(f"sha{self.algorithm}")
|
|
self.hash_empty_value = self.hash.copy().digest()
|
|
self.secret = bytes(self.digest_size)
|
|
|
|
def certificate_verify_data(self, context_string: bytes) -> bytes:
|
|
return b" " * 64 + context_string + b"\x00" + self.hash.copy().digest()
|
|
|
|
def finished_verify_data(self, secret: bytes) -> bytes:
|
|
hmac_key = hkdf_expand_label(
|
|
algorithm=self.algorithm,
|
|
secret=secret,
|
|
label=b"finished",
|
|
hash_value=b"",
|
|
length=self.digest_size,
|
|
)
|
|
|
|
h = HMAC(hmac_key, digestmod=f"sha{self.algorithm}")
|
|
h.update(self.hash.copy().digest())
|
|
return h.digest()
|
|
|
|
def derive_secret(self, label: bytes) -> bytes:
|
|
return hkdf_expand_label(
|
|
algorithm=self.algorithm,
|
|
secret=self.secret,
|
|
label=label,
|
|
hash_value=self.hash.copy().digest(),
|
|
length=self.digest_size,
|
|
)
|
|
|
|
def extract(self, key_material: bytes | None = None) -> None:
|
|
if key_material is None:
|
|
key_material = bytes(self.digest_size)
|
|
|
|
if self.generation:
|
|
self.secret = hkdf_expand_label(
|
|
algorithm=self.algorithm,
|
|
secret=self.secret,
|
|
label=b"derived",
|
|
hash_value=self.hash_empty_value,
|
|
length=self.digest_size,
|
|
)
|
|
|
|
self.generation += 1
|
|
self.secret = hkdf_extract(
|
|
algorithm=self.algorithm, salt=self.secret, key_material=key_material
|
|
)
|
|
|
|
def update_hash(self, data: bytes) -> None:
|
|
self.hash.update(data)
|
|
|
|
|
|
class KeyScheduleProxy:
|
|
def __init__(self, cipher_suites: list[CipherSuite]):
|
|
self.__schedules = dict(map(lambda c: (c, KeySchedule(c)), cipher_suites))
|
|
|
|
def extract(self, key_material: bytes | None = None) -> None:
|
|
for k in self.__schedules.values():
|
|
k.extract(key_material)
|
|
|
|
def select(self, cipher_suite: CipherSuite) -> KeySchedule:
|
|
return self.__schedules[cipher_suite]
|
|
|
|
def update_hash(self, data: bytes) -> None:
|
|
for k in self.__schedules.values():
|
|
k.update_hash(data)
|
|
|
|
|
|
CIPHER_SUITES = {
|
|
CipherSuite.AES_128_GCM_SHA256: 256,
|
|
CipherSuite.AES_256_GCM_SHA384: 384,
|
|
CipherSuite.CHACHA20_POLY1305_SHA256: 256,
|
|
}
|
|
|
|
SIGNATURE_ALGORITHMS: dict[SignatureAlgorithm, tuple[bool | None, int]] = {
|
|
SignatureAlgorithm.ECDSA_SECP256R1_SHA256: (None, 256),
|
|
SignatureAlgorithm.ECDSA_SECP384R1_SHA384: (None, 384),
|
|
SignatureAlgorithm.ECDSA_SECP521R1_SHA512: (None, 512),
|
|
SignatureAlgorithm.RSA_PKCS1_SHA256: (False, 256),
|
|
SignatureAlgorithm.RSA_PKCS1_SHA384: (False, 384),
|
|
SignatureAlgorithm.RSA_PKCS1_SHA512: (False, 512),
|
|
SignatureAlgorithm.RSA_PSS_RSAE_SHA256: (True, 256),
|
|
SignatureAlgorithm.RSA_PSS_RSAE_SHA384: (True, 384),
|
|
SignatureAlgorithm.RSA_PSS_RSAE_SHA512: (True, 512),
|
|
}
|
|
|
|
|
|
def cipher_suite_hash(cipher_suite: CipherSuite) -> int:
|
|
return CIPHER_SUITES[cipher_suite]
|
|
|
|
|
|
def negotiate(
|
|
supported: list[T],
|
|
offered: list[Any] | None,
|
|
exc: Alert | None = None,
|
|
excl: T | None = None,
|
|
) -> T:
|
|
if offered is not None:
|
|
for c in supported:
|
|
if c in offered:
|
|
if excl is not None and excl == c:
|
|
continue
|
|
return c
|
|
|
|
if exc is not None:
|
|
raise exc
|
|
return None
|
|
|
|
|
|
def signature_algorithm_params(signature_algorithm: int) -> tuple[Any, ...]:
|
|
if signature_algorithm in (SignatureAlgorithm.ED25519, SignatureAlgorithm.ED448):
|
|
return ()
|
|
|
|
is_pss, hash_size = SIGNATURE_ALGORITHMS[SignatureAlgorithm(signature_algorithm)]
|
|
|
|
if is_pss is None:
|
|
return ()
|
|
|
|
return (
|
|
is_pss,
|
|
hash_size,
|
|
)
|
|
|
|
|
|
@contextmanager
|
|
def push_message(
|
|
key_schedule: KeySchedule | KeyScheduleProxy, buf: Buffer
|
|
) -> Generator:
|
|
hash_start = buf.tell()
|
|
yield
|
|
key_schedule.update_hash(buf.data_slice(hash_start, buf.tell()))
|
|
|
|
|
|
# callback types
|
|
|
|
|
|
@dataclass
|
|
class SessionTicket:
|
|
"""
|
|
A TLS session ticket for session resumption.
|
|
"""
|
|
|
|
age_add: int
|
|
cipher_suite: CipherSuite
|
|
not_valid_after: datetime.datetime
|
|
not_valid_before: datetime.datetime
|
|
resumption_secret: bytes
|
|
server_name: str
|
|
ticket: bytes
|
|
|
|
max_early_data_size: int | None = None
|
|
other_extensions: list[tuple[int, bytes]] = field(default_factory=list)
|
|
|
|
@property
|
|
def is_valid(self) -> bool:
|
|
now = utcnow()
|
|
return now >= self.not_valid_before and now <= self.not_valid_after
|
|
|
|
@property
|
|
def obfuscated_age(self) -> int:
|
|
age = int((utcnow() - self.not_valid_before).total_seconds() * 1000)
|
|
return (age + self.age_add) % (1 << 32)
|
|
|
|
|
|
AlpnHandler = Callable[[str], None]
|
|
SessionTicketFetcher = Callable[[bytes], Optional[SessionTicket]]
|
|
SessionTicketHandler = Callable[[SessionTicket], None]
|
|
|
|
|
|
class Context:
|
|
def __init__(
|
|
self,
|
|
is_client: bool,
|
|
alpn_protocols: list[str] | None = None,
|
|
cadata: bytes | None = None,
|
|
cafile: str | None = None,
|
|
capath: str | None = None,
|
|
cipher_suites: list[CipherSuite] | None = None,
|
|
logger: logging.Logger | logging.LoggerAdapter | None = None,
|
|
max_early_data: int | None = None,
|
|
server_name: str | None = None,
|
|
verify_mode: int | None = None,
|
|
hostname_checks_common_name: bool = False,
|
|
assert_fingerprint: str | None = None,
|
|
verify_hostname: bool = True,
|
|
):
|
|
# configuration
|
|
self._alpn_protocols = alpn_protocols
|
|
self._cadata = cadata
|
|
self._cafile = cafile
|
|
self._capath = capath
|
|
self._hostname_checks_common_name = hostname_checks_common_name
|
|
self._assert_fingerprint = assert_fingerprint
|
|
self._verify_hostname = verify_hostname
|
|
self.certificate: X509Certificate | None = None
|
|
self.certificate_chain: list[X509Certificate] = []
|
|
self.certificate_private_key: (
|
|
EcPrivateKey | Ed25519PrivateKey | DsaPrivateKey | RsaPrivateKey | None
|
|
) = None
|
|
self.handshake_extensions: list[Extension] = []
|
|
self._max_early_data = max_early_data
|
|
self.session_ticket: SessionTicket | None = None
|
|
|
|
# ensure pure ascii server name
|
|
if server_name is not None and not server_name.isascii():
|
|
server_name = idna_encode(server_name).decode()
|
|
|
|
self._server_name = server_name
|
|
|
|
if verify_mode is not None:
|
|
self._verify_mode = verify_mode
|
|
else:
|
|
self._verify_mode = ssl.CERT_REQUIRED if is_client else ssl.CERT_NONE
|
|
|
|
# callbacks
|
|
self.alpn_cb: AlpnHandler | None = None
|
|
self.get_session_ticket_cb: SessionTicketFetcher | None = None
|
|
self.new_session_ticket_cb: SessionTicketHandler | None = None
|
|
self.update_traffic_key_cb: Callable[
|
|
[Direction, Epoch, CipherSuite, bytes], None
|
|
] = lambda d, e, c, s: None
|
|
|
|
# supported parameters
|
|
if cipher_suites is not None:
|
|
self._cipher_suites = cipher_suites
|
|
else:
|
|
self._cipher_suites = [
|
|
CipherSuite.GREASE,
|
|
CipherSuite.AES_128_GCM_SHA256,
|
|
CipherSuite.CHACHA20_POLY1305_SHA256,
|
|
CipherSuite.AES_256_GCM_SHA384,
|
|
]
|
|
self._legacy_compression_methods: list[int] = [CompressionMethod.NULL]
|
|
self._psk_key_exchange_modes: list[int] = [PskKeyExchangeMode.PSK_DHE_KE]
|
|
self._signature_algorithms: list[int] = [
|
|
SignatureAlgorithm.ECDSA_SECP256R1_SHA256,
|
|
SignatureAlgorithm.RSA_PSS_RSAE_SHA256,
|
|
SignatureAlgorithm.RSA_PKCS1_SHA256,
|
|
SignatureAlgorithm.ECDSA_SECP384R1_SHA384,
|
|
SignatureAlgorithm.RSA_PSS_RSAE_SHA384,
|
|
SignatureAlgorithm.RSA_PKCS1_SHA384,
|
|
SignatureAlgorithm.RSA_PSS_RSAE_SHA512,
|
|
SignatureAlgorithm.RSA_PKCS1_SHA512,
|
|
SignatureAlgorithm.ED25519,
|
|
]
|
|
|
|
self._supported_groups = [
|
|
Group.GREASE,
|
|
Group.X25519ML768,
|
|
Group.X25519,
|
|
Group.SECP256R1,
|
|
Group.SECP384R1,
|
|
]
|
|
|
|
self._supported_versions = [TLS_VERSION_GREASE, TLS_VERSION_1_3]
|
|
|
|
# state
|
|
self.alpn_negotiated: str | None = None
|
|
self.early_data_accepted: bool = False
|
|
self.key_schedule: KeySchedule | None = None
|
|
self.received_extensions: list[Extension] | None = None
|
|
self._key_schedule_psk: KeySchedule | None = None
|
|
self._key_schedule_proxy: KeyScheduleProxy | None = None
|
|
self._new_session_ticket: NewSessionTicket | None = None
|
|
self._peer_certificate: X509Certificate | None = None
|
|
self._peer_certificate_chain: list[X509Certificate] = []
|
|
self._ocsp_response: bytes | None = None
|
|
self._receive_buffer = b""
|
|
self._session_resumed = False
|
|
self._enc_key: bytes | None = None
|
|
self._dec_key: bytes | None = None
|
|
self._certificate_request: CertificateRequest | None = None
|
|
self.__logger = logger
|
|
|
|
# KeyExchange
|
|
self._ec_p256_private_key: ECDHP256KeyExchange | None = None
|
|
self._ec_p384_private_key: ECDHP384KeyExchange | None = None
|
|
self._ec_p521_private_key: ECDHP521KeyExchange | None = None
|
|
self._x25519_private_key: X25519KeyExchange | None = None
|
|
self._x25519_kyber_768_private_key: X25519ML768KeyExchange | None = None
|
|
|
|
if is_client:
|
|
self.client_random = os.urandom(32)
|
|
self.legacy_session_id = b""
|
|
self.state = State.CLIENT_HANDSHAKE_START
|
|
else:
|
|
self.client_random = None
|
|
self.legacy_session_id = None
|
|
self.state = State.SERVER_EXPECT_CLIENT_HELLO
|
|
|
|
@property
|
|
def peer_certificate(self) -> X509Certificate | None:
|
|
return self._peer_certificate
|
|
|
|
@property
|
|
def peer_certificate_chain(self) -> list[X509Certificate]:
|
|
return self._peer_certificate_chain
|
|
|
|
@property
|
|
def session_resumed(self) -> bool:
|
|
"""
|
|
Returns True if session resumption was successfully used.
|
|
"""
|
|
return self._session_resumed
|
|
|
|
def handle_message(
|
|
self, input_data: bytes, output_buf: dict[Epoch, Buffer]
|
|
) -> None:
|
|
if self.state == State.CLIENT_HANDSHAKE_START:
|
|
self._client_send_hello(output_buf[Epoch.INITIAL])
|
|
return
|
|
|
|
self._receive_buffer += input_data
|
|
while len(self._receive_buffer) >= 4:
|
|
# determine message length
|
|
message_type = self._receive_buffer[0]
|
|
message_length = 4 + int.from_bytes(
|
|
self._receive_buffer[1:4], byteorder="big"
|
|
)
|
|
|
|
# check message is complete
|
|
if len(self._receive_buffer) < message_length:
|
|
break
|
|
message = self._receive_buffer[:message_length]
|
|
self._receive_buffer = self._receive_buffer[message_length:]
|
|
|
|
input_buf = Buffer(data=message)
|
|
|
|
# client states
|
|
|
|
if self.state == State.CLIENT_EXPECT_SERVER_HELLO:
|
|
if message_type == HandshakeType.SERVER_HELLO:
|
|
self._client_handle_hello(input_buf, output_buf[Epoch.INITIAL])
|
|
else:
|
|
raise AlertUnexpectedMessage
|
|
elif self.state == State.CLIENT_EXPECT_ENCRYPTED_EXTENSIONS:
|
|
if message_type == HandshakeType.ENCRYPTED_EXTENSIONS:
|
|
self._client_handle_encrypted_extensions(input_buf)
|
|
else:
|
|
raise AlertUnexpectedMessage
|
|
elif self.state == State.CLIENT_EXPECT_CERTIFICATE_REQUEST_OR_CERTIFICATE:
|
|
if message_type == HandshakeType.CERTIFICATE:
|
|
self._client_handle_certificate(input_buf)
|
|
elif message_type == HandshakeType.CERTIFICATE_REQUEST:
|
|
self._client_handle_certificate_request(input_buf)
|
|
else:
|
|
raise AlertUnexpectedMessage
|
|
elif self.state == State.CLIENT_EXPECT_CERTIFICATE_VERIFY:
|
|
if message_type == HandshakeType.CERTIFICATE_VERIFY:
|
|
self._client_handle_certificate_verify(input_buf)
|
|
else:
|
|
raise AlertUnexpectedMessage
|
|
elif self.state == State.CLIENT_EXPECT_FINISHED:
|
|
if message_type == HandshakeType.FINISHED:
|
|
self._client_handle_finished(input_buf, output_buf[Epoch.HANDSHAKE])
|
|
else:
|
|
raise AlertUnexpectedMessage
|
|
elif self.state == State.CLIENT_POST_HANDSHAKE:
|
|
if message_type == HandshakeType.NEW_SESSION_TICKET:
|
|
self._client_handle_new_session_ticket(input_buf)
|
|
else:
|
|
raise AlertUnexpectedMessage
|
|
|
|
# server states
|
|
|
|
elif self.state == State.SERVER_EXPECT_CLIENT_HELLO:
|
|
if message_type == HandshakeType.CLIENT_HELLO:
|
|
self._server_handle_hello(
|
|
input_buf,
|
|
output_buf[Epoch.INITIAL],
|
|
output_buf[Epoch.HANDSHAKE],
|
|
output_buf[Epoch.ONE_RTT],
|
|
)
|
|
else:
|
|
raise AlertUnexpectedMessage
|
|
elif self.state == State.SERVER_EXPECT_FINISHED:
|
|
if message_type == HandshakeType.FINISHED:
|
|
self._server_handle_finished(input_buf, output_buf[Epoch.ONE_RTT])
|
|
else:
|
|
raise AlertUnexpectedMessage
|
|
elif self.state == State.SERVER_POST_HANDSHAKE:
|
|
raise AlertUnexpectedMessage
|
|
|
|
assert input_buf.eof()
|
|
|
|
def _build_session_ticket(
|
|
self, new_session_ticket: NewSessionTicket, other_extensions: list[Extension]
|
|
) -> SessionTicket:
|
|
resumption_master_secret = self.key_schedule.derive_secret(b"res master")
|
|
resumption_secret = hkdf_expand_label(
|
|
algorithm=self.key_schedule.algorithm,
|
|
secret=resumption_master_secret,
|
|
label=b"resumption",
|
|
hash_value=new_session_ticket.ticket_nonce,
|
|
length=self.key_schedule.digest_size,
|
|
)
|
|
|
|
timestamp = utcnow()
|
|
return SessionTicket(
|
|
age_add=new_session_ticket.ticket_age_add,
|
|
cipher_suite=self.key_schedule.cipher_suite,
|
|
max_early_data_size=new_session_ticket.max_early_data_size,
|
|
not_valid_after=timestamp
|
|
+ datetime.timedelta(seconds=new_session_ticket.ticket_lifetime),
|
|
not_valid_before=timestamp,
|
|
other_extensions=other_extensions,
|
|
resumption_secret=resumption_secret,
|
|
server_name=self._server_name,
|
|
ticket=new_session_ticket.ticket,
|
|
)
|
|
|
|
def _client_send_hello(self, output_buf: Buffer) -> None:
|
|
key_share: list[KeyShareEntry] = []
|
|
supported_groups: list[int] = []
|
|
|
|
for group in self._supported_groups:
|
|
if group == Group.SECP256R1:
|
|
self._ec_p256_private_key = ECDHP256KeyExchange()
|
|
key_share.append(
|
|
(Group.SECP256R1, self._ec_p256_private_key.public_key())
|
|
)
|
|
supported_groups.append(Group.SECP256R1)
|
|
elif group == Group.SECP384R1:
|
|
self._ec_p384_private_key = ECDHP384KeyExchange()
|
|
key_share.append(
|
|
(Group.SECP384R1, self._ec_p384_private_key.public_key())
|
|
)
|
|
supported_groups.append(Group.SECP384R1)
|
|
elif group == Group.SECP521R1:
|
|
self._ec_p521_private_key = ECDHP521KeyExchange()
|
|
key_share.append(
|
|
(Group.SECP521R1, self._ec_p521_private_key.public_key())
|
|
)
|
|
supported_groups.append(Group.SECP521R1)
|
|
elif group == Group.X25519:
|
|
self._x25519_private_key = X25519KeyExchange()
|
|
key_share.append((Group.X25519, self._x25519_private_key.public_key()))
|
|
supported_groups.append(Group.X25519)
|
|
elif group == Group.X25519ML768:
|
|
self._x25519_kyber_768_private_key = X25519ML768KeyExchange()
|
|
key_share.append(
|
|
(
|
|
Group.X25519ML768,
|
|
self._x25519_kyber_768_private_key.public_key(),
|
|
)
|
|
)
|
|
supported_groups.append(Group.X25519ML768)
|
|
if self.__logger is not None:
|
|
self.__logger.debug(
|
|
"TLS: Advertising to peer post-quantum algorithm "
|
|
"using X25519ML768 (0x11EC)"
|
|
)
|
|
elif group == Group.GREASE:
|
|
key_share.append((Group.GREASE, b"\x00"))
|
|
supported_groups.append(Group.GREASE)
|
|
|
|
assert len(key_share), "no key share entries"
|
|
|
|
hello = ClientHello(
|
|
random=self.client_random,
|
|
legacy_session_id=self.legacy_session_id,
|
|
cipher_suites=[int(x) for x in self._cipher_suites],
|
|
legacy_compression_methods=self._legacy_compression_methods,
|
|
alpn_protocols=self._alpn_protocols,
|
|
key_share=key_share,
|
|
psk_key_exchange_modes=(
|
|
self._psk_key_exchange_modes
|
|
if (self.session_ticket or self.new_session_ticket_cb is not None)
|
|
else None
|
|
),
|
|
server_name=self._server_name,
|
|
signature_algorithms=self._signature_algorithms,
|
|
supported_groups=supported_groups,
|
|
supported_versions=self._supported_versions,
|
|
other_extensions=self.handshake_extensions,
|
|
)
|
|
|
|
# PSK
|
|
if self.session_ticket and self.session_ticket.is_valid:
|
|
self._key_schedule_psk = KeySchedule(self.session_ticket.cipher_suite)
|
|
self._key_schedule_psk.extract(self.session_ticket.resumption_secret)
|
|
binder_key = self._key_schedule_psk.derive_secret(b"res binder")
|
|
binder_length = self._key_schedule_psk.digest_size
|
|
|
|
# update hello
|
|
if self.session_ticket.max_early_data_size is not None:
|
|
hello.early_data = True
|
|
hello.pre_shared_key = OfferedPsks(
|
|
identities=[
|
|
(self.session_ticket.ticket, self.session_ticket.obfuscated_age)
|
|
],
|
|
binders=[bytes(binder_length)],
|
|
)
|
|
|
|
# serialize hello without binder
|
|
tmp_buf = Buffer(capacity=2048)
|
|
push_client_hello(tmp_buf, hello)
|
|
|
|
# calculate binder
|
|
hash_offset = tmp_buf.tell() - binder_length - 3
|
|
self._key_schedule_psk.update_hash(tmp_buf.data_slice(0, hash_offset))
|
|
binder = self._key_schedule_psk.finished_verify_data(binder_key)
|
|
hello.pre_shared_key.binders[0] = binder
|
|
self._key_schedule_psk.update_hash(
|
|
tmp_buf.data_slice(hash_offset, hash_offset + 3) + binder
|
|
)
|
|
|
|
# calculate early data key
|
|
if hello.early_data:
|
|
early_key = self._key_schedule_psk.derive_secret(b"c e traffic")
|
|
self.update_traffic_key_cb(
|
|
Direction.ENCRYPT,
|
|
Epoch.ZERO_RTT,
|
|
self._key_schedule_psk.cipher_suite,
|
|
early_key,
|
|
)
|
|
|
|
self._key_schedule_proxy = KeyScheduleProxy(
|
|
[cs for cs in self._cipher_suites if cs != CipherSuite.GREASE]
|
|
)
|
|
self._key_schedule_proxy.extract(None)
|
|
|
|
with push_message(self._key_schedule_proxy, output_buf):
|
|
push_client_hello(output_buf, hello)
|
|
|
|
self._set_state(State.CLIENT_EXPECT_SERVER_HELLO)
|
|
|
|
def _client_handle_hello(self, input_buf: Buffer, output_buf: Buffer) -> None:
|
|
peer_hello = pull_server_hello(input_buf)
|
|
|
|
cipher_suite = negotiate(
|
|
self._cipher_suites,
|
|
[peer_hello.cipher_suite],
|
|
AlertHandshakeFailure("Unsupported cipher suite"),
|
|
excl=CipherSuite.GREASE,
|
|
)
|
|
assert peer_hello.compression_method in self._legacy_compression_methods
|
|
assert peer_hello.supported_version in self._supported_versions
|
|
|
|
# select key schedule
|
|
if peer_hello.pre_shared_key is not None:
|
|
if (
|
|
self._key_schedule_psk is None
|
|
or peer_hello.pre_shared_key != 0
|
|
or cipher_suite != self._key_schedule_psk.cipher_suite
|
|
):
|
|
raise AlertIllegalParameter
|
|
self.key_schedule = self._key_schedule_psk
|
|
self._session_resumed = True
|
|
else:
|
|
self.key_schedule = self._key_schedule_proxy.select(cipher_suite)
|
|
self._key_schedule_psk = None
|
|
self._key_schedule_proxy = None
|
|
|
|
# perform key exchange
|
|
peer_public_key = peer_hello.key_share[1]
|
|
shared_key: bytes | None = None
|
|
|
|
if (
|
|
peer_hello.key_share[0] == Group.X25519
|
|
and self._x25519_private_key is not None
|
|
):
|
|
shared_key = self._x25519_private_key.exchange(peer_public_key)
|
|
elif peer_hello.key_share[0] == Group.X25519ML768:
|
|
shared_key = self._x25519_kyber_768_private_key.exchange(peer_public_key)
|
|
if self.__logger is not None:
|
|
self.__logger.debug(
|
|
"TLS: Post-quantum safety achieved using X25519ML768 (key-exchange)"
|
|
)
|
|
elif (
|
|
peer_hello.key_share[0] == Group.SECP256R1
|
|
and self._ec_p256_private_key is not None
|
|
):
|
|
shared_key = self._ec_p256_private_key.exchange(peer_public_key)
|
|
elif (
|
|
peer_hello.key_share[0] == Group.SECP384R1
|
|
and self._ec_p384_private_key is not None
|
|
):
|
|
shared_key = self._ec_p384_private_key.exchange(peer_public_key)
|
|
elif (
|
|
peer_hello.key_share[0] == Group.SECP521R1
|
|
and self._ec_p521_private_key is not None
|
|
):
|
|
shared_key = self._ec_p521_private_key.exchange(peer_public_key)
|
|
|
|
assert shared_key is not None
|
|
|
|
self.key_schedule.update_hash(input_buf.data)
|
|
self.key_schedule.extract(shared_key)
|
|
|
|
self._setup_traffic_protection(
|
|
Direction.DECRYPT, Epoch.HANDSHAKE, b"s hs traffic"
|
|
)
|
|
|
|
self._set_state(State.CLIENT_EXPECT_ENCRYPTED_EXTENSIONS)
|
|
|
|
def _client_handle_encrypted_extensions(self, input_buf: Buffer) -> None:
|
|
encrypted_extensions = pull_encrypted_extensions(input_buf)
|
|
|
|
self.alpn_negotiated = encrypted_extensions.alpn_protocol
|
|
self.early_data_accepted = encrypted_extensions.early_data
|
|
self.received_extensions = encrypted_extensions.other_extensions
|
|
if self.alpn_cb:
|
|
self.alpn_cb(self.alpn_negotiated)
|
|
|
|
self._setup_traffic_protection(
|
|
Direction.ENCRYPT, Epoch.HANDSHAKE, b"c hs traffic"
|
|
)
|
|
self.key_schedule.update_hash(input_buf.data)
|
|
|
|
# if the server accepted our PSK we are done, other we want its certificate
|
|
if self._session_resumed:
|
|
self._set_state(State.CLIENT_EXPECT_FINISHED)
|
|
else:
|
|
self._set_state(State.CLIENT_EXPECT_CERTIFICATE_REQUEST_OR_CERTIFICATE)
|
|
|
|
def _client_handle_certificate(self, input_buf: Buffer) -> None:
|
|
certificate = pull_certificate(input_buf)
|
|
|
|
# attempt to extract a possible OCSP staple extension from
|
|
# the leaf certificate only.
|
|
ext_buf = Buffer(data=certificate.certificates[0][1])
|
|
|
|
try:
|
|
# RFC 8446, Section 4.4.2.2
|
|
|
|
while not ext_buf.eof():
|
|
ext_type = ext_buf.pull_uint16()
|
|
ext_len = ext_buf.pull_uint16()
|
|
|
|
if ext_type == ExtensionType.STATUS_REQUEST:
|
|
status_type = ext_buf.pull_uint8()
|
|
if status_type == 1:
|
|
resp_len = ext_buf.pull_uint24()
|
|
self._ocsp_response = ext_buf.pull_bytes(resp_len)
|
|
break
|
|
else:
|
|
ext_buf.pull_bytes(ext_len - 1)
|
|
else:
|
|
break
|
|
except BufferReadError:
|
|
pass # Defensive: against malformed extensions.
|
|
|
|
self._peer_certificate = X509Certificate(certificate.certificates[0][0])
|
|
self._peer_certificate_chain = [
|
|
X509Certificate(certificate.certificates[i][0])
|
|
for i in range(1, len(certificate.certificates))
|
|
]
|
|
|
|
self.key_schedule.update_hash(input_buf.data)
|
|
|
|
self._set_state(State.CLIENT_EXPECT_CERTIFICATE_VERIFY)
|
|
|
|
def _client_handle_certificate_request(self, input_buf: Buffer) -> None:
|
|
self._certificate_request = pull_certificate_request(input_buf)
|
|
self.key_schedule.update_hash(input_buf.data)
|
|
self._set_state(State.CLIENT_EXPECT_CERTIFICATE_REQUEST_OR_CERTIFICATE)
|
|
|
|
def _client_handle_certificate_verify(self, input_buf: Buffer) -> None:
|
|
verify = pull_certificate_verify(input_buf)
|
|
|
|
assert verify.algorithm in self._signature_algorithms
|
|
|
|
# check signature
|
|
try:
|
|
verify_with_public_key(
|
|
self._peer_certificate.public_key(),
|
|
verify.algorithm,
|
|
self.key_schedule.certificate_verify_data(
|
|
b"TLS 1.3, server CertificateVerify"
|
|
),
|
|
verify.signature,
|
|
)
|
|
except SignatureError as e:
|
|
raise AlertDecryptError(str(e))
|
|
|
|
# check certificate
|
|
if self._verify_mode != ssl.CERT_NONE:
|
|
verify_certificate(
|
|
cadata=self._cadata,
|
|
cafile=self._cafile,
|
|
capath=self._capath,
|
|
certificate=self._peer_certificate,
|
|
chain=self._peer_certificate_chain,
|
|
server_name=self._server_name,
|
|
assert_server_name=self._verify_hostname,
|
|
ocsp_response=self._ocsp_response,
|
|
)
|
|
|
|
if self._assert_fingerprint is not None:
|
|
fingerprint = self._assert_fingerprint.replace(":", "").lower()
|
|
digest_length = len(fingerprint)
|
|
hashfunc = HASHFUNC_MAP.get(digest_length)
|
|
|
|
if not hashfunc:
|
|
raise AlertBadCertificate(
|
|
f"Fingerprint of invalid length: {fingerprint}"
|
|
)
|
|
|
|
expect_fingerprint = unhexlify(fingerprint.encode())
|
|
peer_fingerprint = hashfunc(self._peer_certificate.public_bytes()).digest()
|
|
|
|
if peer_fingerprint != expect_fingerprint:
|
|
raise AlertBadCertificate(
|
|
"Fingerprints did not match. "
|
|
f'Expected "{expect_fingerprint.hex()}", '
|
|
f'got "{peer_fingerprint.hex()}"'
|
|
)
|
|
|
|
self.key_schedule.update_hash(input_buf.data)
|
|
|
|
self._set_state(State.CLIENT_EXPECT_FINISHED)
|
|
|
|
def _client_handle_finished(self, input_buf: Buffer, output_buf: Buffer) -> None:
|
|
finished = pull_finished(input_buf)
|
|
|
|
# check verify data
|
|
expected_verify_data = self.key_schedule.finished_verify_data(self._dec_key)
|
|
if finished.verify_data != expected_verify_data:
|
|
raise AlertDecryptError
|
|
self.key_schedule.update_hash(input_buf.data)
|
|
|
|
# prepare traffic keys
|
|
assert self.key_schedule.generation == 2
|
|
self.key_schedule.extract(None)
|
|
self._setup_traffic_protection(
|
|
Direction.DECRYPT, Epoch.ONE_RTT, b"s ap traffic"
|
|
)
|
|
next_enc_key = self.key_schedule.derive_secret(b"c ap traffic")
|
|
|
|
if self._certificate_request is not None:
|
|
with push_message(self.key_schedule, output_buf):
|
|
push_certificate(
|
|
output_buf,
|
|
Certificate(
|
|
request_context=self._certificate_request.request_context,
|
|
certificates=[
|
|
(cert.public_bytes(), b"")
|
|
for cert in [self.certificate] + self.certificate_chain
|
|
if cert is not None
|
|
],
|
|
),
|
|
)
|
|
|
|
if None not in (self.certificate, self.certificate_private_key):
|
|
# determine applicable signature algorithms
|
|
signature_algorithms: list[SignatureAlgorithm] = []
|
|
|
|
if isinstance(self.certificate_private_key, RsaPrivateKey):
|
|
signature_algorithms = [
|
|
SignatureAlgorithm.RSA_PSS_RSAE_SHA256,
|
|
SignatureAlgorithm.RSA_PKCS1_SHA256,
|
|
]
|
|
elif isinstance(self.certificate_private_key, EcPrivateKey):
|
|
if self.certificate_private_key.curve_type == 256:
|
|
signature_algorithms = [
|
|
SignatureAlgorithm.ECDSA_SECP256R1_SHA256
|
|
]
|
|
elif self.certificate_private_key.curve_type == 384:
|
|
signature_algorithms = [
|
|
SignatureAlgorithm.ECDSA_SECP384R1_SHA384
|
|
]
|
|
elif self.certificate_private_key.curve_type == 521:
|
|
signature_algorithms = [
|
|
SignatureAlgorithm.ECDSA_SECP521R1_SHA512
|
|
]
|
|
elif isinstance(self.certificate_private_key, Ed25519PrivateKey):
|
|
signature_algorithms = [SignatureAlgorithm.ED25519]
|
|
|
|
signature_algorithm = negotiate(
|
|
signature_algorithms,
|
|
self._certificate_request.signature_algorithms,
|
|
AlertHandshakeFailure("No supported signature algorithm"),
|
|
)
|
|
|
|
signature = self.certificate_private_key.sign(
|
|
self.key_schedule.certificate_verify_data(
|
|
b"TLS 1.3, client CertificateVerify"
|
|
),
|
|
*signature_algorithm_params(signature_algorithm),
|
|
)
|
|
|
|
with push_message(self.key_schedule, output_buf):
|
|
push_certificate_verify(
|
|
output_buf,
|
|
CertificateVerify(
|
|
algorithm=signature_algorithm, signature=signature
|
|
),
|
|
)
|
|
|
|
# send finished
|
|
with push_message(self.key_schedule, output_buf):
|
|
push_finished(
|
|
output_buf,
|
|
Finished(
|
|
verify_data=self.key_schedule.finished_verify_data(self._enc_key)
|
|
),
|
|
)
|
|
|
|
# commit traffic key
|
|
self._enc_key = next_enc_key
|
|
self.update_traffic_key_cb(
|
|
Direction.ENCRYPT,
|
|
Epoch.ONE_RTT,
|
|
self.key_schedule.cipher_suite,
|
|
self._enc_key,
|
|
)
|
|
|
|
self._set_state(State.CLIENT_POST_HANDSHAKE)
|
|
|
|
def _client_handle_new_session_ticket(self, input_buf: Buffer) -> None:
|
|
new_session_ticket = pull_new_session_ticket(input_buf)
|
|
|
|
# notify application
|
|
if self.new_session_ticket_cb is not None:
|
|
ticket = self._build_session_ticket(
|
|
new_session_ticket, self.received_extensions
|
|
)
|
|
self.new_session_ticket_cb(ticket)
|
|
|
|
def _server_handle_hello(
|
|
self,
|
|
input_buf: Buffer,
|
|
initial_buf: Buffer,
|
|
handshake_buf: Buffer,
|
|
onertt_buf: Buffer,
|
|
) -> None:
|
|
peer_hello = pull_client_hello(input_buf)
|
|
|
|
# determine applicable signature algorithms
|
|
signature_algorithms: list[SignatureAlgorithm] = []
|
|
|
|
if isinstance(self.certificate_private_key, RsaPrivateKey):
|
|
signature_algorithms = [
|
|
SignatureAlgorithm.RSA_PSS_RSAE_SHA256,
|
|
SignatureAlgorithm.RSA_PKCS1_SHA256,
|
|
]
|
|
elif isinstance(self.certificate_private_key, EcPrivateKey):
|
|
if self.certificate_private_key.curve_type == 256:
|
|
signature_algorithms = [SignatureAlgorithm.ECDSA_SECP256R1_SHA256]
|
|
elif self.certificate_private_key.curve_type == 384:
|
|
signature_algorithms = [SignatureAlgorithm.ECDSA_SECP384R1_SHA384]
|
|
elif self.certificate_private_key.curve_type == 521:
|
|
signature_algorithms = [SignatureAlgorithm.ECDSA_SECP521R1_SHA512]
|
|
elif isinstance(self.certificate_private_key, Ed25519PrivateKey):
|
|
signature_algorithms = [SignatureAlgorithm.ED25519]
|
|
|
|
# negotiate parameters
|
|
cipher_suite = negotiate(
|
|
self._cipher_suites,
|
|
peer_hello.cipher_suites,
|
|
AlertHandshakeFailure("No supported cipher suite"),
|
|
excl=CipherSuite.GREASE,
|
|
)
|
|
compression_method = negotiate(
|
|
self._legacy_compression_methods,
|
|
peer_hello.legacy_compression_methods,
|
|
AlertHandshakeFailure("No supported compression method"),
|
|
)
|
|
psk_key_exchange_mode = negotiate(
|
|
self._psk_key_exchange_modes, peer_hello.psk_key_exchange_modes
|
|
)
|
|
signature_algorithm = negotiate(
|
|
signature_algorithms,
|
|
peer_hello.signature_algorithms,
|
|
AlertHandshakeFailure("No supported signature algorithm"),
|
|
)
|
|
supported_version = negotiate(
|
|
self._supported_versions,
|
|
peer_hello.supported_versions,
|
|
AlertProtocolVersion("No supported protocol version"),
|
|
)
|
|
|
|
# negotiate ALPN
|
|
if self._alpn_protocols is not None:
|
|
self.alpn_negotiated = negotiate(
|
|
self._alpn_protocols,
|
|
peer_hello.alpn_protocols,
|
|
AlertHandshakeFailure("No common ALPN protocols"),
|
|
)
|
|
|
|
self.client_random = peer_hello.random
|
|
self.server_random = os.urandom(32)
|
|
self.legacy_session_id = peer_hello.legacy_session_id
|
|
self.received_extensions = peer_hello.other_extensions
|
|
|
|
if self.alpn_cb:
|
|
self.alpn_cb(self.alpn_negotiated)
|
|
|
|
# select key schedule
|
|
pre_shared_key = None
|
|
if (
|
|
self.get_session_ticket_cb is not None
|
|
and psk_key_exchange_mode is not None
|
|
and peer_hello.pre_shared_key is not None
|
|
and len(peer_hello.pre_shared_key.identities) == 1
|
|
and len(peer_hello.pre_shared_key.binders) == 1
|
|
):
|
|
# ask application to find session ticket
|
|
identity = peer_hello.pre_shared_key.identities[0]
|
|
session_ticket = self.get_session_ticket_cb(identity[0])
|
|
|
|
# validate session ticket
|
|
if (
|
|
session_ticket is not None
|
|
and session_ticket.is_valid
|
|
and session_ticket.cipher_suite == cipher_suite
|
|
):
|
|
self.key_schedule = KeySchedule(cipher_suite)
|
|
self.key_schedule.extract(session_ticket.resumption_secret)
|
|
|
|
binder_key = self.key_schedule.derive_secret(b"res binder")
|
|
binder_length = self.key_schedule.digest_size
|
|
|
|
hash_offset = input_buf.tell() - binder_length - 3
|
|
binder = input_buf.data_slice(
|
|
hash_offset + 3, hash_offset + 3 + binder_length
|
|
)
|
|
|
|
self.key_schedule.update_hash(input_buf.data_slice(0, hash_offset))
|
|
expected_binder = self.key_schedule.finished_verify_data(binder_key)
|
|
|
|
if binder != expected_binder:
|
|
raise AlertHandshakeFailure("PSK validation failed")
|
|
|
|
self.key_schedule.update_hash(
|
|
input_buf.data_slice(hash_offset, hash_offset + 3 + binder_length)
|
|
)
|
|
self._session_resumed = True
|
|
|
|
# calculate early data key
|
|
if peer_hello.early_data:
|
|
early_key = self.key_schedule.derive_secret(b"c e traffic")
|
|
self.early_data_accepted = True
|
|
self.update_traffic_key_cb(
|
|
Direction.DECRYPT,
|
|
Epoch.ZERO_RTT,
|
|
self.key_schedule.cipher_suite,
|
|
early_key,
|
|
)
|
|
|
|
pre_shared_key = 0
|
|
|
|
# if PSK is not used, initialize key schedule
|
|
if pre_shared_key is None:
|
|
self.key_schedule = KeySchedule(cipher_suite)
|
|
self.key_schedule.extract(None)
|
|
self.key_schedule.update_hash(input_buf.data)
|
|
|
|
# perform key exchange
|
|
public_key: bytes | None = None
|
|
group_kx: Group | None = None
|
|
shared_key: bytes | None = None
|
|
|
|
for key_share in peer_hello.key_share:
|
|
peer_public_key = key_share[1]
|
|
|
|
if key_share[0] == Group.X25519:
|
|
self._x25519_private_key = X25519KeyExchange()
|
|
public_key = self._x25519_private_key.public_key()
|
|
shared_key = self._x25519_private_key.exchange(peer_public_key)
|
|
group_kx = Group.X25519
|
|
break
|
|
elif key_share[0] == Group.X25519ML768:
|
|
self._x25519_kyber_768_private_key = X25519ML768KeyExchange()
|
|
shared_key = self._x25519_kyber_768_private_key.exchange(
|
|
peer_public_key
|
|
)
|
|
public_key = self._x25519_kyber_768_private_key.shared_ciphertext()
|
|
group_kx = Group.X25519ML768
|
|
break
|
|
elif key_share[0] == Group.SECP256R1:
|
|
self._ec_p256_private_key = ECDHP256KeyExchange()
|
|
public_key = self._ec_p256_private_key.public_key()
|
|
shared_key = self._ec_p256_private_key.exchange(peer_public_key)
|
|
group_kx = Group.SECP256R1
|
|
break
|
|
elif key_share[0] == Group.SECP384R1:
|
|
self._ec_p384_private_key = ECDHP384KeyExchange()
|
|
public_key = self._ec_p384_private_key.public_key()
|
|
shared_key = self._ec_p384_private_key.exchange(peer_public_key)
|
|
group_kx = Group.SECP384R1
|
|
break
|
|
elif key_share[0] == Group.SECP521R1:
|
|
self._ec_p521_private_key = ECDHP521KeyExchange()
|
|
public_key = self._ec_p521_private_key.public_key()
|
|
shared_key = self._ec_p521_private_key.exchange(peer_public_key)
|
|
group_kx = Group.SECP521R1
|
|
break
|
|
|
|
assert shared_key is not None
|
|
|
|
# send hello
|
|
hello = ServerHello(
|
|
random=self.server_random,
|
|
legacy_session_id=self.legacy_session_id,
|
|
cipher_suite=cipher_suite,
|
|
compression_method=compression_method,
|
|
key_share=(group_kx, public_key),
|
|
pre_shared_key=pre_shared_key,
|
|
supported_version=supported_version,
|
|
)
|
|
with push_message(self.key_schedule, initial_buf):
|
|
push_server_hello(initial_buf, hello)
|
|
self.key_schedule.extract(shared_key)
|
|
|
|
self._setup_traffic_protection(
|
|
Direction.ENCRYPT, Epoch.HANDSHAKE, b"s hs traffic"
|
|
)
|
|
self._setup_traffic_protection(
|
|
Direction.DECRYPT, Epoch.HANDSHAKE, b"c hs traffic"
|
|
)
|
|
|
|
# send encrypted extensions
|
|
with push_message(self.key_schedule, handshake_buf):
|
|
push_encrypted_extensions(
|
|
handshake_buf,
|
|
EncryptedExtensions(
|
|
alpn_protocol=self.alpn_negotiated,
|
|
early_data=self.early_data_accepted,
|
|
other_extensions=self.handshake_extensions,
|
|
),
|
|
)
|
|
|
|
if pre_shared_key is None:
|
|
# send certificate
|
|
with push_message(self.key_schedule, handshake_buf):
|
|
push_certificate(
|
|
handshake_buf,
|
|
Certificate(
|
|
request_context=b"",
|
|
certificates=[
|
|
(x.public_bytes(), b"")
|
|
for x in [self.certificate] + self.certificate_chain
|
|
],
|
|
),
|
|
)
|
|
|
|
# send certificate verify
|
|
signature = self.certificate_private_key.sign(
|
|
self.key_schedule.certificate_verify_data(
|
|
b"TLS 1.3, server CertificateVerify"
|
|
),
|
|
*signature_algorithm_params(signature_algorithm),
|
|
)
|
|
with push_message(self.key_schedule, handshake_buf):
|
|
push_certificate_verify(
|
|
handshake_buf,
|
|
CertificateVerify(
|
|
algorithm=signature_algorithm, signature=signature
|
|
),
|
|
)
|
|
|
|
# send finished
|
|
with push_message(self.key_schedule, handshake_buf):
|
|
push_finished(
|
|
handshake_buf,
|
|
Finished(
|
|
verify_data=self.key_schedule.finished_verify_data(self._enc_key)
|
|
),
|
|
)
|
|
|
|
# prepare traffic keys
|
|
assert self.key_schedule.generation == 2
|
|
self.key_schedule.extract(None)
|
|
self._setup_traffic_protection(
|
|
Direction.ENCRYPT, Epoch.ONE_RTT, b"s ap traffic"
|
|
)
|
|
self._next_dec_key = self.key_schedule.derive_secret(b"c ap traffic")
|
|
|
|
# anticipate client's FINISHED as we don't use client auth
|
|
self._expected_verify_data = self.key_schedule.finished_verify_data(
|
|
self._dec_key
|
|
)
|
|
buf = Buffer(capacity=64)
|
|
push_finished(buf, Finished(verify_data=self._expected_verify_data))
|
|
self.key_schedule.update_hash(buf.data)
|
|
|
|
# create a new session ticket
|
|
if self.new_session_ticket_cb is not None and psk_key_exchange_mode is not None:
|
|
self._new_session_ticket = NewSessionTicket(
|
|
ticket_lifetime=86400,
|
|
ticket_age_add=struct.unpack("I", os.urandom(4))[0],
|
|
ticket_nonce=b"",
|
|
ticket=os.urandom(64),
|
|
max_early_data_size=self._max_early_data,
|
|
)
|
|
|
|
# send message
|
|
push_new_session_ticket(onertt_buf, self._new_session_ticket)
|
|
|
|
# notify application
|
|
ticket = self._build_session_ticket(
|
|
self._new_session_ticket, self.handshake_extensions
|
|
)
|
|
self.new_session_ticket_cb(ticket)
|
|
|
|
self._set_state(State.SERVER_EXPECT_FINISHED)
|
|
|
|
def _server_handle_finished(self, input_buf: Buffer, output_buf: Buffer) -> None:
|
|
finished = pull_finished(input_buf)
|
|
|
|
# check verify data
|
|
if finished.verify_data != self._expected_verify_data:
|
|
raise AlertDecryptError
|
|
|
|
# commit traffic key
|
|
self._dec_key = self._next_dec_key
|
|
self._next_dec_key = None
|
|
self.update_traffic_key_cb(
|
|
Direction.DECRYPT,
|
|
Epoch.ONE_RTT,
|
|
self.key_schedule.cipher_suite,
|
|
self._dec_key,
|
|
)
|
|
|
|
self._set_state(State.SERVER_POST_HANDSHAKE)
|
|
|
|
def _setup_traffic_protection(
|
|
self, direction: Direction, epoch: Epoch, label: bytes
|
|
) -> None:
|
|
key = self.key_schedule.derive_secret(label)
|
|
|
|
if direction == Direction.ENCRYPT:
|
|
self._enc_key = key
|
|
else:
|
|
self._dec_key = key
|
|
|
|
self.update_traffic_key_cb(
|
|
direction, epoch, self.key_schedule.cipher_suite, key
|
|
)
|
|
|
|
def _set_state(self, state: State) -> None:
|
|
if self.__logger:
|
|
self.__logger.debug("TLS %s -> %s", self.state.name, state.name)
|
|
self.state = state
|