fix: 포트 충돌 회피 — note_bridge 8098, intent_service 8099

Jellyfin(8096), OrbStack(8097) 포트 충돌으로 변경.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Hyungi Ahn
2026-03-19 13:53:55 +09:00
parent dc08d29509
commit c2257d3a86
2709 changed files with 619549 additions and 10 deletions

View File

@@ -0,0 +1,12 @@
from __future__ import annotations
from .factories import ResolverDescription, ResolverFactory
from .protocols import BaseResolver, ManyResolver, ProtocolResolver
__all__ = (
"ResolverFactory",
"ProtocolResolver",
"BaseResolver",
"ManyResolver",
"ResolverDescription",
)

View File

@@ -0,0 +1,11 @@
from __future__ import annotations
from .factories import AsyncResolverDescription, AsyncResolverFactory
from .protocols import AsyncBaseResolver, AsyncManyResolver
__all__ = (
"AsyncResolverDescription",
"AsyncResolverFactory",
"AsyncBaseResolver",
"AsyncManyResolver",
)

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
from ._urllib3 import (
AdGuardResolver,
CloudflareResolver,
GoogleResolver,
HTTPSResolver,
NextDNSResolver,
OpenDNSResolver,
Quad9Resolver,
)
__all__ = (
"HTTPSResolver",
"GoogleResolver",
"CloudflareResolver",
"AdGuardResolver",
"OpenDNSResolver",
"Quad9Resolver",
"NextDNSResolver",
)

View File

