fix: 포트 충돌 회피 — note_bridge 8098, intent_service 8099

Jellyfin(8096), OrbStack(8097) 포트 충돌으로 변경.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Hyungi Ahn
2026-03-19 13:53:55 +09:00
parent dc08d29509
commit c2257d3a86
2709 changed files with 619549 additions and 10 deletions

View File

@@ -0,0 +1,11 @@
from __future__ import annotations
from .factories import AsyncResolverDescription, AsyncResolverFactory
from .protocols import AsyncBaseResolver, AsyncManyResolver
__all__ = (
"AsyncResolverDescription",
"AsyncResolverFactory",
"AsyncBaseResolver",
"AsyncManyResolver",
)

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
from ._urllib3 import (
AdGuardResolver,
CloudflareResolver,
GoogleResolver,
HTTPSResolver,
NextDNSResolver,
OpenDNSResolver,
Quad9Resolver,
)
__all__ = (
"HTTPSResolver",
"GoogleResolver",
"CloudflareResolver",
"AdGuardResolver",
"OpenDNSResolver",
"Quad9Resolver",
"NextDNSResolver",
)

View File

@@ -0,0 +1,656 @@
from __future__ import annotations
import socket
import typing
from asyncio import as_completed
from base64 import b64encode
from ....._async.connectionpool import AsyncHTTPSConnectionPool
from ....._async.response import AsyncHTTPResponse
from ....._collections import HTTPHeaderDict
from .....backend import ConnectionInfo, HttpVersion
from .....util.url import parse_url
from ...protocols import (
DomainNameServerQuery,
DomainNameServerReturn,
ProtocolResolver,
SupportedQueryType,
)
from ...utils import is_ipv4, is_ipv6, validate_length_of, parse_https_rdata
from ..protocols import AsyncBaseResolver
class HTTPSResolver(AsyncBaseResolver):
"""
Advanced DNS over HTTPS resolver.
No common ground emerged from IETF w/ JSON. Following Googles DNS over HTTPS schematics that is
also implemented at Cloudflare.
Support RFC 8484 without JSON. Disabled by default.
"""
implementation = "urllib3"
protocol = ProtocolResolver.DOH
def __init__(
self,
server: str | None,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
) -> None:
super().__init__(server, port or 443, *patterns, **kwargs)
self._path: str = "/resolve"
if "path" in kwargs:
if isinstance(kwargs["path"], str) and kwargs["path"] != "/":
self._path = kwargs["path"]
kwargs.pop("path")
self._rfc8484: bool = False
if "rfc8484" in kwargs:
if kwargs["rfc8484"]:
self._rfc8484 = True
kwargs.pop("rfc8484")
assert self._server is not None
if "source_address" in kwargs:
if isinstance(kwargs["source_address"], str):
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
if bind_ip and bind_port.isdigit():
kwargs["source_address"] = (
bind_ip,
int(bind_port),
)
else:
raise ValueError("invalid source_address given in parameters")
else:
raise ValueError("invalid source_address given in parameters")
if "proxy" in kwargs:
kwargs["_proxy"] = parse_url(kwargs["proxy"])
kwargs.pop("proxy")
if "maxsize" not in kwargs:
kwargs["maxsize"] = 10
if "proxy_headers" in kwargs and "_proxy" in kwargs:
proxy_headers = HTTPHeaderDict()
if not isinstance(kwargs["proxy_headers"], list):
kwargs["proxy_headers"] = [kwargs["proxy_headers"]]
for item in kwargs["proxy_headers"]:
if ":" not in item:
raise ValueError("Passed header is invalid in DNS parameters")
k, v = item.split(":", 1)
proxy_headers.add(k, v)
kwargs["_proxy_headers"] = proxy_headers
if "headers" in kwargs:
headers = HTTPHeaderDict()
if not isinstance(kwargs["headers"], list):
kwargs["headers"] = [kwargs["headers"]]
for item in kwargs["headers"]:
if ":" not in item:
raise ValueError("Passed header is invalid in DNS parameters")
k, v = item.split(":", 1)
headers.add(k, v)
kwargs["headers"] = headers
if "disabled_svn" in kwargs:
if not isinstance(kwargs["disabled_svn"], list):
kwargs["disabled_svn"] = [kwargs["disabled_svn"]]
disabled_svn = set()
for svn in kwargs["disabled_svn"]:
svn = svn.lower()
if svn == "h11":
disabled_svn.add(HttpVersion.h11)
elif svn == "h2":
disabled_svn.add(HttpVersion.h2)
elif svn == "h3":
disabled_svn.add(HttpVersion.h3)
kwargs["disabled_svn"] = disabled_svn
if "on_post_connection" in kwargs and callable(kwargs["on_post_connection"]):
self._connection_callback: (
typing.Callable[[ConnectionInfo], None] | None
) = kwargs["on_post_connection"]
kwargs.pop("on_post_connection")
else:
self._connection_callback = None
self._pool = AsyncHTTPSConnectionPool(self._server, self._port, **kwargs)
async def close(self) -> None: # type: ignore[override]
await self._pool.close()
def is_available(self) -> bool:
return self._pool.pool is not None
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Tried to resolve 'localhost' from a HTTPSResolver"
)
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
validate_length_of(host)
promises = []
remote_preemptive_quic_rr = False
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
quic_upgrade_via_dns_rr = False
if family in [socket.AF_UNSPEC, socket.AF_INET]:
if not self._rfc8484:
promises.append(
await self._pool.request_encode_url(
"GET",
self._path,
{"name": host, "type": "1"},
headers={"Accept": "application/dns-json"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
else:
dns_query = DomainNameServerQuery(
host, SupportedQueryType.A, override_id=0
)
dns_payload = bytes(dns_query)
promises.append(
await self._pool.request_encode_url(
"GET",
self._path,
{
"dns": b64encode(dns_payload).decode().replace("=", ""),
},
headers={"Accept": "application/dns-message"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
if not self._rfc8484:
promises.append(
await self._pool.request_encode_url(
"GET",
self._path,
{"name": host, "type": "28"},
headers={"Accept": "application/dns-json"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
else:
dns_query = DomainNameServerQuery(
host, SupportedQueryType.AAAA, override_id=0
)
dns_payload = bytes(dns_query)
promises.append(
await self._pool.request_encode_url(
"GET",
self._path,
{
"dns": b64encode(dns_payload).decode().replace("=", ""),
},
headers={"Accept": "application/dns-message"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
if quic_upgrade_via_dns_rr:
if not self._rfc8484:
promises.append(
await self._pool.request_encode_url(
"GET",
self._path,
{"name": host, "type": "65"},
headers={"Accept": "application/dns-json"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
else:
dns_query = DomainNameServerQuery(
host, SupportedQueryType.HTTPS, override_id=0
)
dns_payload = bytes(dns_query)
promises.append(
await self._pool.request_encode_url(
"GET",
self._path,
{
"dns": b64encode(dns_payload).decode().replace("=", ""),
},
headers={"Accept": "application/dns-message"},
on_post_connection=self._connection_callback,
multiplexed=True,
)
)
tasks = []
responses = []
for promise in promises:
# already resolved
if isinstance(promise, AsyncHTTPResponse):
responses.append(promise)
continue
tasks.append(self._pool.get_response(promise=promise))
if tasks:
for waiting_promise_coro in as_completed(tasks):
responses.append(await waiting_promise_coro) # type: ignore[arg-type]
results: list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
] = []
for response in responses:
if response.status >= 300:
raise socket.gaierror(
f"DNS over HTTPS was unsuccessful, server response status {response.status}."
)
if not self._rfc8484:
payload = await response.json()
assert "Status" in payload and isinstance(payload["Status"], int)
if payload["Status"] != 0:
msg = (
payload["Comment"]
if "Comment" in payload
else f"Remote DNS indicated that an error occurred while providing resolution. Status {payload['Status']}."
)
if isinstance(msg, list):
msg = ", ".join(msg)
raise socket.gaierror(msg)
assert "Question" in payload and isinstance(payload["Question"], list)
if "Answer" not in payload:
continue
assert isinstance(payload["Answer"], list)
for answer in payload["Answer"]:
if answer["type"] not in [1, 28, 65]:
continue
assert "data" in answer
assert isinstance(answer["data"], str)
# DNS RR/HTTPS
if answer["type"] == 65:
# "1 . alpn=h3,h2 ipv4hint=104.16.132.229,104.16.133.229 ipv6hint=2606:4700::6810:84e5,2606:4700::6810:85e5"
# or..
# "1 . alpn=h2,h3"
rr: str = answer["data"]
if rr.startswith("\\#"): # it means, raw, bytes.
rr = "".join(rr[2:].split(" ")[2:])
try:
raw_record = bytes.fromhex(rr)
except ValueError:
raw_record = b""
if not raw_record:
continue
https_record = parse_https_rdata(raw_record)
if "h3" not in https_record["alpn"]:
continue
remote_preemptive_quic_rr = True
else:
rr_decode: dict[str, str] = dict(
tuple(_.lower().split("=", 1)) # type: ignore[misc]
for _ in rr.split(" ")
if "=" in _
)
if "alpn" not in rr_decode or "h3" not in rr_decode["alpn"]:
continue
remote_preemptive_quic_rr = True
if "ipv4hint" in rr_decode and family in [
socket.AF_UNSPEC,
socket.AF_INET,
]:
for ipv4 in rr_decode["ipv4hint"].split(","):
results.append(
(
socket.AF_INET,
socket.SOCK_DGRAM,
17,
"",
(
ipv4,
port,
),
)
)
if "ipv6hint" in rr_decode and family in [
socket.AF_UNSPEC,
socket.AF_INET6,
]:
for ipv6 in rr_decode["ipv6hint"].split(","):
results.append(
(
socket.AF_INET6,
socket.SOCK_DGRAM,
17,
"",
(
ipv6,
port,
0,
0,
),
)
)
continue
inet_type = (
socket.AF_INET if answer["type"] == 1 else socket.AF_INET6
)
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
(
answer["data"],
port,
)
if inet_type == socket.AF_INET
else (
answer["data"],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
else:
dns_resp = DomainNameServerReturn(await response.data)
for record in dns_resp.records:
if record[0] == SupportedQueryType.HTTPS:
assert isinstance(record[-1], dict)
if "h3" in record[-1]["alpn"]:
remote_preemptive_quic_rr = True
continue
assert not isinstance(record[-1], dict)
inet_type = (
socket.AF_INET
if record[0] == SupportedQueryType.A
else socket.AF_INET6
)
dst_addr = (
(
record[-1],
port,
)
if inet_type == socket.AF_INET
else (
record[-1],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
quic_results: list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
] = []
if remote_preemptive_quic_rr:
any_specified = False
for result in results:
if result[1] == socket.SOCK_STREAM:
quic_results.append(
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
)
else:
any_specified = True
break
if any_specified:
quic_results = []
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
class GoogleResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "google"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
if "rfc8484" in kwargs:
if kwargs["rfc8484"]:
kwargs["path"] = "/dns-query"
super().__init__("dns.google", port, *patterns, **kwargs)
class CloudflareResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "cloudflare"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query"})
super().__init__("cloudflare-dns.com", port, *patterns, **kwargs)
class AdGuardResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query", "rfc8484": True})
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
class OpenDNSResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "opendns"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query", "rfc8484": True})
super().__init__("dns.opendns.com", port, *patterns, **kwargs)
class Quad9Resolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "quad9"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
kwargs.update({"path": "/dns-query", "rfc8484": True})
super().__init__("dns11.quad9.net", port, *patterns, **kwargs)
class NextDNSResolver(
HTTPSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "nextdns"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.nextdns.io", port, *patterns, **kwargs)

View File

@@ -0,0 +1,15 @@
from __future__ import annotations
try:
from ._qh3 import AdGuardResolver, NextDNSResolver, QUICResolver
except ImportError:
QUICResolver = None # type: ignore
AdGuardResolver = None # type: ignore
NextDNSResolver = None # type: ignore
__all__ = (
"QUICResolver",
"AdGuardResolver",
"NextDNSResolver",
)

View File

@@ -0,0 +1,557 @@
from __future__ import annotations
import asyncio
import socket
import ssl
import typing
from collections import deque
from ssl import SSLError
from time import time as monotonic
from qh3.quic.configuration import QuicConfiguration
from qh3.quic.connection import QuicConnection
from qh3.quic.events import (
ConnectionTerminated,
HandshakeCompleted,
QuicEvent,
StopSendingReceived,
StreamDataReceived,
StreamReset,
)
from .....util.ssl_ import IS_FIPS, resolve_cert_reqs
from ...protocols import (
COMMON_RCODE_LABEL,
DomainNameServerQuery,
DomainNameServerReturn,
ProtocolResolver,
SupportedQueryType,
)
from ...utils import (
is_ipv4,
is_ipv6,
rfc1035_pack,
rfc1035_should_read,
rfc1035_unpack,
validate_length_of,
)
from ..dou import PlainResolver
from ..system import SystemResolver
if IS_FIPS:
raise ImportError(
"DNS-over-QUIC disabled when Python is built with FIPS-compliant ssl module"
)
class QUICResolver(PlainResolver):
protocol = ProtocolResolver.DOQ
implementation = "qh3"
def __init__(
self,
server: str | None,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
):
super().__init__(server, port or 853, *patterns, **kwargs)
# qh3 load_default_certs seems off. need to investigate.
if "ca_cert_data" not in kwargs and "ca_certs" not in kwargs:
kwargs["ca_cert_data"] = []
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
try:
ctx.load_default_certs()
for der in ctx.get_ca_certs(binary_form=True):
kwargs["ca_cert_data"].append(ssl.DER_cert_to_PEM_cert(der))
if kwargs["ca_cert_data"]:
kwargs["ca_cert_data"] = "".join(kwargs["ca_cert_data"])
else:
del kwargs["ca_cert_data"]
except (AttributeError, ValueError, OSError):
del kwargs["ca_cert_data"]
if "ca_cert_data" not in kwargs and "ca_certs" not in kwargs:
if (
"cert_reqs" not in kwargs
or resolve_cert_reqs(kwargs["cert_reqs"]) is ssl.CERT_REQUIRED
):
raise ssl.SSLError(
"DoQ requires at least one CA loaded in order to verify the remote peer certificate. "
"Add ?cert_reqs=0 to disable certificate checks."
)
configuration = QuicConfiguration(
is_client=True,
alpn_protocols=["doq"],
server_name=self._server
if "server_hostname" not in kwargs
else kwargs["server_hostname"],
verify_mode=resolve_cert_reqs(kwargs["cert_reqs"])
if "cert_reqs" in kwargs
else ssl.CERT_REQUIRED,
cadata=kwargs["ca_cert_data"].encode()
if "ca_cert_data" in kwargs
else None,
cafile=kwargs["ca_certs"] if "ca_certs" in kwargs else None,
idle_timeout=300.0,
)
if "cert_file" in kwargs:
configuration.load_cert_chain(
kwargs["cert_file"],
kwargs["key_file"] if "key_file" in kwargs else None,
kwargs["key_password"] if "key_password" in kwargs else None,
)
elif "cert_data" in kwargs:
configuration.load_cert_chain(
kwargs["cert_data"],
kwargs["key_data"] if "key_data" in kwargs else None,
kwargs["key_password"] if "key_password" in kwargs else None,
)
self._quic = QuicConnection(configuration=configuration)
self._read_semaphore: asyncio.Semaphore = asyncio.Semaphore()
self._connect_attempt: asyncio.Event = asyncio.Event()
self._handshake_event: asyncio.Event = asyncio.Event()
self._terminated: bool = False
self._should_disconnect: bool = False
# DNS over QUIC mandate the size-prefix (unsigned int, 2b)
self._rfc1035_prefix_mandated = True
self._unconsumed: deque[DomainNameServerReturn] = deque()
self._pending: deque[DomainNameServerQuery] = deque()
async def close(self) -> None: # type: ignore[override]
if (
not self._terminated
and self._socket is not None
and not self._socket.should_connect()
):
self._quic.close()
while True:
datagrams = self._quic.datagrams_to_send(monotonic())
if not datagrams:
break
for datagram in datagrams:
data, addr = datagram
await self._socket.sendall(data)
self._socket.close()
await self._socket.wait_for_close()
self._terminated = True
if self._socket is None or self._socket.should_connect():
self._terminated = True
def is_available(self) -> bool:
if self._terminated:
return False
if self._socket is None or self._socket.should_connect():
return True
self._quic.handle_timer(monotonic())
if hasattr(self._quic, "_close_event") and self._quic._close_event is not None:
self._terminated = True
return not self._terminated
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Tried to resolve 'localhost' using the QUICResolver"
)
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
validate_length_of(host)
if self._socket is None and self._connect_attempt.is_set() is False:
self._connect_attempt.set()
assert self.server is not None
self._quic.connect((self._server, self._port), monotonic())
self._socket = await SystemResolver().create_connection(
(self.server, self.port or 853),
timeout=self._timeout,
source_address=self._source_address,
socket_options=None,
socket_kind=self._socket_type,
)
await self.__exchange_until(HandshakeCompleted, receive_first=False)
self._handshake_event.set()
else:
await self._handshake_event.wait()
assert self._socket is not None
remote_preemptive_quic_rr = False
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
quic_upgrade_via_dns_rr = False
tbq = []
if family in [socket.AF_UNSPEC, socket.AF_INET]:
tbq.append(SupportedQueryType.A)
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
tbq.append(SupportedQueryType.AAAA)
if quic_upgrade_via_dns_rr:
tbq.append(SupportedQueryType.HTTPS)
queries = DomainNameServerQuery.bulk(host, *tbq)
open_streams = []
for q in queries:
payload = bytes(q)
self._pending.append(q)
if self._rfc1035_prefix_mandated is True:
payload = rfc1035_pack(payload)
stream_id = self._quic.get_next_available_stream_id()
self._quic.send_stream_data(stream_id, payload, True)
open_streams.append(stream_id)
for dg in self._quic.datagrams_to_send(monotonic()):
await self._socket.sendall(dg[0])
responses: list[DomainNameServerReturn] = []
while len(responses) < len(tbq):
await self._read_semaphore.acquire()
if self._unconsumed:
dns_resp = None
for query in queries:
for unconsumed in self._unconsumed:
if unconsumed.id == query.id:
dns_resp = unconsumed
responses.append(dns_resp)
break
if dns_resp:
break
if dns_resp:
self._unconsumed.remove(dns_resp)
self._pending.remove(query)
self._read_semaphore.release()
continue
try:
events: list[StreamDataReceived] = await self.__exchange_until( # type: ignore[assignment]
StreamDataReceived,
receive_first=True,
event_type_collectable=(StreamDataReceived,),
respect_end_stream_signal=False,
)
payload = b"".join([e.data for e in events])
while rfc1035_should_read(payload):
events.extend(
await self.__exchange_until( # type: ignore[arg-type]
StreamDataReceived,
receive_first=True,
event_type_collectable=(StreamDataReceived,),
respect_end_stream_signal=False,
)
)
payload = b"".join([e.data for e in events])
except (TimeoutError, OSError, socket.timeout, ConnectionError) as e:
raise socket.gaierror(
"Got unexpectedly disconnected while waiting for name resolution"
) from e
self._read_semaphore.release()
if not payload:
continue
#: We can receive two responses at once (or more, concatenated). Let's unwrap them.
fragments = rfc1035_unpack(payload)
for fragment in fragments:
dns_resp = DomainNameServerReturn(fragment)
if any(dns_resp.id == _.id for _ in queries):
responses.append(dns_resp)
query_tbr: DomainNameServerQuery | None = None
for query_tbr in self._pending:
if query_tbr.id == dns_resp.id:
break
if query_tbr:
self._pending.remove(query_tbr)
else:
self._unconsumed.append(dns_resp)
if self._should_disconnect:
await self.close()
self._should_disconnect = False
self._terminated = True
results = []
for response in responses:
if not response.is_ok:
if response.rcode == 2:
raise socket.gaierror(
f"DNSSEC validation failure. Check http://dnsviz.net/d/{host}/dnssec/ and http://dnssec-debugger.verisignlabs.com/{host} for errors"
)
raise socket.gaierror(
f"DNS returned an error: {COMMON_RCODE_LABEL[response.rcode] if response.rcode in COMMON_RCODE_LABEL else f'code {response.rcode}'}"
)
for record in response.records:
if record[0] == SupportedQueryType.HTTPS:
assert isinstance(record[-1], dict)
if "h3" in record[-1]["alpn"]:
remote_preemptive_quic_rr = True
continue
assert not isinstance(record[-1], dict)
inet_type = (
socket.AF_INET
if record[0] == SupportedQueryType.A
else socket.AF_INET6
)
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
(
record[-1],
port,
)
if inet_type == socket.AF_INET
else (
record[-1],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
quic_results = []
if remote_preemptive_quic_rr:
any_specified = False
for result in results:
if result[1] == socket.SOCK_STREAM:
quic_results.append(
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
)
else:
any_specified = True
break
if any_specified:
quic_results = []
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
async def __exchange_until(
self,
event_type: type[QuicEvent] | tuple[type[QuicEvent], ...],
*,
receive_first: bool = False,
event_type_collectable: type[QuicEvent]
| tuple[type[QuicEvent], ...]
| None = None,
respect_end_stream_signal: bool = True,
) -> list[QuicEvent]:
assert self._socket is not None
while True:
if receive_first is False:
now = monotonic()
while True:
datagrams = self._quic.datagrams_to_send(now)
if not datagrams:
break
for datagram in datagrams:
data, addr = datagram
await self._socket.sendall(data)
events = []
while True:
if not self._quic._events:
data_in = await self._socket.recv(1500)
if not data_in:
break
now = monotonic()
if not isinstance(data_in, list):
self._quic.receive_datagram(
data_in, (self._server, self._port), now
)
else:
for gro_segment in data_in:
self._quic.receive_datagram(
gro_segment, (self._server, self._port), now
)
while True:
datagrams = self._quic.datagrams_to_send(now)
if not datagrams:
break
for datagram in datagrams:
data, addr = datagram
await self._socket.sendall(data)
for ev in iter(self._quic.next_event, None):
if isinstance(ev, ConnectionTerminated):
if ev.error_code == 298:
raise SSLError(
"DNS over QUIC did not succeed (Error 298). Chain certificate verification failed."
)
raise socket.gaierror(
f"DNS over QUIC encountered a unrecoverable failure (error {ev.error_code} {ev.reason_phrase})"
)
elif isinstance(ev, StreamReset):
self._terminated = True
raise socket.gaierror(
"DNS over QUIC server submitted a StreamReset. A request was rejected."
)
elif isinstance(ev, StopSendingReceived):
self._should_disconnect = True
continue
if event_type_collectable:
if isinstance(ev, event_type_collectable):
events.append(ev)
else:
events.append(ev)
if isinstance(ev, event_type):
if not respect_end_stream_signal:
return events
if hasattr(ev, "stream_ended") and ev.stream_ended:
return events
elif hasattr(ev, "stream_ended") is False:
return events
return events
class AdGuardResolver(
QUICResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
class NextDNSResolver(
QUICResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "nextdns"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.nextdns.io", port, *patterns, **kwargs)

View File

@@ -0,0 +1,19 @@
from __future__ import annotations
from ._ssl import (
AdGuardResolver,
CloudflareResolver,
GoogleResolver,
OpenDNSResolver,
Quad9Resolver,
TLSResolver,
)
__all__ = (
"TLSResolver",
"GoogleResolver",
"CloudflareResolver",
"AdGuardResolver",
"Quad9Resolver",
"OpenDNSResolver",
)

View File

@@ -0,0 +1,197 @@
from __future__ import annotations
import socket
import typing
from .....util._async.ssl_ import ssl_wrap_socket
from .....util.ssl_ import resolve_cert_reqs
from ...protocols import ProtocolResolver
from ..dou import PlainResolver
from ..system import SystemResolver
class TLSResolver(PlainResolver):
"""
Basic DNS resolver over TLS.
Comply with RFC 7858: https://datatracker.ietf.org/doc/html/rfc7858
"""
protocol = ProtocolResolver.DOT
implementation = "ssl"
def __init__(
self,
server: str | None,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
) -> None:
self._socket_type = socket.SOCK_STREAM
super().__init__(server, port or 853, *patterns, **kwargs)
# DNS over TLS mandate the size-prefix (unsigned int, 2 bytes)
self._rfc1035_prefix_mandated = True
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if self._socket is None and self._connect_attempt.is_set() is False:
assert self.server is not None
self._connect_attempt.set()
self._socket = await SystemResolver().create_connection(
(self.server, self.port or 853),
timeout=self._timeout,
source_address=self._source_address,
socket_options=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1, "tcp"),),
socket_kind=self._socket_type,
)
self._socket = await ssl_wrap_socket(
self._socket,
server_hostname=self.server
if "server_hostname" not in self._kwargs
else self._kwargs["server_hostname"],
keyfile=self._kwargs["key_file"]
if "key_file" in self._kwargs
else None,
certfile=self._kwargs["cert_file"]
if "cert_file" in self._kwargs
else None,
cert_reqs=resolve_cert_reqs(self._kwargs["cert_reqs"])
if "cert_reqs" in self._kwargs
else None,
ca_certs=self._kwargs["ca_certs"]
if "ca_certs" in self._kwargs
else None,
ssl_version=self._kwargs["ssl_version"]
if "ssl_version" in self._kwargs
else None,
ciphers=self._kwargs["ciphers"] if "ciphers" in self._kwargs else None,
ca_cert_dir=self._kwargs["ca_cert_dir"]
if "ca_cert_dir" in self._kwargs
else None,
key_password=self._kwargs["key_password"]
if "key_password" in self._kwargs
else None,
ca_cert_data=self._kwargs["ca_cert_data"]
if "ca_cert_data" in self._kwargs
else None,
certdata=self._kwargs["cert_data"]
if "cert_data" in self._kwargs
else None,
keydata=self._kwargs["key_data"]
if "key_data" in self._kwargs
else None,
)
self._connect_finalized.set()
return await super().getaddrinfo(
host,
port,
family=family,
type=type,
proto=proto,
flags=flags,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
class GoogleResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "google"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.google", port, *patterns, **kwargs)
class CloudflareResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "cloudflare"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("1.1.1.1", port, *patterns, **kwargs)
class AdGuardResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
class OpenDNSResolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "opendns"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns.opendns.com", port, *patterns, **kwargs)
class Quad9Resolver(
TLSResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "quad9"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("dns11.quad9.net", port, *patterns, **kwargs)

View File

@@ -0,0 +1,17 @@
from __future__ import annotations
from ._socket import (
AdGuardResolver,
CloudflareResolver,
GoogleResolver,
PlainResolver,
Quad9Resolver,
)
__all__ = (
"PlainResolver",
"CloudflareResolver",
"GoogleResolver",
"Quad9Resolver",
"AdGuardResolver",
)

View File

@@ -0,0 +1,431 @@
from __future__ import annotations
import asyncio
import socket
import typing
from collections import deque
from ....ssa import AsyncSocket
from ...protocols import (
COMMON_RCODE_LABEL,
DomainNameServerQuery,
DomainNameServerReturn,
ProtocolResolver,
SupportedQueryType,
)
from ...utils import (
is_ipv4,
is_ipv6,
packet_fragment,
rfc1035_pack,
rfc1035_should_read,
rfc1035_unpack,
validate_length_of,
)
from ..protocols import AsyncBaseResolver
from ..system import SystemResolver
class PlainResolver(AsyncBaseResolver):
"""
Minimalist DNS resolver over UDP
Comply with RFC 1035: https://datatracker.ietf.org/doc/html/rfc1035
EDNS is not supported, yet. But we plan to. Willing to contribute?
"""
protocol = ProtocolResolver.DOU
implementation = "socket"
def __init__(
self,
server: str | None,
port: int | None = None,
*patterns: str,
**kwargs: typing.Any,
) -> None:
super().__init__(server, port, *patterns, **kwargs)
self._socket: AsyncSocket | None = None
if not hasattr(self, "_socket_type"):
self._socket_type = socket.SOCK_DGRAM
if "source_address" in kwargs and isinstance(kwargs["source_address"], str):
if ":" in kwargs["source_address"]:
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
self._source_address: tuple[str, int] | None = (bind_ip, int(bind_port))
else:
self._source_address = (kwargs["source_address"], 0)
else:
self._source_address = None
if "timeout" in kwargs and isinstance(
kwargs["timeout"],
(
float,
int,
),
):
self._timeout: float | int | None = kwargs["timeout"]
else:
self._timeout = None
#: Only useful for inheritance, e.g. DNS over TLS support dns-message but require a prefix.
self._rfc1035_prefix_mandated: bool = False
self._unconsumed: deque[DomainNameServerReturn] = deque()
self._pending: deque[DomainNameServerQuery] = deque()
self._read_semaphore: asyncio.Semaphore = asyncio.Semaphore()
self._connect_attempt: asyncio.Event = asyncio.Event()
self._connect_finalized: asyncio.Event = asyncio.Event()
self._terminated: bool = False
async def close(self) -> None: # type: ignore[override]
if not self._terminated:
with self._lock:
if self._socket is not None:
self._socket.close()
await self._socket.wait_for_close()
self._terminated = True
def is_available(self) -> bool:
return not self._terminated
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Tried to resolve 'localhost' from a PlainResolver"
)
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
validate_length_of(host)
if self._socket is None and self._connect_attempt.is_set() is False:
self._connect_attempt.set()
assert self.server is not None
self._socket = await SystemResolver().create_connection(
(self.server, self.port or 53),
timeout=self._timeout,
source_address=self._source_address,
socket_options=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1, "tcp"),),
socket_kind=self._socket_type,
)
self._connect_finalized.set()
else:
await self._connect_finalized.wait()
assert self._socket is not None
await self._socket.wait_for_readiness()
remote_preemptive_quic_rr = False
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
quic_upgrade_via_dns_rr = False
tbq = []
if family in [socket.AF_UNSPEC, socket.AF_INET]:
tbq.append(SupportedQueryType.A)
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
tbq.append(SupportedQueryType.AAAA)
if quic_upgrade_via_dns_rr:
tbq.append(SupportedQueryType.HTTPS)
queries = DomainNameServerQuery.bulk(host, *tbq)
for q in queries:
payload = bytes(q)
self._pending.append(q)
if self._rfc1035_prefix_mandated is True:
payload = rfc1035_pack(payload)
await self._socket.sendall(payload)
responses: list[DomainNameServerReturn] = []
while len(responses) < len(tbq):
await self._read_semaphore.acquire()
#: There we want to verify if another thread got a response that belong to this thread.
if self._unconsumed:
dns_resp = None
for query in queries:
for unconsumed in self._unconsumed:
if unconsumed.id == query.id:
dns_resp = unconsumed
responses.append(dns_resp)
break
if dns_resp:
break
if dns_resp:
self._pending.remove(query)
self._unconsumed.remove(dns_resp)
self._read_semaphore.release()
continue
try:
data_in_or_segments = await self._socket.recv(1500)
if isinstance(data_in_or_segments, list):
payloads = data_in_or_segments
elif data_in_or_segments:
payloads = [data_in_or_segments]
else:
payloads = []
if self._rfc1035_prefix_mandated is True and payloads:
payload = b"".join(payloads)
while rfc1035_should_read(payload):
extra = await self._socket.recv(1500)
if isinstance(extra, list):
payload += b"".join(extra)
else:
payload += extra
payloads = [payload]
except (TimeoutError, OSError, socket.timeout, ConnectionError) as e:
raise socket.gaierror(
"Got unexpectedly disconnected while waiting for name resolution"
) from e
self._read_semaphore.release()
if not payloads:
self._terminated = True
raise socket.gaierror(
"Got unexpectedly disconnected while waiting for name resolution"
)
pending_raw_identifiers = [_.raw_id for _ in self._pending]
for payload in payloads:
#: We can receive two responses at once (or more, concatenated). Let's unwrap them.
if self._rfc1035_prefix_mandated is True:
fragments = rfc1035_unpack(payload)
else:
fragments = packet_fragment(payload, *pending_raw_identifiers)
for fragment in fragments:
dns_resp = DomainNameServerReturn(fragment)
if any(dns_resp.id == _.id for _ in queries):
responses.append(dns_resp)
query_tbr: DomainNameServerQuery | None = None
for query_tbr in self._pending:
if query_tbr.id == dns_resp.id:
break
if query_tbr:
self._pending.remove(query_tbr)
else:
self._unconsumed.append(dns_resp)
results = []
for response in responses:
if not response.is_ok:
if response.rcode == 2:
raise socket.gaierror(
f"DNSSEC validation failure. Check http://dnsviz.net/d/{host}/dnssec/ and http://dnssec-debugger.verisignlabs.com/{host} for errors"
)
raise socket.gaierror(
f"DNS returned an error: {COMMON_RCODE_LABEL[response.rcode] if response.rcode in COMMON_RCODE_LABEL else f'code {response.rcode}'}"
)
for record in response.records:
if record[0] == SupportedQueryType.HTTPS:
assert isinstance(record[-1], dict)
if "h3" in record[-1]["alpn"]:
remote_preemptive_quic_rr = True
continue
assert not isinstance(record[-1], dict)
inet_type = (
socket.AF_INET
if record[0] == SupportedQueryType.A
else socket.AF_INET6
)
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
(
record[-1],
port,
)
if inet_type == socket.AF_INET
else (
record[-1],
port,
0,
0,
)
)
results.append(
(
inet_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
dst_addr,
)
)
quic_results = []
if remote_preemptive_quic_rr:
any_specified = False
for result in results:
if result[1] == socket.SOCK_STREAM:
quic_results.append(
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
)
else:
any_specified = True
break
if any_specified:
quic_results = []
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
class CloudflareResolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "cloudflare"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("1.1.1.1", port, *patterns, **kwargs)
class GoogleResolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "google"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("8.8.8.8", port, *patterns, **kwargs)
class Quad9Resolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "quad9"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("9.9.9.9", port, *patterns, **kwargs)
class AdGuardResolver(
PlainResolver
): # Defensive: we do not cover specific vendors/DNS shortcut
specifier = "adguard"
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
port = kwargs["port"]
kwargs.pop("port")
else:
port = None
super().__init__("94.140.14.140", port, *patterns, **kwargs)

View File

@@ -0,0 +1,243 @@
from __future__ import annotations
import importlib
import inspect
import typing
from abc import ABCMeta
from base64 import b64encode
from typing import Any
from urllib.parse import parse_qs
from ....util import parse_url
from ..factories import ResolverDescription
from ..protocols import ProtocolResolver
from .protocols import AsyncBaseResolver
class AsyncResolverFactory(metaclass=ABCMeta):
@staticmethod
def has(
protocol: ProtocolResolver,
specifier: str | None = None,
implementation: str | None = None,
) -> bool:
package_name: str = __name__.split(".")[0]
module_expr = f".{protocol.value.replace('-', '_')}"
if implementation:
module_expr += f"._{implementation.replace('-', '_').lower()}"
try:
resolver_module = importlib.import_module(
module_expr, f"{package_name}.contrib.resolver._async"
)
except ImportError:
return False
implementations: list[tuple[str, type[AsyncBaseResolver]]] = inspect.getmembers(
resolver_module,
lambda e: isinstance(e, type)
and issubclass(e, AsyncBaseResolver)
and (
(specifier is None and e.specifier is None) or specifier == e.specifier
),
)
if not implementations:
return False
return True
@staticmethod
def new(
protocol: ProtocolResolver,
specifier: str | None = None,
implementation: str | None = None,
**kwargs: Any,
) -> AsyncBaseResolver:
package_name: str = __name__.split(".")[0]
module_expr = f".{protocol.value.replace('-', '_')}"
if implementation:
module_expr += f"._{implementation.replace('-', '_').lower()}"
spe_msg = " " if specifier is None else f' (w/ specifier "{specifier}") '
try:
resolver_module = importlib.import_module(
module_expr, f"{package_name}.contrib.resolver._async"
)
except ImportError as e:
raise NotImplementedError(
f"{protocol}{spe_msg}cannot be loaded. Tried to import '{module_expr}'. Did you specify a non-existent implementation?"
) from e
implementations: list[tuple[str, type[AsyncBaseResolver]]] = inspect.getmembers(
resolver_module,
lambda e: isinstance(e, type)
and issubclass(e, AsyncBaseResolver)
and (
(specifier is None and e.specifier is None) or specifier == e.specifier
)
and hasattr(e, "protocol")
and e.protocol == protocol,
)
if not implementations:
raise NotImplementedError(
f"{protocol}{spe_msg}cannot be loaded. "
"No compatible implementation available. "
"Make sure your implementation inherit from BaseResolver."
)
implementation_target: type[AsyncBaseResolver] = implementations.pop()[1]
return implementation_target(**kwargs)
class AsyncResolverDescription(ResolverDescription):
"""Describe how a BaseResolver must be instantiated."""
def new(self) -> AsyncBaseResolver:
kwargs = {**self.kwargs}
if self.server:
kwargs["server"] = self.server
if self.port:
kwargs["port"] = self.port
if self.host_patterns:
kwargs["patterns"] = self.host_patterns
return AsyncResolverFactory.new(
self.protocol,
self.specifier,
self.implementation,
**kwargs,
)
@staticmethod
def from_url(url: str) -> AsyncResolverDescription:
parsed_url = parse_url(url)
schema = parsed_url.scheme
if schema is None:
raise ValueError("Given DNS url is missing a protocol")
specifier = None
implementation = None
if "+" in schema:
schema, specifier = tuple(schema.lower().split("+", 1))
protocol = ProtocolResolver(schema)
kwargs: dict[str, typing.Any] = {}
if parsed_url.path:
kwargs["path"] = parsed_url.path
if parsed_url.auth:
kwargs["headers"] = dict()
if ":" in parsed_url.auth:
username, password = parsed_url.auth.split(":")
username = username.strip("'\"")
password = password.strip("'\"")
kwargs["headers"]["Authorization"] = (
f"Basic {b64encode(f'{username}:{password}'.encode()).decode()}"
)
else:
kwargs["headers"]["Authorization"] = f"Bearer {parsed_url.auth}"
if parsed_url.query:
parameters = parse_qs(parsed_url.query)
for parameter in parameters:
if not parameters[parameter]:
continue
parameter_insensible = parameter.lower()
if (
isinstance(parameters[parameter], list)
and len(parameters[parameter]) > 1
):
if parameter == "implementation":
raise ValueError("Only one implementation can be passed to URL")
values = []
for e in parameters[parameter]:
if "," in e:
values.extend(e.split(","))
else:
values.append(e)
if parameter_insensible in kwargs:
if isinstance(kwargs[parameter_insensible], list):
kwargs[parameter_insensible].extend(values)
else:
values.append(kwargs[parameter_insensible])
kwargs[parameter_insensible] = values
continue
kwargs[parameter_insensible] = values
continue
value: str = parameters[parameter][0].lower().strip(" ")
if parameter == "implementation":
implementation = value
continue
if "," in value:
list_of_values = value.split(",")
if parameter_insensible in kwargs:
if isinstance(kwargs[parameter_insensible], list):
kwargs[parameter_insensible].extend(list_of_values)
else:
list_of_values.append(kwargs[parameter_insensible])
continue
kwargs[parameter_insensible] = list_of_values
continue
value_converted: bool | int | float | None = None
if value in ["false", "true"]:
value_converted = True if value == "true" else False
elif value.isdigit():
value_converted = int(value)
elif (
value.count(".") == 1
and value.index(".") > 0
and value.replace(".", "").isdigit()
):
value_converted = float(value)
kwargs[parameter_insensible] = (
value if value_converted is None else value_converted
)
host_patterns: list[str] = []
if "hosts" in kwargs:
host_patterns = (
kwargs["hosts"].split(",")
if isinstance(kwargs["hosts"], str)
else kwargs["hosts"]
)
del kwargs["hosts"]
return AsyncResolverDescription(
protocol,
specifier,
implementation,
parsed_url.host,
parsed_url.port,
*host_patterns,
**kwargs,
)

View File

@@ -0,0 +1,5 @@
from __future__ import annotations
from ._dict import InMemoryResolver
__all__ = ("InMemoryResolver",)

View File

@@ -0,0 +1,186 @@
from __future__ import annotations
import socket
import typing
from .....util.url import _IPV6_ADDRZ_RE
from ...protocols import ProtocolResolver
from ...utils import is_ipv4, is_ipv6
from ..protocols import AsyncBaseResolver
class InMemoryResolver(AsyncBaseResolver):
protocol = ProtocolResolver.MANUAL
implementation = "dict"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
kwargs.pop("port")
super().__init__(None, None, *patterns, **kwargs)
self._maxsize = 65535 if "maxsize" not in kwargs else int(kwargs["maxsize"])
self._hosts: dict[str, list[tuple[socket.AddressFamily, str]]] = {}
if self._host_patterns:
for record in self._host_patterns:
if ":" not in record:
continue
hostname, addr = record.split(":", 1)
self.register(hostname, addr)
self._host_patterns = tuple([])
def recycle(self) -> AsyncBaseResolver:
return self
async def close(self) -> None: # type: ignore[override]
pass # no-op
def is_available(self) -> bool:
return True
def have_constraints(self) -> bool:
return True
def support(self, hostname: str | bytes | None) -> bool | None:
if hostname is None:
hostname = "localhost"
if isinstance(hostname, bytes):
hostname = hostname.decode("ascii")
return hostname in self._hosts
def register(self, hostname: str, ipaddr: str) -> None:
if hostname not in self._hosts:
self._hosts[hostname] = []
else:
for e in self._hosts[hostname]:
t, addr = e
if addr in ipaddr:
return
if _IPV6_ADDRZ_RE.match(ipaddr):
self._hosts[hostname].append((socket.AF_INET6, ipaddr[1:-1]))
elif is_ipv6(ipaddr):
self._hosts[hostname].append((socket.AF_INET6, ipaddr))
else:
self._hosts[hostname].append((socket.AF_INET, ipaddr))
if len(self._hosts) > self._maxsize:
k = None
for k in self._hosts.keys():
break
if k:
self._hosts.pop(k)
def clear(self, hostname: str) -> None:
if hostname in self._hosts:
del self._hosts[hostname]
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
host = "localhost" # Defensive: stdlib cpy behavior
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
results: list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
] = []
if host not in self._hosts:
raise socket.gaierror(f"no records found for hostname {host} in-memory")
for entry in self._hosts[host]:
addr_type, addr_target = entry
if family != socket.AF_UNSPEC:
if family != addr_type:
continue
results.append(
(
addr_type,
type,
6 if type == socket.SOCK_STREAM else 17,
"",
(addr_target, port)
if addr_type == socket.AF_INET
else (addr_target, port, 0, 0),
)
)
if not results:
raise socket.gaierror(f"no records found for hostname {host} in-memory")
return results

View File

@@ -0,0 +1,105 @@
from __future__ import annotations
import socket
import typing
from ...protocols import ProtocolResolver
from ...utils import is_ipv4, is_ipv6
from ..protocols import AsyncBaseResolver
class NullResolver(AsyncBaseResolver):
protocol = ProtocolResolver.NULL
implementation = "dummy"
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
kwargs.pop("port")
super().__init__(None, None, *patterns, **kwargs)
def recycle(self) -> AsyncBaseResolver:
return self
async def close(self) -> None: # type: ignore[override]
pass # no-op
def is_available(self) -> bool:
return True
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if host is None:
host = "localhost" # Defensive: stdlib cpy behavior
if port is None:
port = 0 # Defensive: stdlib cpy behavior
if isinstance(port, str):
port = int(port) # Defensive: stdlib cpy behavior
if port < 0:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Servname not supported for ai_socktype"
)
if isinstance(host, bytes):
host = host.decode("ascii") # Defensive: stdlib cpy behavior
if is_ipv4(host):
if family == socket.AF_INET6:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET,
type,
6,
"",
(
host,
port,
),
)
]
elif is_ipv6(host):
if family == socket.AF_INET:
raise socket.gaierror( # Defensive: stdlib cpy behavior
"Address family for hostname not supported"
)
return [
(
socket.AF_INET6,
type,
17,
"",
(
host,
port,
0,
0,
),
)
]
raise socket.gaierror(f"Tried to resolve '{host}' using the NullResolver")
__all__ = ("NullResolver",)

View File

@@ -0,0 +1,375 @@
from __future__ import annotations
import asyncio
import ipaddress
import socket
import struct
import sys
import typing
from abc import ABCMeta, abstractmethod
from datetime import datetime, timedelta, timezone
from ...._constant import UDP_LINUX_GRO
from ...._typing import _TYPE_SOCKET_OPTIONS, _TYPE_TIMEOUT_INTERNAL
from ....exceptions import LocationParseError
from ....util.connection import _set_socket_options, allowed_gai_family
from ....util.timeout import _DEFAULT_TIMEOUT
from ...ssa import AsyncSocket
from ...ssa._timeout import timeout as timeout_
from ..protocols import BaseResolver
class AsyncBaseResolver(BaseResolver, metaclass=ABCMeta):
def recycle(self) -> AsyncBaseResolver:
return super().recycle() # type: ignore[return-value]
@abstractmethod
async def close(self) -> None: # type: ignore[override]
"""Terminate the given resolver instance. This should render it unusable. Further inquiries should raise an exception."""
raise NotImplementedError
@abstractmethod
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
"""This method align itself on the standard library socket.getaddrinfo(). It must be implemented as-is on your Resolver."""
raise NotImplementedError
# This function is copied from socket.py in the Python 2.7 standard
# library test suite. Added to its signature is only `socket_options`.
# One additional modification is that we avoid binding to IPv6 servers
# discovered in DNS if the system doesn't have IPv6 functionality.
async def create_connection( # type: ignore[override]
self,
address: tuple[str, int],
timeout: _TYPE_TIMEOUT_INTERNAL = _DEFAULT_TIMEOUT,
source_address: tuple[str, int] | None = None,
socket_options: _TYPE_SOCKET_OPTIONS | None = None,
socket_kind: socket.SocketKind = socket.SOCK_STREAM,
*,
quic_upgrade_via_dns_rr: bool = False,
timing_hook: typing.Callable[[tuple[timedelta, timedelta, datetime]], None]
| None = None,
default_socket_family: socket.AddressFamily = socket.AF_UNSPEC,
) -> AsyncSocket:
"""Connect to *address* and return the socket object.
Convenience function. Connect to *address* (a 2-tuple ``(host,
port)``) and return the socket object. Passing the optional
*timeout* parameter will set the timeout on the socket instance
before attempting to connect. If no *timeout* is supplied, the
global default timeout setting returned by :func:`socket.getdefaulttimeout`
is used. If *source_address* is set it must be a tuple of (host, port)
for the socket to bind as a source address before making the connection.
An host of '' or port 0 tells the OS to use the default.
"""
host, port = address
if host.startswith("["):
host = host.strip("[]")
err = None
# Using the value from allowed_gai_family() in the context of getaddrinfo lets
# us select whether to work with IPv4 DNS records, IPv6 records, or both.
# The original create_connection function always returns all records.
family = allowed_gai_family()
if family != socket.AF_UNSPEC:
default_socket_family = family
if source_address is not None:
if isinstance(
ipaddress.ip_address(source_address[0]), ipaddress.IPv4Address
):
default_socket_family = socket.AF_INET
else:
default_socket_family = socket.AF_INET6
try:
host.encode("idna")
except UnicodeError:
raise LocationParseError(f"'{host}', label empty or too long") from None
dt_pre_resolve = datetime.now(tz=timezone.utc)
if timeout is not _DEFAULT_TIMEOUT and timeout is not None:
# we can hang here in case of bad networking conditions
# the DNS may never answer or the packets can be lost.
# this isn't possible in sync mode. unfortunately.
# found by user at https://github.com/jawah/niquests/issues/183
# todo: find a way to limit getaddrinfo delays in sync mode.
try:
async with timeout_(timeout):
records = await self.getaddrinfo(
host,
port,
default_socket_family,
socket_kind,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
except TimeoutError:
raise socket.gaierror(
f"unable to resolve '{host}' within timeout. the DNS server may be unresponsive."
)
else:
records = await self.getaddrinfo(
host,
port,
default_socket_family,
socket_kind,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
delta_post_resolve = datetime.now(tz=timezone.utc) - dt_pre_resolve
dt_pre_established = datetime.now(tz=timezone.utc)
for res in records:
af, socktype, proto, canonname, sa = res
sock = None
try:
sock = AsyncSocket(af, socktype, proto)
# we need to add this or reusing the same origin port will likely fail within
# short period of time. kernel put port on wait shut.
if source_address:
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
except (OSError, AttributeError): # Defensive: very old OS?
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
except (
OSError,
AttributeError,
): # Defensive: we can't do anything better than this.
pass
try:
sock.setsockopt(
socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0)
)
except (OSError, AttributeError):
pass
# attempt to leverage GRO when under Linux
if socktype == socket.SOCK_DGRAM and sys.platform == "linux":
try:
sock.setsockopt(socket.SOL_UDP, UDP_LINUX_GRO, 1)
except OSError: # Defensive: oh, well(...) anyway!
pass
# If provided, set socket level options before connecting.
_set_socket_options(sock, socket_options)
if timeout is not _DEFAULT_TIMEOUT:
sock.settimeout(timeout)
if source_address:
sock.bind(source_address)
try:
await sock.connect(sa)
except asyncio.CancelledError:
sock.close()
raise
# Break explicitly a reference cycle
err = None
delta_post_established = (
datetime.now(tz=timezone.utc) - dt_pre_established
)
if timing_hook is not None:
timing_hook(
(
delta_post_resolve,
delta_post_established,
datetime.now(tz=timezone.utc),
)
)
return sock
except (OSError, OverflowError) as _:
err = _
if sock is not None:
sock.close()
if isinstance(_, OverflowError):
break
if err is not None:
try:
raise err
finally:
# Break explicitly a reference cycle
err = None
else:
raise OSError("getaddrinfo returns an empty list")
class AsyncManyResolver(AsyncBaseResolver):
"""
Special resolver that use many child resolver. Priorities
are based on given order (list of BaseResolver).
"""
def __init__(self, *resolvers: AsyncBaseResolver) -> None:
super().__init__(None, None)
self._size = len(resolvers)
self._unconstrained: list[AsyncBaseResolver] = [
_ for _ in resolvers if not _.have_constraints()
]
self._constrained: list[AsyncBaseResolver] = [
_ for _ in resolvers if _.have_constraints()
]
self._concurrent: int = 0
self._terminated: bool = False
def recycle(self) -> AsyncBaseResolver:
resolvers = []
for resolver in self._unconstrained + self._constrained:
resolvers.append(resolver.recycle())
return AsyncManyResolver(*resolvers)
async def close(self) -> None: # type: ignore[override]
for resolver in self._unconstrained + self._constrained:
await resolver.close()
self._terminated = True
def is_available(self) -> bool:
return not self._terminated
def __resolvers(
self, constrained: bool = False
) -> typing.Generator[AsyncBaseResolver, None, None]:
resolvers = self._unconstrained if not constrained else self._constrained
if not resolvers:
return
with self._lock:
self._concurrent += 1
try:
resolver_count = len(resolvers)
start_idx = (self._concurrent - 1) % resolver_count
for idx in range(start_idx, resolver_count):
if not resolvers[idx].is_available():
with self._lock:
resolvers[idx] = resolvers[idx].recycle()
yield resolvers[idx]
if start_idx > 0:
for idx in range(0, start_idx):
if not resolvers[idx].is_available():
with self._lock:
resolvers[idx] = resolvers[idx].recycle()
yield resolvers[idx]
finally:
with self._lock:
self._concurrent -= 1
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
if isinstance(host, bytes):
host = host.decode("ascii")
if host is None:
host = "localhost"
tested_resolvers = []
any_constrained_tried: bool = False
for resolver in self.__resolvers(True):
can_resolve = resolver.support(host)
if can_resolve is True:
any_constrained_tried = True
try:
results = await resolver.getaddrinfo(
host,
port,
family,
type,
proto,
flags,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
if results:
return results
except socket.gaierror as exc:
if isinstance(exc.args[0], str) and (
"DNSSEC" in exc.args[0] or "DNSKEY" in exc.args[0]
):
raise
continue
elif can_resolve is False:
tested_resolvers.append(resolver)
if any_constrained_tried:
raise socket.gaierror(
f"Name or service not known: {host} using {self._size - len(self._unconstrained)} resolver(s)"
)
for resolver in self.__resolvers():
try:
results = await resolver.getaddrinfo(
host,
port,
family,
type,
proto,
flags,
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
)
if results:
return results
except socket.gaierror as exc:
if isinstance(exc.args[0], str) and (
"DNSSEC" in exc.args[0] or "DNSKEY" in exc.args[0]
):
raise
continue
raise socket.gaierror(
f"Name or service not known: {host} using {self._size - len(self._constrained)} resolver(s)"
)

View File

@@ -0,0 +1,5 @@
from __future__ import annotations
from ._socket import SystemResolver
__all__ = ("SystemResolver",)

View File

@@ -0,0 +1,66 @@
from __future__ import annotations
import asyncio
import socket
import typing
from ...protocols import ProtocolResolver
from ..protocols import AsyncBaseResolver
class SystemResolver(AsyncBaseResolver):
implementation = "socket"
protocol = ProtocolResolver.SYSTEM
def __init__(self, *patterns: str, **kwargs: typing.Any):
if "server" in kwargs:
kwargs.pop("server")
if "port" in kwargs:
kwargs.pop("port")
super().__init__(None, None, *patterns, **kwargs)
def support(self, hostname: str | bytes | None) -> bool | None:
if hostname is None:
return True
if isinstance(hostname, bytes):
hostname = hostname.decode("ascii")
if hostname == "localhost":
return True
return super().support(hostname)
def recycle(self) -> AsyncBaseResolver:
return self
async def close(self) -> None: # type: ignore[override]
pass # no-op!
def is_available(self) -> bool:
return True
async def getaddrinfo( # type: ignore[override]
self,
host: bytes | str | None,
port: str | int | None,
family: socket.AddressFamily,
type: socket.SocketKind,
proto: int = 0,
flags: int = 0,
*,
quic_upgrade_via_dns_rr: bool = False,
) -> list[
tuple[
socket.AddressFamily,
socket.SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
return await asyncio.get_running_loop().getaddrinfo(
host=host,
port=port,
family=family,
type=type,
proto=proto,
flags=flags,
)