@@ -0,0 +1,656 @@
from __future__ import annotations
import socket
import typing
from asyncio import as_completed
from base64 import b64encode
from ....._async.connectionpool import AsyncHTTPSConnectionPool
from ....._async.response import AsyncHTTPResponse
from ....._collections import HTTPHeaderDict
from .....backend import ConnectionInfo, HttpVersion
from .....util.url import parse_url
from ...protocols import (
DomainNameServerQuery,
DomainNameServerReturn,
ProtocolResolver,
SupportedQueryType,
)
from ...utils import is_ipv4, is_ipv6, validate_length_of, parse_https_rdata
from ..protocols import AsyncBaseResolver
class HTTPSResolver(AsyncBaseResolver):
"""
Advanced DNS over HTTPS resolver.
No common ground emerged from IETF w/ JSON. Following Googles DNS over HTTPS schematics that is
also implemented at Cloudflare.
Support RFC 8484 without JSON. Disabled by default.
"""
implementation = "urllib3"
protocol = ProtocolResolver.DOH
def __init__(
self,
server: str | None,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
) -> None:
super().__init__(server, port or 443, *patterns, **kwargs)
self._path: str = "/resolve"
if "path" in kwargs:
if isinstance(kwargs["path"], str) and kwargs["path"] != "/":
self._path = kwargs["path"]
kwargs.pop("path")
self._rfc8484: bool = False
if "rfc8484" in kwargs:
if kwargs["rfc8484"]:
self._rfc8484 = True
kwargs.pop("rfc8484")
assert self._server is not None
if "source_address" in kwargs:
if isinstance(kwargs["source_address"], str):
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
if bind_ip and bind_port.isdigit():
kwargs["source_address"] = (
bind_ip,
int(bind_port),
)
else:
raise ValueError("invalid source_address given in parameters")
else:
raise ValueError("invalid source_address given in parameters")
if "proxy" in kwargs:
kwargs["_proxy"] = parse_url(kwargs["proxy"])
kwargs.pop("proxy")
if "maxsize" not in kwargs:
kwargs["maxsize"] = 10
if "proxy_headers" in kwargs and "_proxy" in kwargs:
proxy_headers = HTTPHeaderDict()
if not isinstance(kwargs["proxy_headers"], list):
kwargs["proxy_headers"] = [kwargs["proxy_headers"]]
for item in kwargs["proxy_headers"]:
if ":" not in item:
raise ValueError("Passed header is invalid in DNS parameters")
k, v = item.split(":", 1)
proxy_headers.add(k, v)
kwargs["_proxy_headers"] = proxy_headers
if "headers" in kwargs:
headers = HTTPHeaderDict()
if not isinstance(kwargs["headers"], list):
kwargs["headers"] = [kwargs["headers"]]
for item in kwargs["headers"]:
if ":" not in item:
raise ValueError("Passed header is invalid in DNS parameters")
k, v = item.split(":", 1)
headers.add(k, v)
kwargs["headers"] = headers
if "disabled_svn" in kwargs:
if not isinstance(kwargs["disabled_svn"], list):
kwargs["disabled_svn"] = [kwargs["disabled_svn"]]
disabled_svn = set()
for svn in kwargs["disabled_svn"]:
svn = svn.lower()
if svn == "h11":
disabled_svn.add(HttpVersion.h11)
elif svn == "h2":
disabled_svn.add(HttpVersion.h2)
elif svn == "h3":
disabled_svn.add(HttpVersion.h3)
kwargs["disabled_svn"] = disabled_svn
if "on_post_connection" in kwargs and callable(kwargs["on_post_connection"]):
self._connection_callback: (
typing.Callable[[ConnectionInfo], None] | None
) = kwargs["on_post_connection"]
kwargs.pop("on_post_connection")
else:
self._connection_callback = None
self._pool = AsyncHTTPSConnectionPool(self._server, self._port, **kwargs)
async def close(self) -> None: # type: ignore[override]
await self._pool.close()
def is_available(self) -> bool:
return self._pool.pool is not None
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Tried to resolve 'localhost' from a HTTPSResolver"
)
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
validate_length_of(host)
promises = []
remote_preemptive_quic_rr = False
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
quic_upgrade_via_dns_rr = False
if family in [socket.AF_UNSPEC, socket.AF_INET]:
if not self._rfc8484:
promises.append(
await self._pool.request_encode_url(
"GET",
self._path,
{"name": host, "type": "1"},
headers={"Accept": "application/dns-json"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
else:
dns_query = DomainNameServerQuery(
host, SupportedQueryType.A, override_id=0
)
dns_payload = bytes(dns_query)
promises.append(
await self._pool.request_encode_url(
"GET",
self._path,
{
"dns": b64encode(dns_payload).decode().replace("=", ""),
},
headers={"Accept": "application/dns-message"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
if not self._rfc8484:
promises.append(
await self._pool.request_encode_url(
"GET",
self._path,
{"name": host, "type": "28"},
headers={"Accept": "application/dns-json"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
else:
dns_query = DomainNameServerQuery(
host, SupportedQueryType.AAAA, override_id=0
)
dns_payload = bytes(dns_query)
promises.append(
await self._pool.request_encode_url(
"GET",
self._path,
{
"dns": b64encode(dns_payload).decode().replace("=", ""),
},
headers={"Accept": "application/dns-message"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
if quic_upgrade_via_dns_rr:
if not self._rfc8484:
promises.append(
await self._pool.request_encode_url(
"GET",
self._path,
{"name": host, "type": "65"},
headers={"Accept": "application/dns-json"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
else:
dns_query = DomainNameServerQuery(
host, SupportedQueryType.HTTPS, override_id=0
)
dns_payload = bytes(dns_query)
promises.append(
await self._pool.request_encode_url(
"GET",
self._path,
{
"dns": b64encode(dns_payload).decode().replace("=", ""),
},
headers={"Accept": "application/dns-message"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
tasks = []
responses = []
for promise in promises:
# already resolved
if isinstance(promise, AsyncHTTPResponse):
responses.append(promise)
continue
tasks.append(self._pool.get_response(promise=promise))
if tasks:
for waiting_promise_coro in as_completed(tasks):
responses.append(await waiting_promise_coro) # type: ignore[arg-type]
results: list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
] = []
for response in responses:
if response.status >= 300:
raise socket.gaierror(
f"DNS over HTTPS was unsuccessful, server response status {response.status}."
)
if not self._rfc8484:
payload = await response.json()
assert "Status" in payload and isinstance(payload["Status"], int)
if payload["Status"] != 0:
msg = (
payload["Comment"]
if "Comment" in payload
else f"Remote DNS indicated that an error occurred while providing resolution. Status {payload['Status']}."
)
if isinstance(msg, list):
msg = ", ".join(msg)
raise socket.gaierror(msg)
assert "Question" in payload and isinstance(payload["Question"], list)
if "Answer" not in payload:
continue
assert isinstance(payload["Answer"], list)
for answer in payload["Answer"]:
if answer["type"] not in [1, 28, 65]:
continue
assert "data" in answer
assert isinstance(answer["data"], str)
# DNS RR/HTTPS
if answer["type"] == 65:
# "1 . alpn=h3,h2 ipv4hint=104.16.132.229,104.16.133.229 ipv6hint=2606:4700::6810:84e5,2606:4700::6810:85e5"
# or..
# "1 . alpn=h2,h3"
rr: str = answer["data"]
if rr.startswith("\\#"): # it means, raw, bytes.
rr = "".join(rr[2:].split(" ")[2:])
try:
raw_record = bytes.fromhex(rr)
except ValueError:
raw_record = b""
if not raw_record:
continue
https_record = parse_https_rdata(raw_record)
if "h3" not in https_record["alpn"]:
continue
remote_preemptive_quic_rr = True
else:
rr_decode: dict[str, str] = dict(
tuple(_.lower().split("=", 1)) # type: ignore[misc]
for _ in rr.split(" ")
if "=" in _
)
if "alpn" not in rr_decode or "h3" not in rr_decode["alpn"]:
continue
remote_preemptive_quic_rr = True
if "ipv4hint" in rr_decode and family in [
socket.AF_UNSPEC,
socket.AF_INET,
]:
for ipv4 in rr_decode["ipv4hint"].split(","):
results.append(
(
socket.AF_INET,
socket.SOCK_DGRAM,
17,
"",
(
ipv4,
port,
),
)
)
if "ipv6hint" in rr_decode and family in [
socket.AF_UNSPEC,
socket.AF_INET6,
]:
for ipv6 in rr_decode["ipv6hint"].split(","):
results.append(
(
socket.AF_INET6,
socket.SOCK_DGRAM,
17,
"",
(
ipv6,
port,
0,
0,
),
)
)
continue
inet_type = (
socket.AF_INET if answer["type"] == 1 else socket.AF_INET6
)
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
(
answer["data"],
port,
)
if inet_type == socket.AF_INET
else (
answer["data"],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
else:
dns_resp = DomainNameServerReturn(await response.data)
for record in dns_resp.records:
if record[0] == SupportedQueryType.HTTPS:
assert isinstance(record[-1], dict)
if "h3" in record[-1]["alpn"]:
remote_preemptive_quic_rr = True
continue
assert not isinstance(record[-1], dict)
inet_type = (
socket.AF_INET
if record[0] == SupportedQueryType.A
else socket.AF_INET6
)
dst_addr = (
(
record[-1],
port,
)
if inet_type == socket.AF_INET
else (
record[-1],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
quic_results: list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
] = []
if remote_preemptive_quic_rr:
any_specified = False
for result in results:
if result[1] == socket.SOCK_STREAM:
quic_results.append(
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
)
else:
any_specified = True
break
if any_specified:
quic_results = []
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
class GoogleResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "google"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
if "rfc8484" in kwargs:
if kwargs["rfc8484"]:
kwargs["path"] = "/dns-query"
super().__init__("dns.google", port, *patterns, **kwargs)
class CloudflareResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "cloudflare"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query"})
super().__init__("cloudflare-dns.com", port, *patterns, **kwargs)
class AdGuardResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query", "rfc8484": True})
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
class OpenDNSResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "opendns"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query", "rfc8484": True})
super().__init__("dns.opendns.com", port, *patterns, **kwargs)
class Quad9Resolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "quad9"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query", "rfc8484": True})
super().__init__("dns11.quad9.net", port, *patterns, **kwargs)
class NextDNSResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "nextdns"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.nextdns.io", port, *patterns, **kwargs)

View File

@@ -0,0 +1,15 @@
from __future__ import annotations
try:
from ._qh3 import AdGuardResolver, NextDNSResolver, QUICResolver
except ImportError:
QUICResolver = None # type: ignore
AdGuardResolver = None # type: ignore
NextDNSResolver = None # type: ignore
__all__ = (
"QUICResolver",
"AdGuardResolver",
"NextDNSResolver",
)

View File

@@ -0,0 +1,557 @@
from __future__ import annotations
import asyncio
import socket
import ssl
import typing
from collections import deque
from ssl import SSLError
from time import time as monotonic
from qh3.quic.configuration import QuicConfiguration
from qh3.quic.connection import QuicConnection
from qh3.quic.events import (
ConnectionTerminated,
HandshakeCompleted,
QuicEvent,
StopSendingReceived,
StreamDataReceived,
StreamReset,
)
from .....util.ssl_ import IS_FIPS, resolve_cert_reqs
from ...protocols import (
COMMON_RCODE_LABEL,
DomainNameServerQuery,
DomainNameServerReturn,
ProtocolResolver,
SupportedQueryType,
)
from ...utils import (
is_ipv4,
is_ipv6,
rfc1035_pack,
rfc1035_should_read,
rfc1035_unpack,
validate_length_of,
)
from ..dou import PlainResolver
from ..system import SystemResolver
if IS_FIPS:
raise ImportError(
"DNS-over-QUIC disabled when Python is built with FIPS-compliant ssl module"
)
class QUICResolver(PlainResolver):
protocol = ProtocolResolver.DOQ
implementation = "qh3"
def __init__(
self,
server: str | None,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
):
super().__init__(server, port or 853, *patterns, **kwargs)
# qh3 load_default_certs seems off. need to investigate.
if "ca_cert_data" not in kwargs and "ca_certs" not in kwargs:
kwargs["ca_cert_data"] = []
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
try:
ctx.load_default_certs()
for der in ctx.get_ca_certs(binary_form=True):
kwargs["ca_cert_data"].append(ssl.DER_cert_to_PEM_cert(der))
if kwargs["ca_cert_data"]:
kwargs["ca_cert_data"] = "".join(kwargs["ca_cert_data"])
else:
del kwargs["ca_cert_data"]
except (AttributeError, ValueError, OSError):
del kwargs["ca_cert_data"]
if "ca_cert_data" not in kwargs and "ca_certs" not in kwargs:
if (
"cert_reqs" not in kwargs
or resolve_cert_reqs(kwargs["cert_reqs"]) is ssl.CERT_REQUIRED
):
raise ssl.SSLError(
"DoQ requires at least one CA loaded in order to verify the remote peer certificate. "
"Add ?cert_reqs=0 to disable certificate checks."
)
configuration = QuicConfiguration(
is_client=True,
alpn_protocols=["doq"],
server_name=self._server
if "server_hostname" not in kwargs
else kwargs["server_hostname"],
verify_mode=resolve_cert_reqs(kwargs["cert_reqs"])
if "cert_reqs" in kwargs
else ssl.CERT_REQUIRED,
cadata=kwargs["ca_cert_data"].encode()
if "ca_cert_data" in kwargs
else None,
cafile=kwargs["ca_certs"] if "ca_certs" in kwargs else None,
idle_timeout=300.0,
)
if "cert_file" in kwargs:
configuration.load_cert_chain(
kwargs["cert_file"],
kwargs["key_file"] if "key_file" in kwargs else None,
kwargs["key_password"] if "key_password" in kwargs else None,
)
elif "cert_data" in kwargs:
configuration.load_cert_chain(
kwargs["cert_data"],
kwargs["key_data"] if "key_data" in kwargs else None,
kwargs["key_password"] if "key_password" in kwargs else None,
)
self._quic = QuicConnection(configuration=configuration)
self._read_semaphore: asyncio.Semaphore = asyncio.Semaphore()
self._connect_attempt: asyncio.Event = asyncio.Event()
self._handshake_event: asyncio.Event = asyncio.Event()
self._terminated: bool = False
self._should_disconnect: bool = False
# DNS over QUIC mandate the size-prefix (unsigned int, 2b)
self._rfc1035_prefix_mandated = True
self._unconsumed: deque[DomainNameServerReturn] = deque()
self._pending: deque[DomainNameServerQuery] = deque()
async def close(self) -> None: # type: ignore[override]
if (
not self._terminated
and self._socket is not None
and not self._socket.should_connect()
):
self._quic.close()
while True:
datagrams = self._quic.datagrams_to_send(monotonic())
if not datagrams:
break
for datagram in datagrams:
data, addr = datagram
await self._socket.sendall(data)
self._socket.close()
await self._socket.wait_for_close()
self._terminated = True
if self._socket is None or self._socket.should_connect():
self._terminated = True
def is_available(self) -> bool:
if self._terminated:
return False
if self._socket is None or self._socket.should_connect():
return True
self._quic.handle_timer(monotonic())
if hasattr(self._quic, "_close_event") and self._quic._close_event is not None:
self._terminated = True
return not self._terminated
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Tried to resolve 'localhost' using the QUICResolver"
)
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
validate_length_of(host)
if self._socket is None and self._connect_attempt.is_set() is False:
self._connect_attempt.set()
assert self.server is not None
self._quic.connect((self._server, self._port), monotonic())
self._socket = await SystemResolver().create_connection(
(self.server, self.port or 853),
timeout=self._timeout,
source_address=self._source_address,
socket_options=None,
socket_kind=self._socket_type,
)
await self.__exchange_until(HandshakeCompleted, receive_first=False)
self._handshake_event.set()
else:
await self._handshake_event.wait()
assert self._socket is not None
remote_preemptive_quic_rr = False
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
quic_upgrade_via_dns_rr = False
tbq = []
if family in [socket.AF_UNSPEC, socket.AF_INET]:
tbq.append(SupportedQueryType.A)
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
tbq.append(SupportedQueryType.AAAA)
if quic_upgrade_via_dns_rr:
tbq.append(SupportedQueryType.HTTPS)
queries = DomainNameServerQuery.bulk(host, *tbq)
open_streams = []
for q in queries:
payload = bytes(q)
self._pending.append(q)
if self._rfc1035_prefix_mandated is True:
payload = rfc1035_pack(payload)
stream_id = self._quic.get_next_available_stream_id()
self._quic.send_stream_data(stream_id, payload, True)
open_streams.append(stream_id)
for dg in self._quic.datagrams_to_send(monotonic()):
await self._socket.sendall(dg[0])
responses: list[DomainNameServerReturn] = []
while len(responses) < len(tbq):
await self._read_semaphore.acquire()
if self._unconsumed:
dns_resp = None
for query in queries:
for unconsumed in self._unconsumed:
if unconsumed.id == query.id:
dns_resp = unconsumed
responses.append(dns_resp)
break
if dns_resp:
break
if dns_resp:
self._unconsumed.remove(dns_resp)
self._pending.remove(query)
self._read_semaphore.release()
continue
try:
events: list[StreamDataReceived] = await self.__exchange_until( # type: ignore[assignment]
StreamDataReceived,
receive_first=True,
event_type_collectable=(StreamDataReceived,),
respect_end_stream_signal=False,
)
payload = b"".join([e.data for e in events])
while rfc1035_should_read(payload):
events.extend(
await self.__exchange_until( # type: ignore[arg-type]
StreamDataReceived,
receive_first=True,
event_type_collectable=(StreamDataReceived,),
respect_end_stream_signal=False,
)
)
payload = b"".join([e.data for e in events])
except (TimeoutError, OSError, socket.timeout, ConnectionError) as e:
raise socket.gaierror(
"Got unexpectedly disconnected while waiting for name resolution"
) from e
self._read_semaphore.release()
if not payload:
continue
#: We can receive two responses at once (or more, concatenated). Let's unwrap them.
fragments = rfc1035_unpack(payload)
for fragment in fragments:
dns_resp = DomainNameServerReturn(fragment)
if any(dns_resp.id == _.id for _ in queries):
responses.append(dns_resp)
query_tbr: DomainNameServerQuery | None = None
for query_tbr in self._pending:
if query_tbr.id == dns_resp.id:
break
if query_tbr:
self._pending.remove(query_tbr)
else:
self._unconsumed.append(dns_resp)
if self._should_disconnect:
await self.close()
self._should_disconnect = False
self._terminated = True
results = []
for response in responses:
if not response.is_ok:
if response.rcode == 2:
raise socket.gaierror(
f"DNSSEC validation failure. Check http://dnsviz.net/d/{host}/dnssec/ and http://dnssec-debugger.verisignlabs.com/{host} for errors"
)
raise socket.gaierror(
f"DNS returned an error: {COMMON_RCODE_LABEL[response.rcode] if response.rcode in COMMON_RCODE_LABEL else f'code {response.rcode}'}"
)
for record in response.records:
if record[0] == SupportedQueryType.HTTPS:
assert isinstance(record[-1], dict)
if "h3" in record[-1]["alpn"]:
remote_preemptive_quic_rr = True
continue
assert not isinstance(record[-1], dict)
inet_type = (
socket.AF_INET
if record[0] == SupportedQueryType.A
else socket.AF_INET6
)
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
(
record[-1],
port,
)
if inet_type == socket.AF_INET
else (
record[-1],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
quic_results = []
if remote_preemptive_quic_rr:
any_specified = False
for result in results:
if result[1] == socket.SOCK_STREAM:
quic_results.append(
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
)
else:
any_specified = True
break
if any_specified:
quic_results = []
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
async def __exchange_until(
self,
event_type: type[QuicEvent] | tuple[type[QuicEvent], ...],
*,
receive_first: bool = False,
event_type_collectable: type[QuicEvent]
| tuple[type[QuicEvent], ...]
| None = None,
respect_end_stream_signal: bool = True,
) -> list[QuicEvent]:
assert self._socket is not None
while True:
if receive_first is False:
now = monotonic()
while True:
datagrams = self._quic.datagrams_to_send(now)
if not datagrams:
break
for datagram in datagrams:
data, addr = datagram
await self._socket.sendall(data)
events = []
while True:
if not self._quic._events:
data_in = await self._socket.recv(1500)
if not data_in:
break
now = monotonic()
if not isinstance(data_in, list):
self._quic.receive_datagram(
data_in, (self._server, self._port), now
)
else:
for gro_segment in data_in:
self._quic.receive_datagram(
gro_segment, (self._server, self._port), now
)
while True:
datagrams = self._quic.datagrams_to_send(now)
if not datagrams:
break
for datagram in datagrams:
data, addr = datagram
await self._socket.sendall(data)
for ev in iter(self._quic.next_event, None):
if isinstance(ev, ConnectionTerminated):
if ev.error_code == 298:
raise SSLError(
"DNS over QUIC did not succeed (Error 298). Chain certificate verification failed."
)
raise socket.gaierror(
f"DNS over QUIC encountered a unrecoverable failure (error {ev.error_code} {ev.reason_phrase})"
)
elif isinstance(ev, StreamReset):
self._terminated = True
raise socket.gaierror(
"DNS over QUIC server submitted a StreamReset. A request was rejected."
)
elif isinstance(ev, StopSendingReceived):
self._should_disconnect = True
continue
if event_type_collectable:
if isinstance(ev, event_type_collectable):
events.append(ev)
else:
events.append(ev)
if isinstance(ev, event_type):
if not respect_end_stream_signal:
return events
if hasattr(ev, "stream_ended") and ev.stream_ended:
return events
elif hasattr(ev, "stream_ended") is False:
return events
return events
class AdGuardResolver(
QUICResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
class NextDNSResolver(
QUICResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "nextdns"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.nextdns.io", port, *patterns, **kwargs)

View File

@@ -0,0 +1,19 @@
from __future__ import annotations
from ._ssl import (
AdGuardResolver,
CloudflareResolver,
GoogleResolver,
OpenDNSResolver,
Quad9Resolver,
TLSResolver,
)
__all__ = (
"TLSResolver",
"GoogleResolver",
"CloudflareResolver",
"AdGuardResolver",
"Quad9Resolver",
"OpenDNSResolver",
)

View File

@@ -0,0 +1,197 @@
from __future__ import annotations
import socket
import typing
from .....util._async.ssl_ import ssl_wrap_socket
from .....util.ssl_ import resolve_cert_reqs
from ...protocols import ProtocolResolver
from ..dou import PlainResolver
from ..system import SystemResolver
class TLSResolver(PlainResolver):
"""
Basic DNS resolver over TLS.
Comply with RFC 7858: https://datatracker.ietf.org/doc/html/rfc7858
"""
protocol = ProtocolResolver.DOT
implementation = "ssl"
def __init__(
self,
server: str | None,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
) -> None:
self._socket_type = socket.SOCK_STREAM
super().__init__(server, port or 853, *patterns, **kwargs)
# DNS over TLS mandate the size-prefix (unsigned int, 2 bytes)
self._rfc1035_prefix_mandated = True
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if self._socket is None and self._connect_attempt.is_set() is False:
assert self.server is not None
self._connect_attempt.set()
self._socket = await SystemResolver().create_connection(
(self.server, self.port or 853),
timeout=self._timeout,
source_address=self._source_address,
socket_options=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1, "tcp"),),
socket_kind=self._socket_type,
)
self._socket = await ssl_wrap_socket(
self._socket,
server_hostname=self.server
if "server_hostname" not in self._kwargs
else self._kwargs["server_hostname"],
keyfile=self._kwargs["key_file"]
if "key_file" in self._kwargs
else None,
certfile=self._kwargs["cert_file"]
if "cert_file" in self._kwargs
else None,
cert_reqs=resolve_cert_reqs(self._kwargs["cert_reqs"])
if "cert_reqs" in self._kwargs
else None,
ca_certs=self._kwargs["ca_certs"]
if "ca_certs" in self._kwargs
else None,
ssl_version=self._kwargs["ssl_version"]
if "ssl_version" in self._kwargs
else None,
ciphers=self._kwargs["ciphers"] if "ciphers" in self._kwargs else None,
ca_cert_dir=self._kwargs["ca_cert_dir"]
if "ca_cert_dir" in self._kwargs
else None,
key_password=self._kwargs["key_password"]
if "key_password" in self._kwargs
else None,
ca_cert_data=self._kwargs["ca_cert_data"]
if "ca_cert_data" in self._kwargs
else None,
certdata=self._kwargs["cert_data"]
if "cert_data" in self._kwargs
else None,
keydata=self._kwargs["key_data"]
if "key_data" in self._kwargs
else None,
)
self._connect_finalized.set()
return await super().getaddrinfo(
host,
port,
family=family,
type=type,
proto=proto,
flags=flags,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
class GoogleResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "google"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.google", port, *patterns, **kwargs)
class CloudflareResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "cloudflare"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("1.1.1.1", port, *patterns, **kwargs)
class AdGuardResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
class OpenDNSResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "opendns"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.opendns.com", port, *patterns, **kwargs)
class Quad9Resolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "quad9"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns11.quad9.net", port, *patterns, **kwargs)

View File

@@ -0,0 +1,17 @@
from __future__ import annotations
from ._socket import (
AdGuardResolver,
CloudflareResolver,
GoogleResolver,
PlainResolver,
Quad9Resolver,
)
__all__ = (
"PlainResolver",
"CloudflareResolver",
"GoogleResolver",
"Quad9Resolver",
"AdGuardResolver",
)

View File

@@ -0,0 +1,431 @@
from __future__ import annotations
import asyncio
import socket
import typing
from collections import deque
from ....ssa import AsyncSocket
from ...protocols import (
COMMON_RCODE_LABEL,
DomainNameServerQuery,
DomainNameServerReturn,
ProtocolResolver,
SupportedQueryType,
)
from ...utils import (
is_ipv4,
is_ipv6,
packet_fragment,
rfc1035_pack,
rfc1035_should_read,
rfc1035_unpack,
validate_length_of,
)
from ..protocols import AsyncBaseResolver
from ..system import SystemResolver
class PlainResolver(AsyncBaseResolver):
"""
Minimalist DNS resolver over UDP
Comply with RFC 1035: https://datatracker.ietf.org/doc/html/rfc1035
EDNS is not supported, yet. But we plan to. Willing to contribute?
"""
protocol = ProtocolResolver.DOU
implementation = "socket"
def __init__(
self,
server: str | None,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
) -> None:
super().__init__(server, port, *patterns, **kwargs)
self._socket: AsyncSocket | None = None
if not hasattr(self, "_socket_type"):
self._socket_type = socket.SOCK_DGRAM
if "source_address" in kwargs and isinstance(kwargs["source_address"], str):
if ":" in kwargs["source_address"]:
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
self._source_address: tuple[str, int] | None = (bind_ip, int(bind_port))
else:
self._source_address = (kwargs["source_address"], 0)
else:
self._source_address = None
if "timeout" in kwargs and isinstance(
kwargs["timeout"],
(
float,
int,
),
):
self._timeout: float | int | None = kwargs["timeout"]
else:
self._timeout = None
#: Only useful for inheritance, e.g. DNS over TLS support dns-message but require a prefix.
self._rfc1035_prefix_mandated: bool = False
self._unconsumed: deque[DomainNameServerReturn] = deque()
self._pending: deque[DomainNameServerQuery] = deque()
self._read_semaphore: asyncio.Semaphore = asyncio.Semaphore()
self._connect_attempt: asyncio.Event = asyncio.Event()
self._connect_finalized: asyncio.Event = asyncio.Event()
self._terminated: bool = False
async def close(self) -> None: # type: ignore[override]
if not self._terminated:
with self._lock:
if self._socket is not None:
self._socket.close()
await self._socket.wait_for_close()
self._terminated = True
def is_available(self) -> bool:
return not self._terminated
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Tried to resolve 'localhost' from a PlainResolver"
)
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
validate_length_of(host)
if self._socket is None and self._connect_attempt.is_set() is False:
self._connect_attempt.set()
assert self.server is not None
self._socket = await SystemResolver().create_connection(
(self.server, self.port or 53),
timeout=self._timeout,
source_address=self._source_address,
socket_options=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1, "tcp"),),
socket_kind=self._socket_type,
)
self._connect_finalized.set()
else:
await self._connect_finalized.wait()
assert self._socket is not None
await self._socket.wait_for_readiness()
remote_preemptive_quic_rr = False
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
quic_upgrade_via_dns_rr = False
tbq = []
if family in [socket.AF_UNSPEC, socket.AF_INET]:
tbq.append(SupportedQueryType.A)
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
tbq.append(SupportedQueryType.AAAA)
if quic_upgrade_via_dns_rr:
tbq.append(SupportedQueryType.HTTPS)
queries = DomainNameServerQuery.bulk(host, *tbq)
for q in queries:
payload = bytes(q)
self._pending.append(q)
if self._rfc1035_prefix_mandated is True:
payload = rfc1035_pack(payload)
await self._socket.sendall(payload)
responses: list[DomainNameServerReturn] = []
while len(responses) < len(tbq):
await self._read_semaphore.acquire()
#: There we want to verify if another thread got a response that belong to this thread.
if self._unconsumed:
dns_resp = None
for query in queries:
for unconsumed in self._unconsumed:
if unconsumed.id == query.id:
dns_resp = unconsumed
responses.append(dns_resp)
break
if dns_resp:
break
if dns_resp:
self._pending.remove(query)
self._unconsumed.remove(dns_resp)
self._read_semaphore.release()
continue
try:
data_in_or_segments = await self._socket.recv(1500)
if isinstance(data_in_or_segments, list):
payloads = data_in_or_segments
elif data_in_or_segments:
payloads = [data_in_or_segments]
else:
payloads = []
if self._rfc1035_prefix_mandated is True and payloads:
payload = b"".join(payloads)
while rfc1035_should_read(payload):
extra = await self._socket.recv(1500)
if isinstance(extra, list):
payload += b"".join(extra)
else:
payload += extra
payloads = [payload]
except (TimeoutError, OSError, socket.timeout, ConnectionError) as e:
raise socket.gaierror(
"Got unexpectedly disconnected while waiting for name resolution"
) from e
self._read_semaphore.release()
if not payloads:
self._terminated = True
raise socket.gaierror(
"Got unexpectedly disconnected while waiting for name resolution"
)
pending_raw_identifiers = [_.raw_id for _ in self._pending]
for payload in payloads:
#: We can receive two responses at once (or more, concatenated). Let's unwrap them.
if self._rfc1035_prefix_mandated is True:
fragments = rfc1035_unpack(payload)
else:
fragments = packet_fragment(payload, *pending_raw_identifiers)
for fragment in fragments:
dns_resp = DomainNameServerReturn(fragment)
if any(dns_resp.id == _.id for _ in queries):
responses.append(dns_resp)
query_tbr: DomainNameServerQuery | None = None
for query_tbr in self._pending:
if query_tbr.id == dns_resp.id:
break
if query_tbr:
self._pending.remove(query_tbr)
else:
self._unconsumed.append(dns_resp)
results = []
for response in responses:
if not response.is_ok:
if response.rcode == 2:
raise socket.gaierror(
f"DNSSEC validation failure. Check http://dnsviz.net/d/{host}/dnssec/ and http://dnssec-debugger.verisignlabs.com/{host} for errors"
)
raise socket.gaierror(
f"DNS returned an error: {COMMON_RCODE_LABEL[response.rcode] if response.rcode in COMMON_RCODE_LABEL else f'code {response.rcode}'}"
)
for record in response.records:
if record[0] == SupportedQueryType.HTTPS:
assert isinstance(record[-1], dict)
if "h3" in record[-1]["alpn"]:
remote_preemptive_quic_rr = True
continue
assert not isinstance(record[-1], dict)
inet_type = (
socket.AF_INET
if record[0] == SupportedQueryType.A
else socket.AF_INET6
)
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
(
record[-1],
port,
)
if inet_type == socket.AF_INET
else (
record[-1],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
quic_results = []
if remote_preemptive_quic_rr:
any_specified = False
for result in results:
if result[1] == socket.SOCK_STREAM:
quic_results.append(
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
)
else:
any_specified = True
break
if any_specified:
quic_results = []
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
class CloudflareResolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "cloudflare"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("1.1.1.1", port, *patterns, **kwargs)
class GoogleResolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "google"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("8.8.8.8", port, *patterns, **kwargs)
class Quad9Resolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "quad9"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("9.9.9.9", port, *patterns, **kwargs)
class AdGuardResolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("94.140.14.140", port, *patterns, **kwargs)

View File

@@ -0,0 +1,243 @@
from __future__ import annotations
import importlib
import inspect
import typing
from abc import ABCMeta
from base64 import b64encode
from typing import Any
from urllib.parse import parse_qs
from ....util import parse_url
from ..factories import ResolverDescription
from ..protocols import ProtocolResolver
from .protocols import AsyncBaseResolver
class AsyncResolverFactory(metaclass=ABCMeta):
@staticmethod
def has(
protocol: ProtocolResolver,
specifier: str | None = None,
implementation: str | None = None,
) -> bool:
package_name: str = __name__.split(".")[0]
module_expr = f".{protocol.value.replace('-', '_')}"
if implementation:
module_expr += f"._{implementation.replace('-', '_').lower()}"
try:
resolver_module = importlib.import_module(
module_expr, f"{package_name}.contrib.resolver._async"
)
except ImportError:
return False
implementations: list[tuple[str, type[AsyncBaseResolver]]] = inspect.getmembers(
resolver_module,
lambda e: isinstance(e, type)
and issubclass(e, AsyncBaseResolver)
and (
(specifier is None and e.specifier is None) or specifier == e.specifier
),
)
if not implementations:
return False
return True
@staticmethod
def new(
protocol: ProtocolResolver,
specifier: str | None = None,
implementation: str | None = None,
**kwargs: Any,
) -> AsyncBaseResolver:
package_name: str = __name__.split(".")[0]
module_expr = f".{protocol.value.replace('-', '_')}"
if implementation:
module_expr += f"._{implementation.replace('-', '_').lower()}"
spe_msg = " " if specifier is None else f' (w/ specifier "{specifier}") '
try:
resolver_module = importlib.import_module(
module_expr, f"{package_name}.contrib.resolver._async"
)
except ImportError as e:
raise NotImplementedError(
f"{protocol}{spe_msg}cannot be loaded. Tried to import '{module_expr}'. Did you specify a non-existent implementation?"
) from e
implementations: list[tuple[str, type[AsyncBaseResolver]]] = inspect.getmembers(
resolver_module,
lambda e: isinstance(e, type)
and issubclass(e, AsyncBaseResolver)
and (
(specifier is None and e.specifier is None) or specifier == e.specifier
)
and hasattr(e, "protocol")
and e.protocol == protocol,
)
if not implementations:
raise NotImplementedError(
f"{protocol}{spe_msg}cannot be loaded. "
"No compatible implementation available. "
"Make sure your implementation inherit from BaseResolver."
)
implementation_target: type[AsyncBaseResolver] = implementations.pop()[1]
return implementation_target(**kwargs)
class AsyncResolverDescription(ResolverDescription):
"""Describe how a BaseResolver must be instantiated."""
def new(self) -> AsyncBaseResolver:
kwargs = {**self.kwargs}
if self.server:
kwargs["server"] = self.server
if self.port:
kwargs["port"] = self.port
if self.host_patterns:
kwargs["patterns"] = self.host_patterns
return AsyncResolverFactory.new(
self.protocol,
self.specifier,
self.implementation,
**kwargs,
)
@staticmethod
def from_url(url: str) -> AsyncResolverDescription:
parsed_url = parse_url(url)
schema = parsed_url.scheme
if schema is None:
raise ValueError("Given DNS url is missing a protocol")
specifier = None
implementation = None
if "+" in schema:
schema, specifier = tuple(schema.lower().split("+", 1))
protocol = ProtocolResolver(schema)
kwargs: dict[str, typing.Any] = {}
if parsed_url.path:
kwargs["path"] = parsed_url.path
if parsed_url.auth:
kwargs["headers"] = dict()
if ":" in parsed_url.auth:
username, password = parsed_url.auth.split(":")
username = username.strip("'\"")
password = password.strip("'\"")
kwargs["headers"]["Authorization"] = (
f"Basic {b64encode(f'{username}:{password}'.encode()).decode()}"
)
else:
kwargs["headers"]["Authorization"] = f"Bearer {parsed_url.auth}"
if parsed_url.query:
parameters = parse_qs(parsed_url.query)
for parameter in parameters:
if not parameters[parameter]:
continue
parameter_insensible = parameter.lower()
if (
isinstance(parameters[parameter], list)
and len(parameters[parameter]) > 1
):
if parameter == "implementation":
raise ValueError("Only one implementation can be passed to URL")
values = []
for e in parameters[parameter]:
if "," in e:
values.extend(e.split(","))
else:
values.append(e)
if parameter_insensible in kwargs:
if isinstance(kwargs[parameter_insensible], list):
kwargs[parameter_insensible].extend(values)
else:
values.append(kwargs[parameter_insensible])
kwargs[parameter_insensible] = values
continue
kwargs[parameter_insensible] = values
continue
value: str = parameters[parameter][0].lower().strip(" ")
if parameter == "implementation":
implementation = value
continue
if "," in value:
list_of_values = value.split(",")
if parameter_insensible in kwargs:
if isinstance(kwargs[parameter_insensible], list):
kwargs[parameter_insensible].extend(list_of_values)
else:
list_of_values.append(kwargs[parameter_insensible])
continue
kwargs[parameter_insensible] = list_of_values
continue
value_converted: bool | int | float | None = None
if value in ["false", "true"]:
value_converted = True if value == "true" else False
elif value.isdigit():
value_converted = int(value)
elif (
value.count(".") == 1
and value.index(".") > 0
and value.replace(".", "").isdigit()
):
value_converted = float(value)
kwargs[parameter_insensible] = (
value if value_converted is None else value_converted
)
host_patterns: list[str] = []
if "hosts" in kwargs:
host_patterns = (
kwargs["hosts"].split(",")
if isinstance(kwargs["hosts"], str)
else kwargs["hosts"]
)
del kwargs["hosts"]
return AsyncResolverDescription(
protocol,
specifier,
implementation,
parsed_url.host,
parsed_url.port,
*host_patterns,
**kwargs,
)

View File

@@ -0,0 +1,5 @@
from __future__ import annotations
from ._dict import InMemoryResolver
__all__ = ("InMemoryResolver",)

View File

@@ -0,0 +1,186 @@
from __future__ import annotations
import socket
import typing
from .....util.url import _IPV6_ADDRZ_RE
from ...protocols import ProtocolResolver
from ...utils import is_ipv4, is_ipv6
from ..protocols import AsyncBaseResolver
class InMemoryResolver(AsyncBaseResolver):
protocol = ProtocolResolver.MANUAL
implementation = "dict"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
kwargs.pop("port")
super().__init__(None, None, *patterns, **kwargs)
self._maxsize = 65535 if "maxsize" not in kwargs else int(kwargs["maxsize"])
self._hosts: dict[str, list[tuple[socket.AddressFamily, str]]] = {}
if self._host_patterns:
for record in self._host_patterns:
if ":" not in record:
continue
hostname, addr = record.split(":", 1)
self.register(hostname, addr)
self._host_patterns = tuple([])
def recycle(self) -> AsyncBaseResolver:
return self
async def close(self) -> None: # type: ignore[override]
pass # no-op
def is_available(self) -> bool:
return True
def have_constraints(self) -> bool:
return True
def support(self, hostname: str | bytes | None) -> bool | None:
if hostname is None:
hostname = "localhost"
if isinstance(hostname, bytes):
hostname = hostname.decode("ascii")
return hostname in self._hosts
def register(self, hostname: str, ipaddr: str) -> None:
if hostname not in self._hosts:
self._hosts[hostname] = []
else:
for e in self._hosts[hostname]:
t, addr = e
if addr in ipaddr:
return
if _IPV6_ADDRZ_RE.match(ipaddr):
self._hosts[hostname].append((socket.AF_INET6, ipaddr[1:-1]))
elif is_ipv6(ipaddr):
self._hosts[hostname].append((socket.AF_INET6, ipaddr))
else:
self._hosts[hostname].append((socket.AF_INET, ipaddr))
if len(self._hosts) > self._maxsize:
k = None
for k in self._hosts.keys():
break
if k:
self._hosts.pop(k)
def clear(self, hostname: str) -> None:
if hostname in self._hosts:
del self._hosts[hostname]
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
host = "localhost" # Defensive: stdlib cpy behavior
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
results: list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
] = []
if host not in self._hosts:
raise socket.gaierror(f"no records found for hostname {host} in-memory")
for entry in self._hosts[host]:
addr_type, addr_target = entry
if family != socket.AF_UNSPEC:
if family != addr_type:
continue
results.append(
(
addr_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
(addr_target, port)
if addr_type == socket.AF_INET
else (addr_target, port, 0, 0),
)
)
if not results:
raise socket.gaierror(f"no records found for hostname {host} in-memory")
return results

View File

@@ -0,0 +1,105 @@
from __future__ import annotations
import socket
import typing
from ...protocols import ProtocolResolver
from ...utils import is_ipv4, is_ipv6
from ..protocols import AsyncBaseResolver
class NullResolver(AsyncBaseResolver):
protocol = ProtocolResolver.NULL
implementation = "dummy"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
kwargs.pop("port")
super().__init__(None, None, *patterns, **kwargs)
def recycle(self) -> AsyncBaseResolver:
return self
async def close(self) -> None: # type: ignore[override]
pass # no-op
def is_available(self) -> bool:
return True
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
host = "localhost" # Defensive: stdlib cpy behavior
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
raise socket.gaierror(f"Tried to resolve '{host}' using the NullResolver")
__all__ = ("NullResolver",)

View File

@@ -0,0 +1,375 @@
from __future__ import annotations
import asyncio
import ipaddress
import socket
import struct
import sys
import typing
from abc import ABCMeta, abstractmethod
from datetime import datetime, timedelta, timezone
from ...._constant import UDP_LINUX_GRO
from ...._typing import _TYPE_SOCKET_OPTIONS, _TYPE_TIMEOUT_INTERNAL
from ....exceptions import LocationParseError
from ....util.connection import _set_socket_options, allowed_gai_family
from ....util.timeout import _DEFAULT_TIMEOUT
from ...ssa import AsyncSocket
from ...ssa._timeout import timeout as timeout_
from ..protocols import BaseResolver
class AsyncBaseResolver(BaseResolver, metaclass=ABCMeta):
def recycle(self) -> AsyncBaseResolver:
return super().recycle() # type: ignore[return-value]
@abstractmethod
async def close(self) -> None: # type: ignore[override]
"""Terminate the given resolver instance. This should render it unusable. Further inquiries should raise an exception."""
raise NotImplementedError
@abstractmethod
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
"""This method align itself on the standard library socket.getaddrinfo(). It must be implemented as-is on your Resolver."""
raise NotImplementedError
# This function is copied from socket.py in the Python 2.7 standard
# library test suite. Added to its signature is only `socket_options`.
# One additional modification is that we avoid binding to IPv6 servers
# discovered in DNS if the system doesn't have IPv6 functionality.
async def create_connection( # type: ignore[override]
self,
address: tuple[str, int],
timeout: _TYPE_TIMEOUT_INTERNAL = _DEFAULT_TIMEOUT,
source_address: tuple[str, int] | None = None,
socket_options: _TYPE_SOCKET_OPTIONS | None = None,
socket_kind: socket.SocketKind = socket.SOCK_STREAM,
*,
quic_upgrade_via_dns_rr: bool = False,
timing_hook: typing.Callable[[tuple[timedelta, timedelta, datetime]], None]
| None = None,
default_socket_family: socket.AddressFamily = socket.AF_UNSPEC,
) -> AsyncSocket:
"""Connect to *address* and return the socket object.
Convenience function. Connect to *address* (a 2-tuple ``(host,
port)``) and return the socket object. Passing the optional
*timeout* parameter will set the timeout on the socket instance
before attempting to connect. If no *timeout* is supplied, the
global default timeout setting returned by :func:`socket.getdefaulttimeout`
is used. If *source_address* is set it must be a tuple of (host, port)
for the socket to bind as a source address before making the connection.
An host of '' or port 0 tells the OS to use the default.
"""
host, port = address
if host.startswith("["):
host = host.strip("[]")
err = None
# Using the value from allowed_gai_family() in the context of getaddrinfo lets
# us select whether to work with IPv4 DNS records, IPv6 records, or both.
# The original create_connection function always returns all records.
family = allowed_gai_family()
if family != socket.AF_UNSPEC:
default_socket_family = family
if source_address is not None:
if isinstance(
ipaddress.ip_address(source_address[0]), ipaddress.IPv4Address
):
default_socket_family = socket.AF_INET
else:
default_socket_family = socket.AF_INET6
try:
host.encode("idna")
except UnicodeError:
raise LocationParseError(f"'{host}', label empty or too long") from None
dt_pre_resolve = datetime.now(tz=timezone.utc)
if timeout is not _DEFAULT_TIMEOUT and timeout is not None:
# we can hang here in case of bad networking conditions
# the DNS may never answer or the packets can be lost.
# this isn't possible in sync mode. unfortunately.
# found by user at https://github.com/jawah/niquests/issues/183
# todo: find a way to limit getaddrinfo delays in sync mode.
try:
async with timeout_(timeout):
records = await self.getaddrinfo(
host,
port,
default_socket_family,
socket_kind,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
except TimeoutError:
raise socket.gaierror(
f"unable to resolve '{host}' within timeout. the DNS server may be unresponsive."
)
else:
records = await self.getaddrinfo(
host,
port,
default_socket_family,
socket_kind,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
delta_post_resolve = datetime.now(tz=timezone.utc) - dt_pre_resolve
dt_pre_established = datetime.now(tz=timezone.utc)
for res in records:
af, socktype, proto, canonname, sa = res
sock = None
try:
sock = AsyncSocket(af, socktype, proto)
# we need to add this or reusing the same origin port will likely fail within
# short period of time. kernel put port on wait shut.
if source_address:
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
except (OSError, AttributeError): # Defensive: very old OS?
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
except (
OSError,
AttributeError,
): # Defensive: we can't do anything better than this.
pass
try:
sock.setsockopt(
socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0)
)
except (OSError, AttributeError):
pass
# attempt to leverage GRO when under Linux
if socktype == socket.SOCK_DGRAM and sys.platform == "linux":
try:
sock.setsockopt(socket.SOL_UDP, UDP_LINUX_GRO, 1)
except OSError: # Defensive: oh, well(...) anyway!
pass
# If provided, set socket level options before connecting.
_set_socket_options(sock, socket_options)
if timeout is not _DEFAULT_TIMEOUT:
sock.settimeout(timeout)
if source_address:
sock.bind(source_address)
try:
await sock.connect(sa)
except asyncio.CancelledError:
sock.close()
raise
# Break explicitly a reference cycle
err = None
delta_post_established = (
datetime.now(tz=timezone.utc) - dt_pre_established
)
if timing_hook is not None:
timing_hook(
(
delta_post_resolve,
delta_post_established,
datetime.now(tz=timezone.utc),
)
)
return sock
except (OSError, OverflowError) as _:
err = _
if sock is not None:
sock.close()
if isinstance(_, OverflowError):
break
if err is not None:
try:
raise err
finally:
# Break explicitly a reference cycle
err = None
else:
raise OSError("getaddrinfo returns an empty list")
class AsyncManyResolver(AsyncBaseResolver):
"""
Special resolver that use many child resolver. Priorities
are based on given order (list of BaseResolver).
"""
def __init__(self, *resolvers: AsyncBaseResolver) -> None:
super().__init__(None, None)
self._size = len(resolvers)
self._unconstrained: list[AsyncBaseResolver] = [
_ for _ in resolvers if not _.have_constraints()
]
self._constrained: list[AsyncBaseResolver] = [
_ for _ in resolvers if _.have_constraints()
]
self._concurrent: int = 0
self._terminated: bool = False
def recycle(self) -> AsyncBaseResolver:
resolvers = []
for resolver in self._unconstrained + self._constrained:
resolvers.append(resolver.recycle())
return AsyncManyResolver(*resolvers)
async def close(self) -> None: # type: ignore[override]
for resolver in self._unconstrained + self._constrained:
await resolver.close()
self._terminated = True
def is_available(self) -> bool:
return not self._terminated
def __resolvers(
self, constrained: bool = False
) -> typing.Generator[AsyncBaseResolver, None, None]:
resolvers = self._unconstrained if not constrained else self._constrained
if not resolvers:
return
with self._lock:
self._concurrent += 1
try:
resolver_count = len(resolvers)
start_idx = (self._concurrent - 1) % resolver_count
for idx in range(start_idx, resolver_count):
if not resolvers[idx].is_available():
with self._lock:
resolvers[idx] = resolvers[idx].recycle()
yield resolvers[idx]
if start_idx > 0:
for idx in range(0, start_idx):
if not resolvers[idx].is_available():
with self._lock:
resolvers[idx] = resolvers[idx].recycle()
yield resolvers[idx]
finally:
with self._lock:
self._concurrent -= 1
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if isinstance(host, bytes):
host = host.decode("ascii")
if host is None:
host = "localhost"
tested_resolvers = []
any_constrained_tried: bool = False
for resolver in self.__resolvers(True):
can_resolve = resolver.support(host)
if can_resolve is True:
any_constrained_tried = True
try:
results = await resolver.getaddrinfo(
host,
port,
family,
type,
proto,
flags,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
if results:
return results
except socket.gaierror as exc:
if isinstance(exc.args[0], str) and (
"DNSSEC" in exc.args[0] or "DNSKEY" in exc.args[0]
):
raise
continue
elif can_resolve is False:
tested_resolvers.append(resolver)
if any_constrained_tried:
raise socket.gaierror(
f"Name or service not known: {host} using {self._size - len(self._unconstrained)} resolver(s)"
)
for resolver in self.__resolvers():
try:
results = await resolver.getaddrinfo(
host,
port,
family,
type,
proto,
flags,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
if results:
return results
except socket.gaierror as exc:
if isinstance(exc.args[0], str) and (
"DNSSEC" in exc.args[0] or "DNSKEY" in exc.args[0]
):
raise
continue
raise socket.gaierror(
f"Name or service not known: {host} using {self._size - len(self._constrained)} resolver(s)"
)

View File

@@ -0,0 +1,5 @@
from __future__ import annotations
from ._socket import SystemResolver
__all__ = ("SystemResolver",)

View File

@@ -0,0 +1,66 @@
from __future__ import annotations
import asyncio
import socket
import typing
from ...protocols import ProtocolResolver
from ..protocols import AsyncBaseResolver
class SystemResolver(AsyncBaseResolver):
implementation = "socket"
protocol = ProtocolResolver.SYSTEM
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
kwargs.pop("port")
super().__init__(None, None, *patterns, **kwargs)
def support(self, hostname: str | bytes | None) -> bool | None:
if hostname is None:
return True
if isinstance(hostname, bytes):
hostname = hostname.decode("ascii")
if hostname == "localhost":
return True
return super().support(hostname)
def recycle(self) -> AsyncBaseResolver:
return self
async def close(self) -> None: # type: ignore[override]
pass # no-op!
def is_available(self) -> bool:
return True
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
return await asyncio.get_running_loop().getaddrinfo(
host=host,
port=port,
family=family,
type=type,
proto=proto,
flags=flags,
)

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
from ._urllib3 import (
AdGuardResolver,
CloudflareResolver,
GoogleResolver,
HTTPSResolver,
NextDNSResolver,
OpenDNSResolver,
Quad9Resolver,
)
__all__ = (
"HTTPSResolver",
"GoogleResolver",
"CloudflareResolver",
"AdGuardResolver",
"OpenDNSResolver",
"Quad9Resolver",
"NextDNSResolver",
)

View File

@@ -0,0 +1,641 @@
from __future__ import annotations
import socket
import typing
from base64 import b64encode
from ...._collections import HTTPHeaderDict
from ....backend import ConnectionInfo, HttpVersion, ResponsePromise
from ....connectionpool import HTTPSConnectionPool
from ....response import HTTPResponse
from ....util.url import parse_url
from ..protocols import (
BaseResolver,
DomainNameServerQuery,
DomainNameServerReturn,
ProtocolResolver,
SupportedQueryType,
)
from ..utils import is_ipv4, is_ipv6, validate_length_of, parse_https_rdata
class HTTPSResolver(BaseResolver):
"""
Advanced DNS over HTTPS resolver.
No common ground emerged from IETF w/ JSON. Following Googles DNS over HTTPS schematics that is
also implemented at Cloudflare.
Support RFC 8484 without JSON. Disabled by default.
"""
implementation = "urllib3"
protocol = ProtocolResolver.DOH
def __init__(
self,
server: str | None,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
) -> None:
super().__init__(server, port or 443, *patterns, **kwargs)
self._path: str = "/resolve"
if "path" in kwargs:
if isinstance(kwargs["path"], str) and kwargs["path"] != "/":
self._path = kwargs["path"]
kwargs.pop("path")
self._rfc8484: bool = False
if "rfc8484" in kwargs:
if kwargs["rfc8484"]:
self._rfc8484 = True
kwargs.pop("rfc8484")
assert self._server is not None
if "source_address" in kwargs:
if isinstance(kwargs["source_address"], str):
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
if bind_ip and bind_port.isdigit():
kwargs["source_address"] = (
bind_ip,
int(bind_port),
)
else:
raise ValueError("invalid source_address given in parameters")
else:
raise ValueError("invalid source_address given in parameters")
if "proxy" in kwargs:
kwargs["_proxy"] = parse_url(kwargs["proxy"])
kwargs.pop("proxy")
if "maxsize" not in kwargs:
kwargs["maxsize"] = 10
if "proxy_headers" in kwargs and "_proxy" in kwargs:
proxy_headers = HTTPHeaderDict()
if not isinstance(kwargs["proxy_headers"], list):
kwargs["proxy_headers"] = [kwargs["proxy_headers"]]
for item in kwargs["proxy_headers"]:
if ":" not in item:
raise ValueError("Passed header is invalid in DNS parameters")
k, v = item.split(":", 1)
proxy_headers.add(k, v)
kwargs["_proxy_headers"] = proxy_headers
if "headers" in kwargs:
headers = HTTPHeaderDict()
if not isinstance(kwargs["headers"], list):
kwargs["headers"] = [kwargs["headers"]]
for item in kwargs["headers"]:
if ":" not in item:
raise ValueError("Passed header is invalid in DNS parameters")
k, v = item.split(":", 1)
headers.add(k, v)
kwargs["headers"] = headers
if "disabled_svn" in kwargs:
if not isinstance(kwargs["disabled_svn"], list):
kwargs["disabled_svn"] = [kwargs["disabled_svn"]]
disabled_svn = set()
for svn in kwargs["disabled_svn"]:
svn = svn.lower()
if svn == "h11":
disabled_svn.add(HttpVersion.h11)
elif svn == "h2":
disabled_svn.add(HttpVersion.h2)
elif svn == "h3":
disabled_svn.add(HttpVersion.h3)
kwargs["disabled_svn"] = disabled_svn
if "on_post_connection" in kwargs and callable(kwargs["on_post_connection"]):
self._connection_callback: (
typing.Callable[[ConnectionInfo], None] | None
) = kwargs["on_post_connection"]
kwargs.pop("on_post_connection")
else:
self._connection_callback = None
self._pool = HTTPSConnectionPool(self._server, self._port, **kwargs)
def close(self) -> None:
self._pool.close()
def is_available(self) -> bool:
return self._pool.pool is not None
def getaddrinfo(
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Tried to resolve 'localhost' from a HTTPSResolver"
)
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror("Address family for hostname not supported")
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror("Address family for hostname not supported")
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
validate_length_of(host)
promises: list[HTTPResponse | ResponsePromise] = []
remote_preemptive_quic_rr = False
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
quic_upgrade_via_dns_rr = False
if family in [socket.AF_UNSPEC, socket.AF_INET]:
if not self._rfc8484:
promises.append(
self._pool.request_encode_url(
"GET",
self._path,
{"name": host, "type": "1"},
headers={"Accept": "application/dns-json"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
else:
dns_query = DomainNameServerQuery(
host, SupportedQueryType.A, override_id=0
)
dns_payload = bytes(dns_query)
promises.append(
self._pool.request_encode_url(
"GET",
self._path,
{
"dns": b64encode(dns_payload).decode().replace("=", ""),
},
headers={"Accept": "application/dns-message"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
if not self._rfc8484:
promises.append(
self._pool.request_encode_url(
"GET",
self._path,
{"name": host, "type": "28"},
headers={"Accept": "application/dns-json"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
else:
dns_query = DomainNameServerQuery(
host, SupportedQueryType.AAAA, override_id=0
)
dns_payload = bytes(dns_query)
promises.append(
self._pool.request_encode_url(
"GET",
self._path,
{
"dns": b64encode(dns_payload).decode().replace("=", ""),
},
headers={"Accept": "application/dns-message"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
if quic_upgrade_via_dns_rr:
if not self._rfc8484:
promises.append(
self._pool.request_encode_url(
"GET",
self._path,
{"name": host, "type": "65"},
headers={"Accept": "application/dns-json"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
else:
dns_query = DomainNameServerQuery(
host, SupportedQueryType.HTTPS, override_id=0
)
dns_payload = bytes(dns_query)
promises.append(
self._pool.request_encode_url(
"GET",
self._path,
{
"dns": b64encode(dns_payload).decode().replace("=", ""),
},
headers={"Accept": "application/dns-message"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
responses: list[HTTPResponse] = []
for promise in promises:
if isinstance(promise, HTTPResponse):
responses.append(promise)
continue
responses.append(self._pool.get_response(promise=promise)) # type: ignore[arg-type]
results: list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
] = []
for response in responses:
if response.status >= 300:
raise socket.gaierror(
f"DNS over HTTPS was unsuccessful, server response status {response.status}."
)
if not self._rfc8484:
payload = response.json()
assert "Status" in payload and isinstance(payload["Status"], int)
if payload["Status"] != 0:
msg = (
payload["Comment"]
if "Comment" in payload
else f"Remote DNS indicated that an error occurred while providing resolution. Status {payload['Status']}."
)
if isinstance(msg, list):
msg = ", ".join(msg)
raise socket.gaierror(msg)
assert "Question" in payload and isinstance(payload["Question"], list)
if "Answer" not in payload:
continue
assert isinstance(payload["Answer"], list)
for answer in payload["Answer"]:
if answer["type"] not in [1, 28, 65]:
continue
assert "data" in answer
assert isinstance(answer["data"], str)
# DNS RR/HTTPS
if answer["type"] == 65:
# "1 . alpn=h3,h2 ipv4hint=104.16.132.229,104.16.133.229 ipv6hint=2606:4700::6810:84e5,2606:4700::6810:85e5"
# or..
# "1 . alpn=h2,h3"
rr: str = answer["data"]
if rr.startswith("\\#"): # it means, raw, bytes.
rr = "".join(rr[2:].split(" ")[2:])
try:
raw_record = bytes.fromhex(rr)
except ValueError:
raw_record = b""
https_record = parse_https_rdata(raw_record)
if "h3" not in https_record["alpn"]:
continue
remote_preemptive_quic_rr = True
else:
rr_decode: dict[str, str] = dict(
tuple(_.lower().split("=", 1)) # type: ignore[misc]
for _ in rr.split(" ")
if "=" in _
)
if "alpn" not in rr_decode or "h3" not in rr_decode["alpn"]:
continue
remote_preemptive_quic_rr = True
if "ipv4hint" in rr_decode and family in [
socket.AF_UNSPEC,
socket.AF_INET,
]:
for ipv4 in rr_decode["ipv4hint"].split(","):
results.append(
(
socket.AF_INET,
socket.SOCK_DGRAM,
17,
"",
(
ipv4,
port,
),
)
)
if "ipv6hint" in rr_decode and family in [
socket.AF_UNSPEC,
socket.AF_INET6,
]:
for ipv6 in rr_decode["ipv6hint"].split(","):
results.append(
(
socket.AF_INET6,
socket.SOCK_DGRAM,
17,
"",
(
ipv6,
port,
0,
0,
),
)
)
continue
inet_type = (
socket.AF_INET if answer["type"] == 1 else socket.AF_INET6
)
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
(
answer["data"],
port,
)
if inet_type == socket.AF_INET
else (
answer["data"],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
else:
dns_resp = DomainNameServerReturn(response.data)
for record in dns_resp.records:
if record[0] == SupportedQueryType.HTTPS:
assert isinstance(record[-1], dict)
if "h3" in record[-1]["alpn"]:
remote_preemptive_quic_rr = True
continue
assert not isinstance(record[-1], dict)
inet_type = (
socket.AF_INET
if record[0] == SupportedQueryType.A
else socket.AF_INET6
)
dst_addr = (
(
record[-1],
port,
)
if inet_type == socket.AF_INET
else (
record[-1],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
quic_results: list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
] = []
if remote_preemptive_quic_rr:
any_specified = False
for result in results:
if result[1] == socket.SOCK_STREAM:
quic_results.append(
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
)
else:
any_specified = True
break
if any_specified:
quic_results = []
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
class GoogleResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "google"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
if "rfc8484" in kwargs:
if kwargs["rfc8484"]:
kwargs["path"] = "/dns-query"
super().__init__("dns.google", port, *patterns, **kwargs)
class CloudflareResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "cloudflare"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query"})
super().__init__("cloudflare-dns.com", port, *patterns, **kwargs)
class AdGuardResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query", "rfc8484": True})
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
class OpenDNSResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "opendns"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query", "rfc8484": True})
super().__init__("dns.opendns.com", port, *patterns, **kwargs)
class Quad9Resolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "quad9"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query", "rfc8484": True})
super().__init__("dns11.quad9.net", port, *patterns, **kwargs)
class NextDNSResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "nextdns"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.nextdns.io", port, *patterns, **kwargs)

View File

@@ -0,0 +1,15 @@
from __future__ import annotations
try:
from ._qh3 import AdGuardResolver, NextDNSResolver, QUICResolver
except ImportError:
QUICResolver = None # type: ignore
AdGuardResolver = None # type: ignore
NextDNSResolver = None # type: ignore
__all__ = (
"QUICResolver",
"AdGuardResolver",
"NextDNSResolver",
)

View File

@@ -0,0 +1,541 @@
from __future__ import annotations
import socket
import ssl
import typing
from collections import deque
from ssl import SSLError
from time import time as monotonic
from qh3.quic.configuration import QuicConfiguration
from qh3.quic.connection import QuicConnection
from qh3.quic.events import (
ConnectionTerminated,
HandshakeCompleted,
QuicEvent,
StopSendingReceived,
StreamDataReceived,
StreamReset,
)
from ....util.ssl_ import IS_FIPS, resolve_cert_reqs
from ...ssa._gro import _sock_has_gro, _sock_has_gso, sync_recv_gro, sync_sendmsg_gso
from ..dou import PlainResolver
from ..protocols import (
COMMON_RCODE_LABEL,
DomainNameServerQuery,
DomainNameServerReturn,
ProtocolResolver,
SupportedQueryType,
)
from ..utils import (
is_ipv4,
is_ipv6,
rfc1035_pack,
rfc1035_should_read,
rfc1035_unpack,
validate_length_of,
)
if IS_FIPS:
raise ImportError(
"DNS-over-QUIC disabled when Python is built with FIPS-compliant ssl module"
)
class QUICResolver(PlainResolver):
protocol = ProtocolResolver.DOQ
implementation = "qh3"
def __init__(
self,
server: str,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
):
super().__init__(server, port or 853, *patterns, **kwargs)
# qh3 load_default_certs seems off. need to investigate.
if "ca_cert_data" not in kwargs and "ca_certs" not in kwargs:
kwargs["ca_cert_data"] = []
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
try:
ctx.load_default_certs()
for der in ctx.get_ca_certs(binary_form=True):
kwargs["ca_cert_data"].append(ssl.DER_cert_to_PEM_cert(der))
if kwargs["ca_cert_data"]:
kwargs["ca_cert_data"] = "".join(kwargs["ca_cert_data"])
else:
del kwargs["ca_cert_data"]
except (AttributeError, ValueError, OSError):
del kwargs["ca_cert_data"]
if "ca_cert_data" not in kwargs and "ca_certs" not in kwargs:
if (
"cert_reqs" not in kwargs
or resolve_cert_reqs(kwargs["cert_reqs"]) is ssl.CERT_REQUIRED
):
raise ssl.SSLError(
"DoQ requires at least one CA loaded in order to verify the remote peer certificate. "
"Add ?cert_reqs=0 to disable certificate checks."
)
configuration = QuicConfiguration(
is_client=True,
alpn_protocols=["doq"],
server_name=self._server
if "server_hostname" not in kwargs
else kwargs["server_hostname"],
verify_mode=resolve_cert_reqs(kwargs["cert_reqs"])
if "cert_reqs" in kwargs
else ssl.CERT_REQUIRED,
cadata=kwargs["ca_cert_data"].encode()
if "ca_cert_data" in kwargs
else None,
cafile=kwargs["ca_certs"] if "ca_certs" in kwargs else None,
idle_timeout=300.0,
)
if "cert_file" in kwargs:
configuration.load_cert_chain(
kwargs["cert_file"],
kwargs["key_file"] if "key_file" in kwargs else None,
kwargs["key_password"] if "key_password" in kwargs else None,
)
elif "cert_data" in kwargs:
configuration.load_cert_chain(
kwargs["cert_data"],
kwargs["key_data"] if "key_data" in kwargs else None,
kwargs["key_password"] if "key_password" in kwargs else None,
)
self._quic = QuicConnection(configuration=configuration)
self._dgram_gro_enabled: bool = _sock_has_gro(self._socket)
self._dgram_gso_enabled: bool = _sock_has_gso(self._socket)
self._quic.connect((self._server, self._port), monotonic())
self.__exchange_until(HandshakeCompleted, receive_first=False)
self._terminated: bool = False
self._should_disconnect: bool = False
# DNS over QUIC mandate the size-prefix (unsigned int, 2b)
self._rfc1035_prefix_mandated = True
self._unconsumed: deque[DomainNameServerReturn] = deque()
self._pending: deque[DomainNameServerQuery] = deque()
def close(self) -> None:
if not self._terminated:
with self._lock:
self._quic.close()
while True:
datagrams = self._quic.datagrams_to_send(monotonic())
if not datagrams:
break
if self._dgram_gso_enabled and len(datagrams) > 1:
sync_sendmsg_gso(self._socket, [d[0] for d in datagrams])
else:
for datagram in datagrams:
self._socket.sendall(datagram[0])
self._socket.close()
self._terminated = True
def is_available(self) -> bool:
self._quic.handle_timer(monotonic())
if hasattr(self._quic, "_close_event") and self._quic._close_event is not None:
self._terminated = True
return not self._terminated
def getaddrinfo(
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Tried to resolve 'localhost' using the QUICResolver"
)
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
validate_length_of(host)
remote_preemptive_quic_rr = False
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
quic_upgrade_via_dns_rr = False
tbq = []
if family in [socket.AF_UNSPEC, socket.AF_INET]:
tbq.append(SupportedQueryType.A)
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
tbq.append(SupportedQueryType.AAAA)
if quic_upgrade_via_dns_rr:
tbq.append(SupportedQueryType.HTTPS)
queries = DomainNameServerQuery.bulk(host, *tbq)
open_streams = []
with self._lock:
for q in queries:
payload = bytes(q)
self._pending.append(q)
if self._rfc1035_prefix_mandated is True:
payload = rfc1035_pack(payload)
stream_id = self._quic.get_next_available_stream_id()
self._quic.send_stream_data(stream_id, payload, True)
open_streams.append(stream_id)
datagrams = self._quic.datagrams_to_send(monotonic())
if self._dgram_gso_enabled and len(datagrams) > 1:
sync_sendmsg_gso(self._socket, [d[0] for d in datagrams])
else:
for dg in datagrams:
self._socket.sendall(dg[0])
responses: list[DomainNameServerReturn] = []
while len(responses) < len(tbq):
with self._lock:
if self._unconsumed:
dns_resp = None
for query in queries:
for unconsumed in self._unconsumed:
if unconsumed.id == query.id:
dns_resp = unconsumed
responses.append(dns_resp)
break
if dns_resp:
break
if dns_resp:
self._unconsumed.remove(dns_resp)
self._pending.remove(query)
continue
try:
events: list[StreamDataReceived] = self.__exchange_until( # type: ignore[assignment]
StreamDataReceived,
receive_first=True,
event_type_collectable=(StreamDataReceived,),
respect_end_stream_signal=False,
)
payload = b"".join([e.data for e in events])
while rfc1035_should_read(payload):
events.extend(
self.__exchange_until( # type: ignore[arg-type]
StreamDataReceived,
receive_first=True,
event_type_collectable=(StreamDataReceived,),
respect_end_stream_signal=False,
)
)
payload = b"".join([e.data for e in events])
except (TimeoutError, OSError, socket.timeout, ConnectionError) as e:
raise socket.gaierror(
"Got unexpectedly disconnected while waiting for name resolution"
) from e
if not payload:
continue
#: We can receive two responses at once (or more, concatenated). Let's unwrap them.
fragments = rfc1035_unpack(payload)
for fragment in fragments:
dns_resp = DomainNameServerReturn(fragment)
if any(dns_resp.id == _.id for _ in queries):
responses.append(dns_resp)
query_tbr: DomainNameServerQuery | None = None
for query_tbr in self._pending:
if query_tbr.id == dns_resp.id:
break
if query_tbr:
self._pending.remove(query_tbr)
else:
self._unconsumed.append(dns_resp)
if self._should_disconnect:
with self._lock:
self.close()
self._should_disconnect = False
self._terminated = True
results = []
for response in responses:
if not response.is_ok:
if response.rcode == 2:
raise socket.gaierror(
f"DNSSEC validation failure. Check http://dnsviz.net/d/{host}/dnssec/ and http://dnssec-debugger.verisignlabs.com/{host} for errors"
)
raise socket.gaierror(
f"DNS returned an error: {COMMON_RCODE_LABEL[response.rcode] if response.rcode in COMMON_RCODE_LABEL else f'code {response.rcode}'}"
)
for record in response.records:
if record[0] == SupportedQueryType.HTTPS:
assert isinstance(record[-1], dict)
if "h3" in record[-1]["alpn"]:
remote_preemptive_quic_rr = True
continue
assert not isinstance(record[-1], dict)
inet_type = (
socket.AF_INET
if record[0] == SupportedQueryType.A
else socket.AF_INET6
)
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
(
record[-1],
port,
)
if inet_type == socket.AF_INET
else (
record[-1],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
quic_results = []
if remote_preemptive_quic_rr:
any_specified = False
for result in results:
if result[1] == socket.SOCK_STREAM:
quic_results.append(
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
)
else:
any_specified = True
break
if any_specified:
quic_results = []
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
def __exchange_until(
self,
event_type: type[QuicEvent] | tuple[type[QuicEvent], ...],
*,
receive_first: bool = False,
event_type_collectable: type[QuicEvent]
| tuple[type[QuicEvent], ...]
| None = None,
respect_end_stream_signal: bool = True,
) -> list[QuicEvent]:
while True:
if receive_first is False:
now = monotonic()
while True:
datagrams = self._quic.datagrams_to_send(now)
if not datagrams:
break
if self._dgram_gso_enabled and len(datagrams) > 1:
sync_sendmsg_gso(self._socket, [d[0] for d in datagrams])
else:
for datagram in datagrams:
self._socket.sendall(datagram[0])
events = []
while True:
if not self._quic._events:
if self._dgram_gro_enabled:
data_in = sync_recv_gro(self._socket, 65535)
else:
data_in = self._socket.recv(1500)
if not data_in:
break
now = monotonic()
if isinstance(data_in, list):
for gro_segment in data_in:
self._quic.receive_datagram(
gro_segment, (self._server, self._port), now
)
else:
self._quic.receive_datagram(
data_in, (self._server, self._port), now
)
while True:
now = monotonic()
datagrams = self._quic.datagrams_to_send(now)
if not datagrams:
break
if self._dgram_gso_enabled and len(datagrams) > 1:
sync_sendmsg_gso(self._socket, [d[0] for d in datagrams])
else:
for datagram in datagrams:
self._socket.sendall(datagram[0])
for ev in iter(self._quic.next_event, None):
if isinstance(ev, ConnectionTerminated):
if ev.error_code == 298:
raise SSLError(
"DNS over QUIC did not succeed (Error 298). Chain certificate verification failed."
)
raise socket.gaierror(
f"DNS over QUIC encountered a unrecoverable failure (error {ev.error_code} {ev.reason_phrase})"
)
elif isinstance(ev, StreamReset):
self._terminated = True
raise socket.gaierror(
"DNS over QUIC server submitted a StreamReset. A request was rejected."
)
elif isinstance(ev, StopSendingReceived):
self._should_disconnect = True
continue
if event_type_collectable:
if isinstance(ev, event_type_collectable):
events.append(ev)
else:
events.append(ev)
if isinstance(ev, event_type):
if not respect_end_stream_signal:
return events
if hasattr(ev, "stream_ended") and ev.stream_ended:
return events
elif hasattr(ev, "stream_ended") is False:
return events
return events
class AdGuardResolver(
QUICResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
class NextDNSResolver(
QUICResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "nextdns"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.nextdns.io", port, *patterns, **kwargs)

View File

@@ -0,0 +1,19 @@
from __future__ import annotations
from ._ssl import (
AdGuardResolver,
CloudflareResolver,
GoogleResolver,
OpenDNSResolver,
Quad9Resolver,
TLSResolver,
)
__all__ = (
"TLSResolver",
"GoogleResolver",
"CloudflareResolver",
"AdGuardResolver",
"Quad9Resolver",
"OpenDNSResolver",
)

View File

@@ -0,0 +1,156 @@
from __future__ import annotations
import socket
import typing
from ....util.ssl_ import resolve_cert_reqs, ssl_wrap_socket
from ..dou import PlainResolver
from ..protocols import ProtocolResolver
from ..system import SystemResolver
class TLSResolver(PlainResolver):
"""
Basic DNS resolver over TLS.
Comply with RFC 7858: https://datatracker.ietf.org/doc/html/rfc7858
"""
protocol = ProtocolResolver.DOT
implementation = "ssl"
def __init__(
self,
server: str,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
) -> None:
if "timeout" in kwargs and isinstance(kwargs["timeout"], (int, float)):
timeout = kwargs["timeout"]
else:
timeout = None
if "source_address" in kwargs and isinstance(kwargs["source_address"], str):
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
else:
bind_ip, bind_port = "0.0.0.0", "0"
self._socket = SystemResolver().create_connection(
(server, port or 853),
timeout=timeout,
source_address=(bind_ip, int(bind_port))
if bind_ip != "0.0.0.0" or bind_port != "0"
else None,
socket_options=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1, "tcp"),),
socket_kind=socket.SOCK_STREAM,
)
super().__init__(server, port, *patterns, **kwargs)
self._socket = ssl_wrap_socket(
self._socket,
server_hostname=server
if "server_hostname" not in kwargs
else kwargs["server_hostname"],
keyfile=kwargs["key_file"] if "key_file" in kwargs else None,
certfile=kwargs["cert_file"] if "cert_file" in kwargs else None,
cert_reqs=resolve_cert_reqs(kwargs["cert_reqs"])
if "cert_reqs" in kwargs
else None,
ca_certs=kwargs["ca_certs"] if "ca_certs" in kwargs else None,
ssl_version=kwargs["ssl_version"] if "ssl_version" in kwargs else None,
ciphers=kwargs["ciphers"] if "ciphers" in kwargs else None,
ca_cert_dir=kwargs["ca_cert_dir"] if "ca_cert_dir" in kwargs else None,
key_password=kwargs["key_password"] if "key_password" in kwargs else None,
ca_cert_data=kwargs["ca_cert_data"] if "ca_cert_data" in kwargs else None,
certdata=kwargs["cert_data"] if "cert_data" in kwargs else None,
keydata=kwargs["key_data"] if "key_data" in kwargs else None,
)
# DNS over TLS mandate the size-prefix (unsigned int, 2 bytes)
self._rfc1035_prefix_mandated = True
class GoogleResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "google"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.google", port, *patterns, **kwargs)
class CloudflareResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "cloudflare"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("1.1.1.1", port, *patterns, **kwargs)
class AdGuardResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
class OpenDNSResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "opendns"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.opendns.com", port, *patterns, **kwargs)
class Quad9Resolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "quad9"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns11.quad9.net", port, *patterns, **kwargs)

View File

@@ -0,0 +1,17 @@
from __future__ import annotations
from ._socket import (
AdGuardResolver,
CloudflareResolver,
GoogleResolver,
PlainResolver,
Quad9Resolver,
)
__all__ = (
"PlainResolver",
"CloudflareResolver",
"GoogleResolver",
"Quad9Resolver",
"AdGuardResolver",
)

View File

@@ -0,0 +1,415 @@
from __future__ import annotations
import socket
import typing
from collections import deque
from ...ssa._gro import _sock_has_gro, sync_recv_gro
from ..protocols import (
COMMON_RCODE_LABEL,
BaseResolver,
DomainNameServerQuery,
DomainNameServerReturn,
ProtocolResolver,
SupportedQueryType,
)
from ..system import SystemResolver
from ..utils import (
is_ipv4,
is_ipv6,
packet_fragment,
rfc1035_pack,
rfc1035_should_read,
rfc1035_unpack,
validate_length_of,
)
class PlainResolver(BaseResolver):
"""
Minimalist DNS resolver over UDP
Comply with RFC 1035: https://datatracker.ietf.org/doc/html/rfc1035
EDNS is not supported, yet. But we plan to. Willing to contribute?
"""
protocol = ProtocolResolver.DOU
implementation = "socket"
def __init__(
self,
server: str,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
) -> None:
super().__init__(server, port, *patterns, **kwargs)
if not hasattr(self, "_socket"):
if "timeout" in kwargs and isinstance(
kwargs["timeout"],
(
float,
int,
),
):
timeout = kwargs["timeout"]
else:
timeout = None
if "source_address" in kwargs and isinstance(kwargs["source_address"], str):
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
else:
bind_ip, bind_port = "0.0.0.0", "0"
self._socket = SystemResolver().create_connection(
(server, port or 53),
timeout=timeout,
source_address=(bind_ip, int(bind_port))
if bind_ip != "0.0.0.0" or bind_port != "0"
else None,
socket_options=None,
socket_kind=socket.SOCK_DGRAM,
)
#: Only useful for inheritance, e.g. DNS over TLS support dns-message but require a prefix.
self._rfc1035_prefix_mandated: bool = False
self._gro_enabled: bool = _sock_has_gro(self._socket)
self._unconsumed: deque[DomainNameServerReturn] = deque()
self._pending: deque[DomainNameServerQuery] = deque()
self._terminated: bool = False
def close(self) -> None:
if not self._terminated:
with self._lock:
if self._socket is not None:
self._socket.shutdown(0)
self._socket.close()
self._terminated = True
def is_available(self) -> bool:
return not self._terminated
def getaddrinfo(
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Tried to resolve 'localhost' from a PlainResolver"
)
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
validate_length_of(host)
remote_preemptive_quic_rr = False
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
quic_upgrade_via_dns_rr = False
tbq = []
if family in [socket.AF_UNSPEC, socket.AF_INET]:
tbq.append(SupportedQueryType.A)
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
tbq.append(SupportedQueryType.AAAA)
if quic_upgrade_via_dns_rr:
tbq.append(SupportedQueryType.HTTPS)
queries = DomainNameServerQuery.bulk(host, *tbq)
with self._lock:
for q in queries:
payload = bytes(q)
self._pending.append(q)
if self._rfc1035_prefix_mandated is True:
payload = rfc1035_pack(payload)
self._socket.sendall(payload)
responses: list[DomainNameServerReturn] = []
while len(responses) < len(tbq):
with self._lock:
#: There we want to verify if another thread got a response that belong to this thread.
if self._unconsumed:
dns_resp = None
for query in queries:
for unconsumed in self._unconsumed:
if unconsumed.id == query.id:
dns_resp = unconsumed
responses.append(dns_resp)
break
if dns_resp:
break
if dns_resp:
self._pending.remove(query)
self._unconsumed.remove(dns_resp)
continue
try:
if self._gro_enabled:
data_in_or_segments = sync_recv_gro(self._socket, 65535)
else:
data_in_or_segments = self._socket.recv(1500)
if isinstance(data_in_or_segments, list):
payloads = data_in_or_segments
elif data_in_or_segments:
payloads = [data_in_or_segments]
else:
payloads = []
if self._rfc1035_prefix_mandated is True and payloads:
payload = b"".join(payloads)
while rfc1035_should_read(payload):
extra = self._socket.recv(1500)
if isinstance(extra, list):
payload += b"".join(extra)
else:
payload += extra
payloads = [payload]
except (TimeoutError, OSError, socket.timeout, ConnectionError) as e:
raise socket.gaierror(
"Got unexpectedly disconnected while waiting for name resolution"
) from e
if not payloads:
self._terminated = True
raise socket.gaierror(
"Got unexpectedly disconnected while waiting for name resolution"
)
pending_raw_identifiers = [_.raw_id for _ in self._pending]
for payload in payloads:
#: We can receive two responses at once (or more, concatenated). Let's unwrap them.
if self._rfc1035_prefix_mandated is True:
fragments = rfc1035_unpack(payload)
else:
fragments = packet_fragment(payload, *pending_raw_identifiers)
for fragment in fragments:
dns_resp = DomainNameServerReturn(fragment)
if any(dns_resp.id == _.id for _ in queries):
responses.append(dns_resp)
query_tbr: DomainNameServerQuery | None = None
for query_tbr in self._pending:
if query_tbr.id == dns_resp.id:
break
if query_tbr:
self._pending.remove(query_tbr)
else:
self._unconsumed.append(dns_resp)
results = []
for response in responses:
if not response.is_ok:
if response.rcode == 2:
raise socket.gaierror(
f"DNSSEC validation failure. Check http://dnsviz.net/d/{host}/dnssec/ and http://dnssec-debugger.verisignlabs.com/{host} for errors"
)
raise socket.gaierror(
f"DNS returned an error: {COMMON_RCODE_LABEL[response.rcode] if response.rcode in COMMON_RCODE_LABEL else f'code {response.rcode}'}"
)
for record in response.records:
if record[0] == SupportedQueryType.HTTPS:
assert isinstance(record[-1], dict)
if "h3" in record[-1]["alpn"]:
remote_preemptive_quic_rr = True
continue
assert not isinstance(record[-1], dict)
inet_type = (
socket.AF_INET
if record[0] == SupportedQueryType.A
else socket.AF_INET6
)
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
(
record[-1],
port,
)
if inet_type == socket.AF_INET
else (
record[-1],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
quic_results = []
if remote_preemptive_quic_rr:
any_specified = False
for result in results:
if result[1] == socket.SOCK_STREAM:
quic_results.append(
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
)
else:
any_specified = True
break
if any_specified:
quic_results = []
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
class CloudflareResolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "cloudflare"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("1.1.1.1", port, *patterns, **kwargs)
class GoogleResolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "google"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("8.8.8.8", port, *patterns, **kwargs)
class Quad9Resolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "quad9"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("9.9.9.9", port, *patterns, **kwargs)
class AdGuardResolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("94.140.14.140", port, *patterns, **kwargs)

View File

@@ -0,0 +1,230 @@
from __future__ import annotations
import importlib
import inspect
import typing
from abc import ABCMeta
from base64 import b64encode
from typing import Any
from urllib.parse import parse_qs
from ...util import parse_url
from .protocols import BaseResolver, ProtocolResolver
class ResolverFactory(metaclass=ABCMeta):
@staticmethod
def new(
protocol: ProtocolResolver,
specifier: str | None = None,
implementation: str | None = None,
**kwargs: Any,
) -> BaseResolver:
package_name: str = __name__.split(".")[0]
module_expr = f".{protocol.value.replace('-', '_')}"
if implementation:
module_expr += f"._{implementation.replace('-', '_').lower()}"
spe_msg = " " if specifier is None else f' (w/ specifier "{specifier}") '
try:
resolver_module = importlib.import_module(
module_expr, f"{package_name}.contrib.resolver"
)
except ImportError as e:
raise NotImplementedError(
f"{protocol}{spe_msg}cannot be loaded. Tried to import '{module_expr}'. Did you specify a non-existent implementation?"
) from e
implementations: list[tuple[str, type[BaseResolver]]] = inspect.getmembers(
resolver_module,
lambda e: isinstance(e, type)
and issubclass(e, BaseResolver)
and (
(specifier is None and e.specifier is None) or specifier == e.specifier
),
)
if not implementations:
raise NotImplementedError(
f"{protocol}{spe_msg}cannot be loaded. "
"No compatible implementation available. "
"Make sure your implementation inherit from BaseResolver."
)
implementation_target: type[BaseResolver] = implementations.pop()[1]
return implementation_target(**kwargs)
class ResolverDescription:
"""Describe how a BaseResolver must be instantiated."""
def __init__(
self,
protocol: ProtocolResolver,
specifier: str | None = None,
implementation: str | None = None,
server: str | None = None,
port: int | None = None,
*host_patterns: str,
**kwargs: typing.Any,
) -> None:
self.protocol = protocol
self.specifier = specifier
self.implementation = implementation
self.server = server
self.port = port
self.host_patterns = host_patterns
self.kwargs = kwargs
def __setitem__(self, key: str, value: typing.Any) -> None:
self.kwargs[key] = value
def __contains__(self, item: str) -> bool:
return item in self.kwargs
def new(self) -> BaseResolver:
kwargs = {**self.kwargs}
if self.server:
kwargs["server"] = self.server
if self.port:
kwargs["port"] = self.port
if self.host_patterns:
kwargs["patterns"] = self.host_patterns
return ResolverFactory.new(
self.protocol,
self.specifier,
self.implementation,
**kwargs,
)
@staticmethod
def from_url(url: str) -> ResolverDescription:
parsed_url = parse_url(url)
schema = parsed_url.scheme
if schema is None:
raise ValueError("Given DNS url is missing a protocol")
specifier = None
implementation = None
if "+" in schema:
schema, specifier = tuple(schema.lower().split("+", 1))
protocol = ProtocolResolver(schema)
kwargs: dict[str, typing.Any] = {}
if parsed_url.path:
kwargs["path"] = parsed_url.path
if parsed_url.auth:
kwargs["headers"] = dict()
if ":" in parsed_url.auth:
username, password = parsed_url.auth.split(":")
username = username.strip("'\"")
password = password.strip("'\"")
kwargs["headers"]["Authorization"] = (
f"Basic {b64encode(f'{username}:{password}'.encode()).decode()}"
)
else:
kwargs["headers"]["Authorization"] = f"Bearer {parsed_url.auth}"
if parsed_url.query:
parameters = parse_qs(parsed_url.query)
for parameter in parameters:
if not parameters[parameter]:
continue
parameter_insensible = parameter.lower()
if (
isinstance(parameters[parameter], list)
and len(parameters[parameter]) > 1
):
if parameter == "implementation":
raise ValueError("Only one implementation can be passed to URL")
values = []
for e in parameters[parameter]:
if "," in e:
values.extend(e.split(","))
else:
values.append(e)
if parameter_insensible in kwargs:
if isinstance(kwargs[parameter_insensible], list):
kwargs[parameter_insensible].extend(values)
else:
values.append(kwargs[parameter_insensible])
kwargs[parameter_insensible] = values
continue
kwargs[parameter_insensible] = values
continue
value: str = parameters[parameter][0].lower().strip(" ")
if parameter == "implementation":
implementation = value
continue
if "," in value:
list_of_values = value.split(",")
if parameter_insensible in kwargs:
if isinstance(kwargs[parameter_insensible], list):
kwargs[parameter_insensible].extend(list_of_values)
else:
list_of_values.append(kwargs[parameter_insensible])
continue
kwargs[parameter_insensible] = list_of_values
continue
value_converted: bool | int | float | None = None
if value in ["false", "true"]:
value_converted = True if value == "true" else False
elif value.isdigit():
value_converted = int(value)
elif (
value.count(".") == 1
and value.index(".") > 0
and value.replace(".", "").isdigit()
):
value_converted = float(value)
kwargs[parameter_insensible] = (
value if value_converted is None else value_converted
)
host_patterns: list[str] = []
if "hosts" in kwargs:
host_patterns = (
kwargs["hosts"].split(",")
if isinstance(kwargs["hosts"], str)
else kwargs["hosts"]
)
del kwargs["hosts"]
return ResolverDescription(
protocol,
specifier,
implementation,
parsed_url.host,
parsed_url.port,
*host_patterns,
**kwargs,
)

View File

@@ -0,0 +1,5 @@
from __future__ import annotations
from ._dict import InMemoryResolver
__all__ = ("InMemoryResolver",)

View File

@@ -0,0 +1,192 @@
from __future__ import annotations
import socket
import typing
from ....util.url import _IPV6_ADDRZ_RE
from ..protocols import BaseResolver, ProtocolResolver
from ..utils import is_ipv4, is_ipv6
class InMemoryResolver(BaseResolver):
protocol = ProtocolResolver.MANUAL
implementation = "dict"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
kwargs.pop("port")
super().__init__(None, None, *patterns, **kwargs)
self._maxsize = 65535 if "maxsize" not in kwargs else int(kwargs["maxsize"])
self._hosts: dict[str, list[tuple[socket.AddressFamily, str]]] = {}
if self._host_patterns:
for record in self._host_patterns:
if ":" not in record:
continue
hostname, addr = record.split(":", 1)
self.register(hostname, addr)
self._host_patterns = tuple([])
# probably about our happy eyeballs impl (sync only)
if len(self._hosts) == 1 and len(self._hosts[list(self._hosts.keys())[0]]) == 1:
self._unsafe_expose = True
def recycle(self) -> BaseResolver:
return self
def close(self) -> None:
pass # no-op
def is_available(self) -> bool:
return True
def have_constraints(self) -> bool:
return True
def support(self, hostname: str | bytes | None) -> bool | None:
if hostname is None:
hostname = "localhost"
if isinstance(hostname, bytes):
hostname = hostname.decode("ascii")
return hostname in self._hosts
def register(self, hostname: str, ipaddr: str) -> None:
with self._lock:
if hostname not in self._hosts:
self._hosts[hostname] = []
else:
for e in self._hosts[hostname]:
t, addr = e
if addr in ipaddr:
return
if _IPV6_ADDRZ_RE.match(ipaddr):
self._hosts[hostname].append((socket.AF_INET6, ipaddr[1:-1]))
elif is_ipv6(ipaddr):
self._hosts[hostname].append((socket.AF_INET6, ipaddr))
else:
self._hosts[hostname].append((socket.AF_INET, ipaddr))
if len(self._hosts) > self._maxsize:
k = None
for k in self._hosts.keys():
break
if k:
self._hosts.pop(k)
def clear(self, hostname: str) -> None:
with self._lock:
if hostname in self._hosts:
del self._hosts[hostname]
def getaddrinfo(
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
host = "localhost" # Defensive: stdlib cpy behavior
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
results: list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
] = []
with self._lock:
if host not in self._hosts:
raise socket.gaierror(f"no records found for hostname {host} in-memory")
for entry in self._hosts[host]:
addr_type, addr_target = entry
if family != socket.AF_UNSPEC:
if family != addr_type:
continue
results.append(
(
addr_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
(addr_target, port)
if addr_type == socket.AF_INET
else (addr_target, port, 0, 0),
)
)
if not results:
raise socket.gaierror(f"no records found for hostname {host} in-memory")
return results

View File

@@ -0,0 +1,104 @@
from __future__ import annotations
import socket
import typing
from ..protocols import BaseResolver, ProtocolResolver
from ..utils import is_ipv4, is_ipv6
class NullResolver(BaseResolver):
protocol = ProtocolResolver.NULL
implementation = "dummy"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
kwargs.pop("port")
super().__init__(None, None, *patterns, **kwargs)
def recycle(self) -> BaseResolver:
return self
def close(self) -> None:
pass # no-op
def is_available(self) -> bool:
return True
def getaddrinfo(
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
host = "localhost" # Defensive: stdlib cpy behavior
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror(
"Servname not supported for ai_socktype"
) # Defensive: stdlib cpy behavior
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror(
"Address family for hostname not supported"
) # Defensive: stdlib cpy behavior
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror(
"Address family for hostname not supported"
) # Defensive: stdlib cpy behavior
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
raise socket.gaierror(f"Tried to resolve '{host}' using the NullResolver")
__all__ = ("NullResolver",)

View File

@@ -0,0 +1,655 @@
from __future__ import annotations
import ipaddress
import socket
import struct
import sys
import threading
import typing
from abc import ABCMeta, abstractmethod
from copy import deepcopy
from datetime import datetime, timedelta, timezone
from enum import Enum
from random import randint
from ..._constant import UDP_LINUX_GRO
from ..._typing import _TYPE_SOCKET_OPTIONS, _TYPE_TIMEOUT_INTERNAL
from ...exceptions import LocationParseError
from ...util.connection import _set_socket_options, allowed_gai_family
from ...util.ssl_match_hostname import CertificateError, match_hostname
from ...util.timeout import _DEFAULT_TIMEOUT
from .utils import inet4_ntoa, inet6_ntoa, parse_https_rdata
if typing.TYPE_CHECKING:
from .utils import HttpsRecord
class ProtocolResolver(str, Enum):
"""
At urllib3.future we aim to propose a wide range of DNS-protocols.
The most used techniques are available.
"""
#: Ask the OS native DNS layer
SYSTEM = "system"
#: DNS over HTTPS
DOH = "doh"
#: DNS over QUIC
DOQ = "doq"
#: DNS over TLS
DOT = "dot"
#: DNS over UDP (insecure)
DOU = "dou"
#: Manual (e.g. hosts)
MANUAL = "in-memory"
#: Void (e.g. purposely disable resolution)
NULL = "null"
#: Custom (e.g. your own implementation, use this when it does not suit any of the protocols specified)
CUSTOM = "custom"
class BaseResolver(metaclass=ABCMeta):
protocol: typing.ClassVar[ProtocolResolver]
specifier: typing.ClassVar[str | None] = None
implementation: typing.ClassVar[str]
def __init__(
self,
server: str | None,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
) -> None:
self._server = server
self._port = port
self._host_patterns: tuple[str, ...] = patterns
self._lock = threading.Lock()
self._kwargs = kwargs
if not self._host_patterns and "patterns" in kwargs:
self._host_patterns = kwargs["patterns"]
# allow to temporarily expose a sock that is "being" created
# this helps with our Happy Eyeballs implementation in sync.
self._unsafe_expose: bool = False
self._sock_cursor: socket.socket | None = None
def recycle(self) -> BaseResolver:
if self.is_available():
raise RuntimeError("Attempting to recycle a Resolver that was not closed")
args = list(self.__class__.__init__.__code__.co_varnames)
args.remove("self")
kwargs_cpy = deepcopy(self._kwargs)
if self._server:
kwargs_cpy["server"] = self._server
if self._port:
kwargs_cpy["port"] = self._port
if "patterns" in args and "kwargs" in args:
return self.__class__(*self._host_patterns, **kwargs_cpy) # type: ignore[arg-type]
elif "kwargs" in args:
return self.__class__(**kwargs_cpy)
return self.__class__() # type: ignore[call-arg]
@property
def server(self) -> str | None:
return self._server
@property
def port(self) -> int | None:
return self._port
def have_constraints(self) -> bool:
return bool(self._host_patterns)
def support(self, hostname: str | bytes | None) -> bool | None:
"""
Determine if given hostname is especially resolvable by given resolver.
If this resolver does not have any constrained list of host, it returns None. Meaning
it support any hostname for resolution.
"""
if not self._host_patterns:
return None
if hostname is None:
hostname = "localhost"
if isinstance(hostname, bytes):
hostname = hostname.decode("ascii")
try:
match_hostname(
{"subjectAltName": (tuple(("DNS", e) for e in self._host_patterns))},
hostname,
)
except CertificateError:
return False
return True
@abstractmethod
def close(self) -> None:
"""Terminate the given resolver instance. This should render it unusable. Further inquiries should raise an exception."""
raise NotImplementedError
@abstractmethod
def is_available(self) -> bool:
"""Determine if Resolver can receive inquiries."""
raise NotImplementedError
@abstractmethod
def getaddrinfo(
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
"""This method align itself on the standard library socket.getaddrinfo(). It must be implemented as-is on your Resolver."""
raise NotImplementedError
# This function is copied from socket.py in the Python 2.7 standard
# library test suite. Added to its signature is only `socket_options`.
# One additional modification is that we avoid binding to IPv6 servers
# discovered in DNS if the system doesn't have IPv6 functionality.
def create_connection(
self,
address: tuple[str, int],
timeout: _TYPE_TIMEOUT_INTERNAL = _DEFAULT_TIMEOUT,
source_address: tuple[str, int] | None = None,
socket_options: _TYPE_SOCKET_OPTIONS | None = None,
socket_kind: socket.SocketKind = socket.SOCK_STREAM,
*,
quic_upgrade_via_dns_rr: bool = False,
timing_hook: typing.Callable[[tuple[timedelta, timedelta, datetime]], None]
| None = None,
default_socket_family: socket.AddressFamily = socket.AF_UNSPEC,
) -> socket.socket:
"""Connect to *address* and return the socket object.
Convenience function. Connect to *address* (a 2-tuple ``(host,
port)``) and return the socket object. Passing the optional
*timeout* parameter will set the timeout on the socket instance
before attempting to connect. If no *timeout* is supplied, the
global default timeout setting returned by :func:`socket.getdefaulttimeout`
is used. If *source_address* is set it must be a tuple of (host, port)
for the socket to bind as a source address before making the connection.
An host of '' or port 0 tells the OS to use the default.
"""
host, port = address
if host.startswith("["):
host = host.strip("[]")
err = None
# Using the value from allowed_gai_family() in the context of getaddrinfo lets
# us select whether to work with IPv4 DNS records, IPv6 records, or both.
# The original create_connection function always returns all records.
family = allowed_gai_family()
if family != socket.AF_UNSPEC:
default_socket_family = family
if source_address is not None:
if isinstance(
ipaddress.ip_address(source_address[0]), ipaddress.IPv4Address
):
default_socket_family = socket.AF_INET
else:
default_socket_family = socket.AF_INET6
try:
host.encode("idna")
except UnicodeError:
raise LocationParseError(f"'{host}', label empty or too long") from None
dt_pre_resolve = datetime.now(tz=timezone.utc)
records = self.getaddrinfo(
host,
port,
default_socket_family,
socket_kind,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
delta_post_resolve = datetime.now(tz=timezone.utc) - dt_pre_resolve
dt_pre_established = datetime.now(tz=timezone.utc)
for res in records:
af, socktype, proto, canonname, sa = res
sock = None
try:
sock = socket.socket(af, socktype, proto)
# we need to add this or reusing the same origin port will likely fail within
# short period of time. kernel put port on wait shut.
if source_address is not None:
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
except (
OSError,
AttributeError,
): # Defensive: Windows or very old OS?
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
except (
OSError,
AttributeError,
): # Defensive: we can't do anything better than this.
pass
try:
sock.setsockopt(
socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0)
)
except (OSError, AttributeError):
pass
sock.bind(source_address)
# attempt to leverage GRO when under Linux
if socktype == socket.SOCK_DGRAM and sys.platform == "linux":
try:
sock.setsockopt(socket.SOL_UDP, UDP_LINUX_GRO, 1)
except OSError: # Defensive: oh, well(...) anyway!
pass
# If provided, set socket level options before connecting.
_set_socket_options(sock, socket_options)
if timeout is not _DEFAULT_TIMEOUT:
sock.settimeout(timeout)
if self._unsafe_expose:
self._sock_cursor = sock
sock.connect(sa)
if self._unsafe_expose:
self._sock_cursor = None
# Break explicitly a reference cycle
err = None
delta_post_established = (
datetime.now(tz=timezone.utc) - dt_pre_established
)
if timing_hook is not None:
timing_hook(
(
delta_post_resolve,
delta_post_established,
datetime.now(tz=timezone.utc),
)
)
return sock
except (OSError, OverflowError) as _:
err = _
if sock is not None:
sock.close()
if isinstance(_, OverflowError):
break
if err is not None:
try:
raise err
finally:
# Break explicitly a reference cycle
err = None
else:
raise OSError("getaddrinfo returns an empty list")
class ManyResolver(BaseResolver):
"""
Special resolver that use many child resolver. Priorities
are based on given order (list of BaseResolver).
"""
def __init__(self, *resolvers: BaseResolver) -> None:
super().__init__(None, None)
self._size = len(resolvers)
self._unconstrained: list[BaseResolver] = [
_ for _ in resolvers if not _.have_constraints()
]
self._constrained: list[BaseResolver] = [
_ for _ in resolvers if _.have_constraints()
]
self._concurrent: int = 0
self._terminated: bool = False
def recycle(self) -> BaseResolver:
resolvers = []
for resolver in self._unconstrained + self._constrained:
resolvers.append(resolver.recycle())
return ManyResolver(*resolvers)
def close(self) -> None:
for resolver in self._unconstrained + self._constrained:
resolver.close()
self._terminated = True
def is_available(self) -> bool:
return not self._terminated
def __resolvers(
self, constrained: bool = False
) -> typing.Generator[BaseResolver, None, None]:
resolvers = self._unconstrained if not constrained else self._constrained
if not resolvers:
return
with self._lock:
self._concurrent += 1
try:
resolver_count = len(resolvers)
start_idx = (self._concurrent - 1) % resolver_count
for idx in range(start_idx, resolver_count):
if not resolvers[idx].is_available():
with self._lock:
resolvers[idx] = resolvers[idx].recycle()
yield resolvers[idx]
if start_idx > 0:
for idx in range(0, start_idx):
if not resolvers[idx].is_available():
with self._lock:
resolvers[idx] = resolvers[idx].recycle()
yield resolvers[idx]
finally:
with self._lock:
self._concurrent -= 1
def getaddrinfo(
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if isinstance(host, bytes):
host = host.decode("ascii")
if host is None:
host = "localhost"
tested_resolvers = []
any_constrained_tried: bool = False
for resolver in self.__resolvers(True):
can_resolve = resolver.support(host)
if can_resolve is True:
any_constrained_tried = True
try:
results = resolver.getaddrinfo(
host,
port,
family,
type,
proto,
flags,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
if results:
return results
except socket.gaierror as exc:
if isinstance(exc.args[0], str) and (
"DNSSEC" in exc.args[0] or "DNSKEY" in exc.args[0]
):
raise
continue
elif can_resolve is False:
tested_resolvers.append(resolver)
if any_constrained_tried:
raise socket.gaierror(
f"Name or service not known: {host} using {self._size - len(self._unconstrained)} resolver(s)"
)
for resolver in self.__resolvers():
try:
results = resolver.getaddrinfo(
host,
port,
family,
type,
proto,
flags,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
if results:
return results
except socket.gaierror as exc:
if isinstance(exc.args[0], str) and (
"DNSSEC" in exc.args[0] or "DNSKEY" in exc.args[0]
):
raise
continue
raise socket.gaierror(
f"Name or service not known: {host} using {self._size - len(self._constrained)} resolver(s)"
)
class SupportedQueryType(int, Enum):
"""
urllib3.future does not need anything else so far. let's be pragmatic.
Each type is associated with its hex value as per the RFC.
"""
A = 0x0001
AAAA = 0x001C
HTTPS = 0x0041
class DomainNameServerQuery:
"""
Minimalist DNS query/message to ask for A, AAAA and HTTPS records.
Only meant for urllib3.future use. Does not cover all of possible extent of use.
"""
def __init__(
self, host: str, query_type: SupportedQueryType, override_id: int | None = None
) -> None:
self._id = struct.pack(
"!H", randint(0x0000, 0xFFFF) if override_id is None else override_id
)
self._host = host
self._query = query_type
self._flags = struct.pack("!H", 0x0100)
self._qd_count = struct.pack("!H", 1)
self._cached: bytes | None = None
@property
def id(self) -> int:
return struct.unpack("!H", self._id)[0] # type: ignore[no-any-return]
@property
def raw_id(self) -> bytes:
return self._id
def __repr__(self) -> str:
return f"<Query '{self._host}' IN {self._query.name}>"
def __bytes__(self) -> bytes:
if self._cached:
return self._cached
payload = b""
payload += self._id
payload += self._flags
payload += self._qd_count
payload += b"\x00\x00"
payload += b"\x00\x00"
payload += b"\x00\x00"
for ext in self._host.split("."):
payload += struct.pack("!B", len(ext))
payload += ext.encode("ascii")
payload += b"\x00"
payload += struct.pack("!H", self._query.value)
payload += struct.pack("!H", 0x0001)
self._cached = payload
return payload
@staticmethod
def bulk(host: str, *types: SupportedQueryType) -> list[DomainNameServerQuery]:
queries = []
for query_type in types:
queries.append(DomainNameServerQuery(host, query_type=query_type))
return queries
#: Most common status code, not exhaustive at all.
COMMON_RCODE_LABEL: dict[int, str] = {
0: "No Error",
1: "Format Error",
2: "Server Failure",
3: "Non-Existent Domain",
5: "Query Refused",
9: "Not Authorized",
}
class DomainNameServerParseException(Exception): ...
class DomainNameServerReturn:
"""
Minimalist DNS response parser. Allow to quickly extract key-data out of it.
Meant for A, AAAA and HTTPS records. Basically only what we need.
"""
def __init__(self, payload: bytes) -> None:
try:
up = struct.unpack("!HHHHHH", payload[:12])
self._id = up[0]
self._flags = up[1]
self._qd_count = up[2]
self._an_count = up[3]
self._rcode = int(f"0x{hex(payload[3])[-1]}", 16)
self._hostname: str = ""
idx = 12
while True:
c = payload[idx]
if c == 0:
idx += 1
break
self._hostname += payload[idx + 1 : idx + 1 + c].decode("ascii") + "."
idx += c + 1
self._records: list[tuple[SupportedQueryType, int, str | HttpsRecord]] = []
if self._an_count:
idx += 4
while idx < len(payload):
up = struct.unpack("!HHHI", payload[idx : idx + 10])
entry_size = struct.unpack("!H", payload[idx + 10 : idx + 12])[0]
data = payload[idx + 12 : idx + 12 + entry_size]
if len(data) == 4:
decoded_data: str | HttpsRecord = inet4_ntoa(data)
elif len(data) == 16:
decoded_data = inet6_ntoa(data)
elif data:
decoded_data = parse_https_rdata(data)
else:
continue
try:
self._records.append(
(SupportedQueryType(up[1]), up[-1], decoded_data)
)
except ValueError:
pass
idx += 12 + entry_size
except (struct.error, IndexError, ValueError, UnicodeDecodeError) as e:
raise DomainNameServerParseException(
"A protocol error occurred while parsing the DNS response payload: "
f"{str(e)}"
) from e
@property
def id(self) -> int:
return self._id # type: ignore[no-any-return]
@property
def hostname(self) -> str:
return self._hostname
@property
def records(self) -> list[tuple[SupportedQueryType, int, str | HttpsRecord]]:
return self._records
@property
def is_found(self) -> bool:
return bool(self._records)
@property
def rcode(self) -> int:
return self._rcode
@property
def is_ok(self) -> bool:
return self._rcode == 0
def __repr__(self) -> str:
if self.is_ok:
return f"<Records '{self.hostname}' {self._records}>"
return f"<DNS Error '{self.hostname}' with Status {self.rcode} ({COMMON_RCODE_LABEL[self.rcode] if self.rcode in COMMON_RCODE_LABEL else 'Unknown'})>"

View File

@@ -0,0 +1,5 @@
from __future__ import annotations
from ._socket import SystemResolver
__all__ = ("SystemResolver",)

View File

@@ -0,0 +1,65 @@
from __future__ import annotations
import socket
import typing
from ..protocols import BaseResolver, ProtocolResolver
class SystemResolver(BaseResolver):
implementation = "socket"
protocol = ProtocolResolver.SYSTEM
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
kwargs.pop("port")
super().__init__(None, None, *patterns, **kwargs)
def support(self, hostname: str | bytes | None) -> bool | None:
if hostname is None:
return True
if isinstance(hostname, bytes):
hostname = hostname.decode("ascii")
if hostname == "localhost":
return True
return super().support(hostname)
def recycle(self) -> BaseResolver:
return self
def close(self) -> None:
pass # no-op!
def is_available(self) -> bool:
return True
def getaddrinfo(
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
# the | tuple[int, bytes] is silently ignored, can't happen with our cases.
return socket.getaddrinfo( # type: ignore[return-value]
host=host,
port=port,
family=family,
type=type,
proto=proto,
flags=flags,
)

View File

@@ -0,0 +1,322 @@
from __future__ import annotations
import base64
import binascii
import socket
import struct
import typing
if typing.TYPE_CHECKING:
class HttpsRecord(typing.TypedDict):
priority: int
target: str
alpn: list[str]
ipv4hint: list[str]
ipv6hint: list[str]
echconfig: list[str]
def inet4_ntoa(address: bytes) -> str:
"""
Convert an IPv4 address from bytes to str.
"""
if len(address) != 4:
raise ValueError(
f"IPv4 addresses are 4 bytes long, got {len(address)} byte(s) instead"
)
return "%u.%u.%u.%u" % (address[0], address[1], address[2], address[3])
def inet6_ntoa(address: bytes) -> str:
"""
Convert an IPv6 address from bytes to str.
"""
if len(address) != 16:
raise ValueError(
f"IPv6 addresses are 16 bytes long, got {len(address)} byte(s) instead"
)
hex = binascii.hexlify(address)
chunks = []
i = 0
length = len(hex)
while i < length:
chunk = hex[i : i + 4].decode().lstrip("0") or "0"
chunks.append(chunk)
i += 4
# Compress the longest subsequence of 0-value chunks to ::
best_start = 0
best_len = 0
start = -1
last_was_zero = False
for i in range(8):
if chunks[i] != "0":
if last_was_zero:
end = i
current_len = end - start
if current_len > best_len:
best_start = start
best_len = current_len
last_was_zero = False
elif not last_was_zero:
start = i
last_was_zero = True
if last_was_zero:
end = 8
current_len = end - start
if current_len > best_len:
best_start = start
best_len = current_len
if best_len > 1:
if best_start == 0 and (best_len == 6 or best_len == 5 and chunks[5] == "ffff"):
# We have an embedded IPv4 address
if best_len == 6:
prefix = "::"
else:
prefix = "::ffff:"
thex = prefix + inet4_ntoa(address[12:])
else:
thex = (
":".join(chunks[:best_start])
+ "::"
+ ":".join(chunks[best_start + best_len :])
)
else:
thex = ":".join(chunks)
return thex
def packet_fragment(payload: bytes, *identifiers: bytes) -> tuple[bytes, ...]:
results = []
offset = 0
start_packet_idx = []
lead_identifier = None
for identifier in identifiers:
idx = payload[:12].find(identifier)
if idx == -1:
continue
if idx != 0:
offset = idx
start_packet_idx.append(idx - offset)
lead_identifier = identifier
break
for identifier in identifiers:
if identifier == lead_identifier:
continue
if offset == 0:
idx = payload.find(b"\x02" + identifier)
else:
idx = payload.find(identifier)
if idx == -1:
continue
start_packet_idx.append(idx - offset)
if not start_packet_idx:
raise ValueError(
"no identifiable dns message emerged from given payload. "
"this should not happen at all. networking issue?"
)
if len(start_packet_idx) == 1:
return (payload,)
start_packet_idx = sorted(start_packet_idx)
previous_idx = None
for idx in start_packet_idx:
if previous_idx is None:
previous_idx = idx
continue
results.append(payload[previous_idx:idx])
previous_idx = idx
results.append(payload[previous_idx:])
return tuple(results)
def is_ipv4(addr: str) -> bool:
try:
socket.inet_aton(addr)
return True
except OSError:
return False
def is_ipv6(addr: str) -> bool:
try:
socket.inet_pton(socket.AF_INET6, addr)
return True
except OSError:
return False
def validate_length_of(hostname: str) -> None:
"""RFC 1035 impose a limit on a domain name length. We verify it there."""
if len(hostname.strip(".")) > 253:
raise UnicodeError("hostname to resolve exceed 253 characters")
elif any([len(_) > 63 for _ in hostname.split(".")]):
raise UnicodeError("at least one label to resolve exceed 63 characters")
def rfc1035_should_read(payload: bytes) -> bool:
if not payload:
return False
if len(payload) <= 2:
return True
cursor = payload
while True:
expected_size: int = struct.unpack("!H", cursor[:2])[0]
if len(cursor[2:]) == expected_size:
return False
elif len(cursor[2:]) < expected_size:
return True
cursor = cursor[2 + expected_size :]
def rfc1035_unpack(payload: bytes) -> tuple[bytes, ...]:
cursor = payload
packets = []
while cursor:
expected_size: int = struct.unpack("!H", cursor[:2])[0]
packets.append(cursor[2 : 2 + expected_size])
cursor = cursor[2 + expected_size :]
return tuple(packets)
def rfc1035_pack(message: bytes) -> bytes:
return struct.pack("!H", len(message)) + message
def read_name(data: bytes, offset: int) -> tuple[str, int]:
"""
Read a DNSencoded name (with compression pointers) from data[offset:].
Returns (name, new_offset).
"""
labels = []
while True:
length = data[offset]
# compression pointer?
if length & 0xC0 == 0xC0:
pointer = struct.unpack_from("!H", data, offset)[0] & 0x3FFF
subname, _ = read_name(data, pointer)
labels.append(subname)
offset += 2
break
if length == 0:
offset += 1
break
offset += 1
labels.append(data[offset : offset + length].decode())
offset += length
return ".".join(labels), offset
def parse_echconfigs(buf: bytes) -> list[str]:
"""
buf is the raw bytes of the ECHConfig vector:
- 2-byte total length, then for each:
- 2-byte cfg length + that many bytes of cfg
We return a list of Base64 strings (one per config).
"""
if len(buf) < 2:
return []
off = 2
total = struct.unpack_from("!H", buf, 0)[0]
end = 2 + total
out = []
while off + 2 <= end:
cfg_len = struct.unpack_from("!H", buf, off)[0]
off += 2
cfg = buf[off : off + cfg_len]
off += cfg_len
out.append(base64.b64encode(cfg).decode())
return out
def parse_https_rdata(rdata: bytes) -> HttpsRecord:
"""
Parse the RDATA of an SVCB/HTTPS record.
Returns a dict with keys: priority, target, alpn, ipv4hint, ipv6hint, echconfig.
"""
off = 0
priority = struct.unpack_from("!H", rdata, off)[0]
off += 2
target, off = read_name(rdata, off)
# pull out all the key/value params
params = {}
while off + 4 <= len(rdata):
key, length = struct.unpack_from("!HH", rdata, off)
off += 4
params[key] = rdata[off : off + length]
off += length
# decode ALPN (key=1), IPv4 (4), IPv6 (6), ECHConfig (5)
def parse_alpn(buf: bytes) -> list[str]:
out = []
i: int = 0
while i < len(buf):
ln = buf[i]
out.append(buf[i + 1 : i + 1 + ln].decode())
i += 1 + ln
return out
alpn: list[str] = parse_alpn(params.get(1, b""))
ipv4 = [
inet4_ntoa(params[4][i : i + 4]) for i in range(0, len(params.get(4, b"")), 4)
]
ipv6 = [
inet6_ntoa(params[6][i : i + 16]) for i in range(0, len(params.get(6, b"")), 16)
]
echconfs = parse_echconfigs(params.get(5, b""))
return {
"priority": priority,
"target": target or ".", # empty name → root
"alpn": alpn,
"ipv4hint": ipv4,
"ipv6hint": ipv6,
"echconfig": echconfs,
}
__all__ = (
"inet4_ntoa",
"inet6_ntoa",
"packet_fragment",
"is_ipv4",
"is_ipv6",
"validate_length_of",
"rfc1035_pack",
"rfc1035_unpack",
"rfc1035_should_read",
"parse_https_rdata",
)