fix: 포트 충돌 회피 — note_bridge 8098, intent_service 8099
Jellyfin(8096), OrbStack(8097) 포트 충돌으로 변경. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .factories import AsyncResolverDescription, AsyncResolverFactory
|
||||
from .protocols import AsyncBaseResolver, AsyncManyResolver
|
||||
|
||||
__all__ = (
|
||||
"AsyncResolverDescription",
|
||||
"AsyncResolverFactory",
|
||||
"AsyncBaseResolver",
|
||||
"AsyncManyResolver",
|
||||
)
|
||||
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._urllib3 import (
|
||||
AdGuardResolver,
|
||||
CloudflareResolver,
|
||||
GoogleResolver,
|
||||
HTTPSResolver,
|
||||
NextDNSResolver,
|
||||
OpenDNSResolver,
|
||||
Quad9Resolver,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"HTTPSResolver",
|
||||
"GoogleResolver",
|
||||
"CloudflareResolver",
|
||||
"AdGuardResolver",
|
||||
"OpenDNSResolver",
|
||||
"Quad9Resolver",
|
||||
"NextDNSResolver",
|
||||
)
|
||||
@@ -0,0 +1,656 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
from asyncio import as_completed
|
||||
from base64 import b64encode
|
||||
|
||||
from ....._async.connectionpool import AsyncHTTPSConnectionPool
|
||||
from ....._async.response import AsyncHTTPResponse
|
||||
from ....._collections import HTTPHeaderDict
|
||||
from .....backend import ConnectionInfo, HttpVersion
|
||||
from .....util.url import parse_url
|
||||
from ...protocols import (
|
||||
DomainNameServerQuery,
|
||||
DomainNameServerReturn,
|
||||
ProtocolResolver,
|
||||
SupportedQueryType,
|
||||
)
|
||||
from ...utils import is_ipv4, is_ipv6, validate_length_of, parse_https_rdata
|
||||
from ..protocols import AsyncBaseResolver
|
||||
|
||||
|
||||
class HTTPSResolver(AsyncBaseResolver):
|
||||
"""
|
||||
Advanced DNS over HTTPS resolver.
|
||||
No common ground emerged from IETF w/ JSON. Following Google’s DNS over HTTPS schematics that is
|
||||
also implemented at Cloudflare.
|
||||
|
||||
Support RFC 8484 without JSON. Disabled by default.
|
||||
"""
|
||||
|
||||
implementation = "urllib3"
|
||||
protocol = ProtocolResolver.DOH
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: str | None,
|
||||
port: int | None = None,
|
||||
*patterns: str,
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
super().__init__(server, port or 443, *patterns, **kwargs)
|
||||
|
||||
self._path: str = "/resolve"
|
||||
|
||||
if "path" in kwargs:
|
||||
if isinstance(kwargs["path"], str) and kwargs["path"] != "/":
|
||||
self._path = kwargs["path"]
|
||||
kwargs.pop("path")
|
||||
|
||||
self._rfc8484: bool = False
|
||||
|
||||
if "rfc8484" in kwargs:
|
||||
if kwargs["rfc8484"]:
|
||||
self._rfc8484 = True
|
||||
kwargs.pop("rfc8484")
|
||||
|
||||
assert self._server is not None
|
||||
|
||||
if "source_address" in kwargs:
|
||||
if isinstance(kwargs["source_address"], str):
|
||||
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
|
||||
|
||||
if bind_ip and bind_port.isdigit():
|
||||
kwargs["source_address"] = (
|
||||
bind_ip,
|
||||
int(bind_port),
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid source_address given in parameters")
|
||||
else:
|
||||
raise ValueError("invalid source_address given in parameters")
|
||||
|
||||
if "proxy" in kwargs:
|
||||
kwargs["_proxy"] = parse_url(kwargs["proxy"])
|
||||
kwargs.pop("proxy")
|
||||
|
||||
if "maxsize" not in kwargs:
|
||||
kwargs["maxsize"] = 10
|
||||
|
||||
if "proxy_headers" in kwargs and "_proxy" in kwargs:
|
||||
proxy_headers = HTTPHeaderDict()
|
||||
|
||||
if not isinstance(kwargs["proxy_headers"], list):
|
||||
kwargs["proxy_headers"] = [kwargs["proxy_headers"]]
|
||||
|
||||
for item in kwargs["proxy_headers"]:
|
||||
if ":" not in item:
|
||||
raise ValueError("Passed header is invalid in DNS parameters")
|
||||
|
||||
k, v = item.split(":", 1)
|
||||
proxy_headers.add(k, v)
|
||||
|
||||
kwargs["_proxy_headers"] = proxy_headers
|
||||
|
||||
if "headers" in kwargs:
|
||||
headers = HTTPHeaderDict()
|
||||
|
||||
if not isinstance(kwargs["headers"], list):
|
||||
kwargs["headers"] = [kwargs["headers"]]
|
||||
|
||||
for item in kwargs["headers"]:
|
||||
if ":" not in item:
|
||||
raise ValueError("Passed header is invalid in DNS parameters")
|
||||
|
||||
k, v = item.split(":", 1)
|
||||
headers.add(k, v)
|
||||
|
||||
kwargs["headers"] = headers
|
||||
|
||||
if "disabled_svn" in kwargs:
|
||||
if not isinstance(kwargs["disabled_svn"], list):
|
||||
kwargs["disabled_svn"] = [kwargs["disabled_svn"]]
|
||||
|
||||
disabled_svn = set()
|
||||
|
||||
for svn in kwargs["disabled_svn"]:
|
||||
svn = svn.lower()
|
||||
|
||||
if svn == "h11":
|
||||
disabled_svn.add(HttpVersion.h11)
|
||||
elif svn == "h2":
|
||||
disabled_svn.add(HttpVersion.h2)
|
||||
elif svn == "h3":
|
||||
disabled_svn.add(HttpVersion.h3)
|
||||
|
||||
kwargs["disabled_svn"] = disabled_svn
|
||||
|
||||
if "on_post_connection" in kwargs and callable(kwargs["on_post_connection"]):
|
||||
self._connection_callback: (
|
||||
typing.Callable[[ConnectionInfo], None] | None
|
||||
) = kwargs["on_post_connection"]
|
||||
kwargs.pop("on_post_connection")
|
||||
else:
|
||||
self._connection_callback = None
|
||||
|
||||
self._pool = AsyncHTTPSConnectionPool(self._server, self._port, **kwargs)
|
||||
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
await self._pool.close()
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self._pool.pool is not None
|
||||
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if host is None:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Tried to resolve 'localhost' from a HTTPSResolver"
|
||||
)
|
||||
|
||||
if port is None:
|
||||
port = 0 # Defensive: stdlib cpy behavior
|
||||
if isinstance(port, str):
|
||||
port = int(port) # Defensive: stdlib cpy behavior
|
||||
if port < 0:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Servname not supported for ai_socktype"
|
||||
)
|
||||
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii") # Defensive: stdlib cpy behavior
|
||||
|
||||
if is_ipv4(host):
|
||||
if family == socket.AF_INET6:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
type,
|
||||
6,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
),
|
||||
)
|
||||
]
|
||||
elif is_ipv6(host):
|
||||
if family == socket.AF_INET:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET6,
|
||||
type,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
validate_length_of(host)
|
||||
|
||||
promises = []
|
||||
remote_preemptive_quic_rr = False
|
||||
|
||||
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
|
||||
quic_upgrade_via_dns_rr = False
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET]:
|
||||
if not self._rfc8484:
|
||||
promises.append(
|
||||
await self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{"name": host, "type": "1"},
|
||||
headers={"Accept": "application/dns-json"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
dns_query = DomainNameServerQuery(
|
||||
host, SupportedQueryType.A, override_id=0
|
||||
)
|
||||
dns_payload = bytes(dns_query)
|
||||
promises.append(
|
||||
await self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{
|
||||
"dns": b64encode(dns_payload).decode().replace("=", ""),
|
||||
},
|
||||
headers={"Accept": "application/dns-message"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
|
||||
if not self._rfc8484:
|
||||
promises.append(
|
||||
await self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{"name": host, "type": "28"},
|
||||
headers={"Accept": "application/dns-json"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
dns_query = DomainNameServerQuery(
|
||||
host, SupportedQueryType.AAAA, override_id=0
|
||||
)
|
||||
dns_payload = bytes(dns_query)
|
||||
|
||||
promises.append(
|
||||
await self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{
|
||||
"dns": b64encode(dns_payload).decode().replace("=", ""),
|
||||
},
|
||||
headers={"Accept": "application/dns-message"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
|
||||
if quic_upgrade_via_dns_rr:
|
||||
if not self._rfc8484:
|
||||
promises.append(
|
||||
await self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{"name": host, "type": "65"},
|
||||
headers={"Accept": "application/dns-json"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
dns_query = DomainNameServerQuery(
|
||||
host, SupportedQueryType.HTTPS, override_id=0
|
||||
)
|
||||
dns_payload = bytes(dns_query)
|
||||
|
||||
promises.append(
|
||||
await self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{
|
||||
"dns": b64encode(dns_payload).decode().replace("=", ""),
|
||||
},
|
||||
headers={"Accept": "application/dns-message"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
|
||||
tasks = []
|
||||
responses = []
|
||||
|
||||
for promise in promises:
|
||||
# already resolved
|
||||
if isinstance(promise, AsyncHTTPResponse):
|
||||
responses.append(promise)
|
||||
continue
|
||||
tasks.append(self._pool.get_response(promise=promise))
|
||||
|
||||
if tasks:
|
||||
for waiting_promise_coro in as_completed(tasks):
|
||||
responses.append(await waiting_promise_coro) # type: ignore[arg-type]
|
||||
|
||||
results: list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
] = []
|
||||
|
||||
for response in responses:
|
||||
if response.status >= 300:
|
||||
raise socket.gaierror(
|
||||
f"DNS over HTTPS was unsuccessful, server response status {response.status}."
|
||||
)
|
||||
|
||||
if not self._rfc8484:
|
||||
payload = await response.json()
|
||||
|
||||
assert "Status" in payload and isinstance(payload["Status"], int)
|
||||
|
||||
if payload["Status"] != 0:
|
||||
msg = (
|
||||
payload["Comment"]
|
||||
if "Comment" in payload
|
||||
else f"Remote DNS indicated that an error occurred while providing resolution. Status {payload['Status']}."
|
||||
)
|
||||
|
||||
if isinstance(msg, list):
|
||||
msg = ", ".join(msg)
|
||||
|
||||
raise socket.gaierror(msg)
|
||||
|
||||
assert "Question" in payload and isinstance(payload["Question"], list)
|
||||
|
||||
if "Answer" not in payload:
|
||||
continue
|
||||
|
||||
assert isinstance(payload["Answer"], list)
|
||||
|
||||
for answer in payload["Answer"]:
|
||||
if answer["type"] not in [1, 28, 65]:
|
||||
continue
|
||||
|
||||
assert "data" in answer
|
||||
assert isinstance(answer["data"], str)
|
||||
|
||||
# DNS RR/HTTPS
|
||||
if answer["type"] == 65:
|
||||
# "1 . alpn=h3,h2 ipv4hint=104.16.132.229,104.16.133.229 ipv6hint=2606:4700::6810:84e5,2606:4700::6810:85e5"
|
||||
# or..
|
||||
# "1 . alpn=h2,h3"
|
||||
rr: str = answer["data"]
|
||||
|
||||
if rr.startswith("\\#"): # it means, raw, bytes.
|
||||
rr = "".join(rr[2:].split(" ")[2:])
|
||||
|
||||
try:
|
||||
raw_record = bytes.fromhex(rr)
|
||||
except ValueError:
|
||||
raw_record = b""
|
||||
|
||||
if not raw_record:
|
||||
continue
|
||||
|
||||
https_record = parse_https_rdata(raw_record)
|
||||
|
||||
if "h3" not in https_record["alpn"]:
|
||||
continue
|
||||
|
||||
remote_preemptive_quic_rr = True
|
||||
else:
|
||||
rr_decode: dict[str, str] = dict(
|
||||
tuple(_.lower().split("=", 1)) # type: ignore[misc]
|
||||
for _ in rr.split(" ")
|
||||
if "=" in _
|
||||
)
|
||||
|
||||
if "alpn" not in rr_decode or "h3" not in rr_decode["alpn"]:
|
||||
continue
|
||||
|
||||
remote_preemptive_quic_rr = True
|
||||
|
||||
if "ipv4hint" in rr_decode and family in [
|
||||
socket.AF_UNSPEC,
|
||||
socket.AF_INET,
|
||||
]:
|
||||
for ipv4 in rr_decode["ipv4hint"].split(","):
|
||||
results.append(
|
||||
(
|
||||
socket.AF_INET,
|
||||
socket.SOCK_DGRAM,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
ipv4,
|
||||
port,
|
||||
),
|
||||
)
|
||||
)
|
||||
if "ipv6hint" in rr_decode and family in [
|
||||
socket.AF_UNSPEC,
|
||||
socket.AF_INET6,
|
||||
]:
|
||||
for ipv6 in rr_decode["ipv6hint"].split(","):
|
||||
results.append(
|
||||
(
|
||||
socket.AF_INET6,
|
||||
socket.SOCK_DGRAM,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
ipv6,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
inet_type = (
|
||||
socket.AF_INET if answer["type"] == 1 else socket.AF_INET6
|
||||
)
|
||||
|
||||
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
|
||||
(
|
||||
answer["data"],
|
||||
port,
|
||||
)
|
||||
if inet_type == socket.AF_INET
|
||||
else (
|
||||
answer["data"],
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
)
|
||||
|
||||
results.append(
|
||||
(
|
||||
inet_type,
|
||||
type,
|
||||
6 if type == socket.SOCK_STREAM else 17,
|
||||
"",
|
||||
dst_addr,
|
||||
)
|
||||
)
|
||||
else:
|
||||
dns_resp = DomainNameServerReturn(await response.data)
|
||||
|
||||
for record in dns_resp.records:
|
||||
if record[0] == SupportedQueryType.HTTPS:
|
||||
assert isinstance(record[-1], dict)
|
||||
if "h3" in record[-1]["alpn"]:
|
||||
remote_preemptive_quic_rr = True
|
||||
continue
|
||||
|
||||
assert not isinstance(record[-1], dict)
|
||||
|
||||
inet_type = (
|
||||
socket.AF_INET
|
||||
if record[0] == SupportedQueryType.A
|
||||
else socket.AF_INET6
|
||||
)
|
||||
|
||||
dst_addr = (
|
||||
(
|
||||
record[-1],
|
||||
port,
|
||||
)
|
||||
if inet_type == socket.AF_INET
|
||||
else (
|
||||
record[-1],
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
)
|
||||
|
||||
results.append(
|
||||
(
|
||||
inet_type,
|
||||
type,
|
||||
6 if type == socket.SOCK_STREAM else 17,
|
||||
"",
|
||||
dst_addr,
|
||||
)
|
||||
)
|
||||
|
||||
quic_results: list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
] = []
|
||||
|
||||
if remote_preemptive_quic_rr:
|
||||
any_specified = False
|
||||
|
||||
for result in results:
|
||||
if result[1] == socket.SOCK_STREAM:
|
||||
quic_results.append(
|
||||
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
|
||||
)
|
||||
else:
|
||||
any_specified = True
|
||||
break
|
||||
|
||||
if any_specified:
|
||||
quic_results = []
|
||||
|
||||
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
|
||||
|
||||
|
||||
class GoogleResolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "google"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
if "rfc8484" in kwargs:
|
||||
if kwargs["rfc8484"]:
|
||||
kwargs["path"] = "/dns-query"
|
||||
super().__init__("dns.google", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class CloudflareResolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "cloudflare"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
kwargs.update({"path": "/dns-query"})
|
||||
super().__init__("cloudflare-dns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class AdGuardResolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "adguard"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
kwargs.update({"path": "/dns-query", "rfc8484": True})
|
||||
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class OpenDNSResolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "opendns"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
kwargs.update({"path": "/dns-query", "rfc8484": True})
|
||||
super().__init__("dns.opendns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class Quad9Resolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "quad9"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
kwargs.update({"path": "/dns-query", "rfc8484": True})
|
||||
super().__init__("dns11.quad9.net", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class NextDNSResolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "nextdns"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("dns.nextdns.io", port, *patterns, **kwargs)
|
||||
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
try:
|
||||
from ._qh3 import AdGuardResolver, NextDNSResolver, QUICResolver
|
||||
except ImportError:
|
||||
QUICResolver = None # type: ignore
|
||||
AdGuardResolver = None # type: ignore
|
||||
NextDNSResolver = None # type: ignore
|
||||
|
||||
|
||||
__all__ = (
|
||||
"QUICResolver",
|
||||
"AdGuardResolver",
|
||||
"NextDNSResolver",
|
||||
)
|
||||
@@ -0,0 +1,557 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
import ssl
|
||||
import typing
|
||||
from collections import deque
|
||||
from ssl import SSLError
|
||||
from time import time as monotonic
|
||||
|
||||
from qh3.quic.configuration import QuicConfiguration
|
||||
from qh3.quic.connection import QuicConnection
|
||||
from qh3.quic.events import (
|
||||
ConnectionTerminated,
|
||||
HandshakeCompleted,
|
||||
QuicEvent,
|
||||
StopSendingReceived,
|
||||
StreamDataReceived,
|
||||
StreamReset,
|
||||
)
|
||||
|
||||
from .....util.ssl_ import IS_FIPS, resolve_cert_reqs
|
||||
from ...protocols import (
|
||||
COMMON_RCODE_LABEL,
|
||||
DomainNameServerQuery,
|
||||
DomainNameServerReturn,
|
||||
ProtocolResolver,
|
||||
SupportedQueryType,
|
||||
)
|
||||
from ...utils import (
|
||||
is_ipv4,
|
||||
is_ipv6,
|
||||
rfc1035_pack,
|
||||
rfc1035_should_read,
|
||||
rfc1035_unpack,
|
||||
validate_length_of,
|
||||
)
|
||||
from ..dou import PlainResolver
|
||||
from ..system import SystemResolver
|
||||
|
||||
if IS_FIPS:
|
||||
raise ImportError(
|
||||
"DNS-over-QUIC disabled when Python is built with FIPS-compliant ssl module"
|
||||
)
|
||||
|
||||
|
||||
class QUICResolver(PlainResolver):
|
||||
protocol = ProtocolResolver.DOQ
|
||||
implementation = "qh3"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: str | None,
|
||||
port: int | None = None,
|
||||
*patterns: str,
|
||||
**kwargs: typing.Any,
|
||||
):
|
||||
super().__init__(server, port or 853, *patterns, **kwargs)
|
||||
|
||||
# qh3 load_default_certs seems off. need to investigate.
|
||||
if "ca_cert_data" not in kwargs and "ca_certs" not in kwargs:
|
||||
kwargs["ca_cert_data"] = []
|
||||
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
|
||||
try:
|
||||
ctx.load_default_certs()
|
||||
|
||||
for der in ctx.get_ca_certs(binary_form=True):
|
||||
kwargs["ca_cert_data"].append(ssl.DER_cert_to_PEM_cert(der))
|
||||
|
||||
if kwargs["ca_cert_data"]:
|
||||
kwargs["ca_cert_data"] = "".join(kwargs["ca_cert_data"])
|
||||
else:
|
||||
del kwargs["ca_cert_data"]
|
||||
except (AttributeError, ValueError, OSError):
|
||||
del kwargs["ca_cert_data"]
|
||||
|
||||
if "ca_cert_data" not in kwargs and "ca_certs" not in kwargs:
|
||||
if (
|
||||
"cert_reqs" not in kwargs
|
||||
or resolve_cert_reqs(kwargs["cert_reqs"]) is ssl.CERT_REQUIRED
|
||||
):
|
||||
raise ssl.SSLError(
|
||||
"DoQ requires at least one CA loaded in order to verify the remote peer certificate. "
|
||||
"Add ?cert_reqs=0 to disable certificate checks."
|
||||
)
|
||||
|
||||
configuration = QuicConfiguration(
|
||||
is_client=True,
|
||||
alpn_protocols=["doq"],
|
||||
server_name=self._server
|
||||
if "server_hostname" not in kwargs
|
||||
else kwargs["server_hostname"],
|
||||
verify_mode=resolve_cert_reqs(kwargs["cert_reqs"])
|
||||
if "cert_reqs" in kwargs
|
||||
else ssl.CERT_REQUIRED,
|
||||
cadata=kwargs["ca_cert_data"].encode()
|
||||
if "ca_cert_data" in kwargs
|
||||
else None,
|
||||
cafile=kwargs["ca_certs"] if "ca_certs" in kwargs else None,
|
||||
idle_timeout=300.0,
|
||||
)
|
||||
|
||||
if "cert_file" in kwargs:
|
||||
configuration.load_cert_chain(
|
||||
kwargs["cert_file"],
|
||||
kwargs["key_file"] if "key_file" in kwargs else None,
|
||||
kwargs["key_password"] if "key_password" in kwargs else None,
|
||||
)
|
||||
elif "cert_data" in kwargs:
|
||||
configuration.load_cert_chain(
|
||||
kwargs["cert_data"],
|
||||
kwargs["key_data"] if "key_data" in kwargs else None,
|
||||
kwargs["key_password"] if "key_password" in kwargs else None,
|
||||
)
|
||||
|
||||
self._quic = QuicConnection(configuration=configuration)
|
||||
|
||||
self._read_semaphore: asyncio.Semaphore = asyncio.Semaphore()
|
||||
self._connect_attempt: asyncio.Event = asyncio.Event()
|
||||
self._handshake_event: asyncio.Event = asyncio.Event()
|
||||
|
||||
self._terminated: bool = False
|
||||
self._should_disconnect: bool = False
|
||||
|
||||
# DNS over QUIC mandate the size-prefix (unsigned int, 2b)
|
||||
self._rfc1035_prefix_mandated = True
|
||||
|
||||
self._unconsumed: deque[DomainNameServerReturn] = deque()
|
||||
self._pending: deque[DomainNameServerQuery] = deque()
|
||||
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
if (
|
||||
not self._terminated
|
||||
and self._socket is not None
|
||||
and not self._socket.should_connect()
|
||||
):
|
||||
self._quic.close()
|
||||
|
||||
while True:
|
||||
datagrams = self._quic.datagrams_to_send(monotonic())
|
||||
|
||||
if not datagrams:
|
||||
break
|
||||
|
||||
for datagram in datagrams:
|
||||
data, addr = datagram
|
||||
await self._socket.sendall(data)
|
||||
|
||||
self._socket.close()
|
||||
await self._socket.wait_for_close()
|
||||
self._terminated = True
|
||||
if self._socket is None or self._socket.should_connect():
|
||||
self._terminated = True
|
||||
|
||||
def is_available(self) -> bool:
|
||||
if self._terminated:
|
||||
return False
|
||||
if self._socket is None or self._socket.should_connect():
|
||||
return True
|
||||
self._quic.handle_timer(monotonic())
|
||||
if hasattr(self._quic, "_close_event") and self._quic._close_event is not None:
|
||||
self._terminated = True
|
||||
return not self._terminated
|
||||
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if host is None:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Tried to resolve 'localhost' using the QUICResolver"
|
||||
)
|
||||
if port is None:
|
||||
port = 0 # Defensive: stdlib cpy behavior
|
||||
if isinstance(port, str):
|
||||
port = int(port) # Defensive: stdlib cpy behavior
|
||||
if port < 0:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Servname not supported for ai_socktype"
|
||||
)
|
||||
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii") # Defensive: stdlib cpy behavior
|
||||
|
||||
if is_ipv4(host):
|
||||
if family == socket.AF_INET6:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
type,
|
||||
6,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
),
|
||||
)
|
||||
]
|
||||
elif is_ipv6(host):
|
||||
if family == socket.AF_INET:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET6,
|
||||
type,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
validate_length_of(host)
|
||||
|
||||
if self._socket is None and self._connect_attempt.is_set() is False:
|
||||
self._connect_attempt.set()
|
||||
assert self.server is not None
|
||||
self._quic.connect((self._server, self._port), monotonic())
|
||||
self._socket = await SystemResolver().create_connection(
|
||||
(self.server, self.port or 853),
|
||||
timeout=self._timeout,
|
||||
source_address=self._source_address,
|
||||
socket_options=None,
|
||||
socket_kind=self._socket_type,
|
||||
)
|
||||
await self.__exchange_until(HandshakeCompleted, receive_first=False)
|
||||
self._handshake_event.set()
|
||||
else:
|
||||
await self._handshake_event.wait()
|
||||
|
||||
assert self._socket is not None
|
||||
|
||||
remote_preemptive_quic_rr = False
|
||||
|
||||
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
|
||||
quic_upgrade_via_dns_rr = False
|
||||
|
||||
tbq = []
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET]:
|
||||
tbq.append(SupportedQueryType.A)
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
|
||||
tbq.append(SupportedQueryType.AAAA)
|
||||
|
||||
if quic_upgrade_via_dns_rr:
|
||||
tbq.append(SupportedQueryType.HTTPS)
|
||||
|
||||
queries = DomainNameServerQuery.bulk(host, *tbq)
|
||||
open_streams = []
|
||||
|
||||
for q in queries:
|
||||
payload = bytes(q)
|
||||
|
||||
self._pending.append(q)
|
||||
|
||||
if self._rfc1035_prefix_mandated is True:
|
||||
payload = rfc1035_pack(payload)
|
||||
|
||||
stream_id = self._quic.get_next_available_stream_id()
|
||||
self._quic.send_stream_data(stream_id, payload, True)
|
||||
|
||||
open_streams.append(stream_id)
|
||||
|
||||
for dg in self._quic.datagrams_to_send(monotonic()):
|
||||
await self._socket.sendall(dg[0])
|
||||
|
||||
responses: list[DomainNameServerReturn] = []
|
||||
|
||||
while len(responses) < len(tbq):
|
||||
await self._read_semaphore.acquire()
|
||||
if self._unconsumed:
|
||||
dns_resp = None
|
||||
for query in queries:
|
||||
for unconsumed in self._unconsumed:
|
||||
if unconsumed.id == query.id:
|
||||
dns_resp = unconsumed
|
||||
responses.append(dns_resp)
|
||||
break
|
||||
if dns_resp:
|
||||
break
|
||||
if dns_resp:
|
||||
self._unconsumed.remove(dns_resp)
|
||||
self._pending.remove(query)
|
||||
self._read_semaphore.release()
|
||||
continue
|
||||
|
||||
try:
|
||||
events: list[StreamDataReceived] = await self.__exchange_until( # type: ignore[assignment]
|
||||
StreamDataReceived,
|
||||
receive_first=True,
|
||||
event_type_collectable=(StreamDataReceived,),
|
||||
respect_end_stream_signal=False,
|
||||
)
|
||||
|
||||
payload = b"".join([e.data for e in events])
|
||||
|
||||
while rfc1035_should_read(payload):
|
||||
events.extend(
|
||||
await self.__exchange_until( # type: ignore[arg-type]
|
||||
StreamDataReceived,
|
||||
receive_first=True,
|
||||
event_type_collectable=(StreamDataReceived,),
|
||||
respect_end_stream_signal=False,
|
||||
)
|
||||
)
|
||||
payload = b"".join([e.data for e in events])
|
||||
except (TimeoutError, OSError, socket.timeout, ConnectionError) as e:
|
||||
raise socket.gaierror(
|
||||
"Got unexpectedly disconnected while waiting for name resolution"
|
||||
) from e
|
||||
|
||||
self._read_semaphore.release()
|
||||
|
||||
if not payload:
|
||||
continue
|
||||
|
||||
#: We can receive two responses at once (or more, concatenated). Let's unwrap them.
|
||||
fragments = rfc1035_unpack(payload)
|
||||
|
||||
for fragment in fragments:
|
||||
dns_resp = DomainNameServerReturn(fragment)
|
||||
|
||||
if any(dns_resp.id == _.id for _ in queries):
|
||||
responses.append(dns_resp)
|
||||
|
||||
query_tbr: DomainNameServerQuery | None = None
|
||||
|
||||
for query_tbr in self._pending:
|
||||
if query_tbr.id == dns_resp.id:
|
||||
break
|
||||
if query_tbr:
|
||||
self._pending.remove(query_tbr)
|
||||
else:
|
||||
self._unconsumed.append(dns_resp)
|
||||
|
||||
if self._should_disconnect:
|
||||
await self.close()
|
||||
self._should_disconnect = False
|
||||
self._terminated = True
|
||||
|
||||
results = []
|
||||
|
||||
for response in responses:
|
||||
if not response.is_ok:
|
||||
if response.rcode == 2:
|
||||
raise socket.gaierror(
|
||||
f"DNSSEC validation failure. Check http://dnsviz.net/d/{host}/dnssec/ and http://dnssec-debugger.verisignlabs.com/{host} for errors"
|
||||
)
|
||||
raise socket.gaierror(
|
||||
f"DNS returned an error: {COMMON_RCODE_LABEL[response.rcode] if response.rcode in COMMON_RCODE_LABEL else f'code {response.rcode}'}"
|
||||
)
|
||||
|
||||
for record in response.records:
|
||||
if record[0] == SupportedQueryType.HTTPS:
|
||||
assert isinstance(record[-1], dict)
|
||||
if "h3" in record[-1]["alpn"]:
|
||||
remote_preemptive_quic_rr = True
|
||||
continue
|
||||
|
||||
assert not isinstance(record[-1], dict)
|
||||
|
||||
inet_type = (
|
||||
socket.AF_INET
|
||||
if record[0] == SupportedQueryType.A
|
||||
else socket.AF_INET6
|
||||
)
|
||||
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
|
||||
(
|
||||
record[-1],
|
||||
port,
|
||||
)
|
||||
if inet_type == socket.AF_INET
|
||||
else (
|
||||
record[-1],
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
)
|
||||
|
||||
results.append(
|
||||
(
|
||||
inet_type,
|
||||
type,
|
||||
6 if type == socket.SOCK_STREAM else 17,
|
||||
"",
|
||||
dst_addr,
|
||||
)
|
||||
)
|
||||
|
||||
quic_results = []
|
||||
|
||||
if remote_preemptive_quic_rr:
|
||||
any_specified = False
|
||||
|
||||
for result in results:
|
||||
if result[1] == socket.SOCK_STREAM:
|
||||
quic_results.append(
|
||||
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
|
||||
)
|
||||
else:
|
||||
any_specified = True
|
||||
break
|
||||
|
||||
if any_specified:
|
||||
quic_results = []
|
||||
|
||||
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
|
||||
|
||||
async def __exchange_until(
|
||||
self,
|
||||
event_type: type[QuicEvent] | tuple[type[QuicEvent], ...],
|
||||
*,
|
||||
receive_first: bool = False,
|
||||
event_type_collectable: type[QuicEvent]
|
||||
| tuple[type[QuicEvent], ...]
|
||||
| None = None,
|
||||
respect_end_stream_signal: bool = True,
|
||||
) -> list[QuicEvent]:
|
||||
assert self._socket is not None
|
||||
|
||||
while True:
|
||||
if receive_first is False:
|
||||
now = monotonic()
|
||||
while True:
|
||||
datagrams = self._quic.datagrams_to_send(now)
|
||||
|
||||
if not datagrams:
|
||||
break
|
||||
|
||||
for datagram in datagrams:
|
||||
data, addr = datagram
|
||||
await self._socket.sendall(data)
|
||||
|
||||
events = []
|
||||
|
||||
while True:
|
||||
if not self._quic._events:
|
||||
data_in = await self._socket.recv(1500)
|
||||
|
||||
if not data_in:
|
||||
break
|
||||
|
||||
now = monotonic()
|
||||
|
||||
if not isinstance(data_in, list):
|
||||
self._quic.receive_datagram(
|
||||
data_in, (self._server, self._port), now
|
||||
)
|
||||
else:
|
||||
for gro_segment in data_in:
|
||||
self._quic.receive_datagram(
|
||||
gro_segment, (self._server, self._port), now
|
||||
)
|
||||
|
||||
while True:
|
||||
datagrams = self._quic.datagrams_to_send(now)
|
||||
|
||||
if not datagrams:
|
||||
break
|
||||
|
||||
for datagram in datagrams:
|
||||
data, addr = datagram
|
||||
await self._socket.sendall(data)
|
||||
|
||||
for ev in iter(self._quic.next_event, None):
|
||||
if isinstance(ev, ConnectionTerminated):
|
||||
if ev.error_code == 298:
|
||||
raise SSLError(
|
||||
"DNS over QUIC did not succeed (Error 298). Chain certificate verification failed."
|
||||
)
|
||||
raise socket.gaierror(
|
||||
f"DNS over QUIC encountered a unrecoverable failure (error {ev.error_code} {ev.reason_phrase})"
|
||||
)
|
||||
elif isinstance(ev, StreamReset):
|
||||
self._terminated = True
|
||||
raise socket.gaierror(
|
||||
"DNS over QUIC server submitted a StreamReset. A request was rejected."
|
||||
)
|
||||
elif isinstance(ev, StopSendingReceived):
|
||||
self._should_disconnect = True
|
||||
continue
|
||||
|
||||
if event_type_collectable:
|
||||
if isinstance(ev, event_type_collectable):
|
||||
events.append(ev)
|
||||
else:
|
||||
events.append(ev)
|
||||
|
||||
if isinstance(ev, event_type):
|
||||
if not respect_end_stream_signal:
|
||||
return events
|
||||
if hasattr(ev, "stream_ended") and ev.stream_ended:
|
||||
return events
|
||||
elif hasattr(ev, "stream_ended") is False:
|
||||
return events
|
||||
|
||||
return events
|
||||
|
||||
|
||||
class AdGuardResolver(
|
||||
QUICResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "adguard"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any):
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class NextDNSResolver(
|
||||
QUICResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "nextdns"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any):
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
super().__init__("dns.nextdns.io", port, *patterns, **kwargs)
|
||||
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._ssl import (
|
||||
AdGuardResolver,
|
||||
CloudflareResolver,
|
||||
GoogleResolver,
|
||||
OpenDNSResolver,
|
||||
Quad9Resolver,
|
||||
TLSResolver,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"TLSResolver",
|
||||
"GoogleResolver",
|
||||
"CloudflareResolver",
|
||||
"AdGuardResolver",
|
||||
"Quad9Resolver",
|
||||
"OpenDNSResolver",
|
||||
)
|
||||
@@ -0,0 +1,197 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
|
||||
from .....util._async.ssl_ import ssl_wrap_socket
|
||||
from .....util.ssl_ import resolve_cert_reqs
|
||||
from ...protocols import ProtocolResolver
|
||||
from ..dou import PlainResolver
|
||||
from ..system import SystemResolver
|
||||
|
||||
|
||||
class TLSResolver(PlainResolver):
|
||||
"""
|
||||
Basic DNS resolver over TLS.
|
||||
Comply with RFC 7858: https://datatracker.ietf.org/doc/html/rfc7858
|
||||
"""
|
||||
|
||||
protocol = ProtocolResolver.DOT
|
||||
implementation = "ssl"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: str | None,
|
||||
port: int | None = None,
|
||||
*patterns: str,
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
self._socket_type = socket.SOCK_STREAM
|
||||
|
||||
super().__init__(server, port or 853, *patterns, **kwargs)
|
||||
|
||||
# DNS over TLS mandate the size-prefix (unsigned int, 2 bytes)
|
||||
self._rfc1035_prefix_mandated = True
|
||||
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if self._socket is None and self._connect_attempt.is_set() is False:
|
||||
assert self.server is not None
|
||||
self._connect_attempt.set()
|
||||
self._socket = await SystemResolver().create_connection(
|
||||
(self.server, self.port or 853),
|
||||
timeout=self._timeout,
|
||||
source_address=self._source_address,
|
||||
socket_options=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1, "tcp"),),
|
||||
socket_kind=self._socket_type,
|
||||
)
|
||||
self._socket = await ssl_wrap_socket(
|
||||
self._socket,
|
||||
server_hostname=self.server
|
||||
if "server_hostname" not in self._kwargs
|
||||
else self._kwargs["server_hostname"],
|
||||
keyfile=self._kwargs["key_file"]
|
||||
if "key_file" in self._kwargs
|
||||
else None,
|
||||
certfile=self._kwargs["cert_file"]
|
||||
if "cert_file" in self._kwargs
|
||||
else None,
|
||||
cert_reqs=resolve_cert_reqs(self._kwargs["cert_reqs"])
|
||||
if "cert_reqs" in self._kwargs
|
||||
else None,
|
||||
ca_certs=self._kwargs["ca_certs"]
|
||||
if "ca_certs" in self._kwargs
|
||||
else None,
|
||||
ssl_version=self._kwargs["ssl_version"]
|
||||
if "ssl_version" in self._kwargs
|
||||
else None,
|
||||
ciphers=self._kwargs["ciphers"] if "ciphers" in self._kwargs else None,
|
||||
ca_cert_dir=self._kwargs["ca_cert_dir"]
|
||||
if "ca_cert_dir" in self._kwargs
|
||||
else None,
|
||||
key_password=self._kwargs["key_password"]
|
||||
if "key_password" in self._kwargs
|
||||
else None,
|
||||
ca_cert_data=self._kwargs["ca_cert_data"]
|
||||
if "ca_cert_data" in self._kwargs
|
||||
else None,
|
||||
certdata=self._kwargs["cert_data"]
|
||||
if "cert_data" in self._kwargs
|
||||
else None,
|
||||
keydata=self._kwargs["key_data"]
|
||||
if "key_data" in self._kwargs
|
||||
else None,
|
||||
)
|
||||
self._connect_finalized.set()
|
||||
|
||||
return await super().getaddrinfo(
|
||||
host,
|
||||
port,
|
||||
family=family,
|
||||
type=type,
|
||||
proto=proto,
|
||||
flags=flags,
|
||||
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
|
||||
)
|
||||
|
||||
|
||||
class GoogleResolver(
|
||||
TLSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "google"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("dns.google", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class CloudflareResolver(
|
||||
TLSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "cloudflare"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("1.1.1.1", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class AdGuardResolver(
|
||||
TLSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "adguard"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class OpenDNSResolver(
|
||||
TLSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "opendns"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("dns.opendns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class Quad9Resolver(
|
||||
TLSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "quad9"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("dns11.quad9.net", port, *patterns, **kwargs)
|
||||
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._socket import (
|
||||
AdGuardResolver,
|
||||
CloudflareResolver,
|
||||
GoogleResolver,
|
||||
PlainResolver,
|
||||
Quad9Resolver,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"PlainResolver",
|
||||
"CloudflareResolver",
|
||||
"GoogleResolver",
|
||||
"Quad9Resolver",
|
||||
"AdGuardResolver",
|
||||
)
|
||||
@@ -0,0 +1,431 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
import typing
|
||||
from collections import deque
|
||||
|
||||
from ....ssa import AsyncSocket
|
||||
from ...protocols import (
|
||||
COMMON_RCODE_LABEL,
|
||||
DomainNameServerQuery,
|
||||
DomainNameServerReturn,
|
||||
ProtocolResolver,
|
||||
SupportedQueryType,
|
||||
)
|
||||
from ...utils import (
|
||||
is_ipv4,
|
||||
is_ipv6,
|
||||
packet_fragment,
|
||||
rfc1035_pack,
|
||||
rfc1035_should_read,
|
||||
rfc1035_unpack,
|
||||
validate_length_of,
|
||||
)
|
||||
from ..protocols import AsyncBaseResolver
|
||||
from ..system import SystemResolver
|
||||
|
||||
|
||||
class PlainResolver(AsyncBaseResolver):
|
||||
"""
|
||||
Minimalist DNS resolver over UDP
|
||||
Comply with RFC 1035: https://datatracker.ietf.org/doc/html/rfc1035
|
||||
|
||||
EDNS is not supported, yet. But we plan to. Willing to contribute?
|
||||
"""
|
||||
|
||||
protocol = ProtocolResolver.DOU
|
||||
implementation = "socket"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: str | None,
|
||||
port: int | None = None,
|
||||
*patterns: str,
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
super().__init__(server, port, *patterns, **kwargs)
|
||||
|
||||
self._socket: AsyncSocket | None = None
|
||||
|
||||
if not hasattr(self, "_socket_type"):
|
||||
self._socket_type = socket.SOCK_DGRAM
|
||||
|
||||
if "source_address" in kwargs and isinstance(kwargs["source_address"], str):
|
||||
if ":" in kwargs["source_address"]:
|
||||
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
|
||||
self._source_address: tuple[str, int] | None = (bind_ip, int(bind_port))
|
||||
else:
|
||||
self._source_address = (kwargs["source_address"], 0)
|
||||
else:
|
||||
self._source_address = None
|
||||
|
||||
if "timeout" in kwargs and isinstance(
|
||||
kwargs["timeout"],
|
||||
(
|
||||
float,
|
||||
int,
|
||||
),
|
||||
):
|
||||
self._timeout: float | int | None = kwargs["timeout"]
|
||||
else:
|
||||
self._timeout = None
|
||||
|
||||
#: Only useful for inheritance, e.g. DNS over TLS support dns-message but require a prefix.
|
||||
self._rfc1035_prefix_mandated: bool = False
|
||||
|
||||
self._unconsumed: deque[DomainNameServerReturn] = deque()
|
||||
self._pending: deque[DomainNameServerQuery] = deque()
|
||||
|
||||
self._read_semaphore: asyncio.Semaphore = asyncio.Semaphore()
|
||||
self._connect_attempt: asyncio.Event = asyncio.Event()
|
||||
self._connect_finalized: asyncio.Event = asyncio.Event()
|
||||
|
||||
self._terminated: bool = False
|
||||
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
if not self._terminated:
|
||||
with self._lock:
|
||||
if self._socket is not None:
|
||||
self._socket.close()
|
||||
await self._socket.wait_for_close()
|
||||
self._terminated = True
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return not self._terminated
|
||||
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if host is None:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Tried to resolve 'localhost' from a PlainResolver"
|
||||
)
|
||||
|
||||
if port is None:
|
||||
port = 0 # Defensive: stdlib cpy behavior
|
||||
if isinstance(port, str):
|
||||
port = int(port) # Defensive: stdlib cpy behavior
|
||||
if port < 0:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Servname not supported for ai_socktype"
|
||||
)
|
||||
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii") # Defensive: stdlib cpy behavior
|
||||
|
||||
if is_ipv4(host):
|
||||
if family == socket.AF_INET6:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
type,
|
||||
6,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
),
|
||||
)
|
||||
]
|
||||
elif is_ipv6(host):
|
||||
if family == socket.AF_INET:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET6,
|
||||
type,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
validate_length_of(host)
|
||||
|
||||
if self._socket is None and self._connect_attempt.is_set() is False:
|
||||
self._connect_attempt.set()
|
||||
assert self.server is not None
|
||||
self._socket = await SystemResolver().create_connection(
|
||||
(self.server, self.port or 53),
|
||||
timeout=self._timeout,
|
||||
source_address=self._source_address,
|
||||
socket_options=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1, "tcp"),),
|
||||
socket_kind=self._socket_type,
|
||||
)
|
||||
self._connect_finalized.set()
|
||||
else:
|
||||
await self._connect_finalized.wait()
|
||||
assert self._socket is not None
|
||||
await self._socket.wait_for_readiness()
|
||||
|
||||
remote_preemptive_quic_rr = False
|
||||
|
||||
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
|
||||
quic_upgrade_via_dns_rr = False
|
||||
|
||||
tbq = []
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET]:
|
||||
tbq.append(SupportedQueryType.A)
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
|
||||
tbq.append(SupportedQueryType.AAAA)
|
||||
|
||||
if quic_upgrade_via_dns_rr:
|
||||
tbq.append(SupportedQueryType.HTTPS)
|
||||
|
||||
queries = DomainNameServerQuery.bulk(host, *tbq)
|
||||
|
||||
for q in queries:
|
||||
payload = bytes(q)
|
||||
self._pending.append(q)
|
||||
|
||||
if self._rfc1035_prefix_mandated is True:
|
||||
payload = rfc1035_pack(payload)
|
||||
|
||||
await self._socket.sendall(payload)
|
||||
|
||||
responses: list[DomainNameServerReturn] = []
|
||||
|
||||
while len(responses) < len(tbq):
|
||||
await self._read_semaphore.acquire()
|
||||
#: There we want to verify if another thread got a response that belong to this thread.
|
||||
if self._unconsumed:
|
||||
dns_resp = None
|
||||
|
||||
for query in queries:
|
||||
for unconsumed in self._unconsumed:
|
||||
if unconsumed.id == query.id:
|
||||
dns_resp = unconsumed
|
||||
responses.append(dns_resp)
|
||||
break
|
||||
if dns_resp:
|
||||
break
|
||||
|
||||
if dns_resp:
|
||||
self._pending.remove(query)
|
||||
self._unconsumed.remove(dns_resp)
|
||||
self._read_semaphore.release()
|
||||
continue
|
||||
|
||||
try:
|
||||
data_in_or_segments = await self._socket.recv(1500)
|
||||
|
||||
if isinstance(data_in_or_segments, list):
|
||||
payloads = data_in_or_segments
|
||||
elif data_in_or_segments:
|
||||
payloads = [data_in_or_segments]
|
||||
else:
|
||||
payloads = []
|
||||
|
||||
if self._rfc1035_prefix_mandated is True and payloads:
|
||||
payload = b"".join(payloads)
|
||||
while rfc1035_should_read(payload):
|
||||
extra = await self._socket.recv(1500)
|
||||
if isinstance(extra, list):
|
||||
payload += b"".join(extra)
|
||||
else:
|
||||
payload += extra
|
||||
payloads = [payload]
|
||||
except (TimeoutError, OSError, socket.timeout, ConnectionError) as e:
|
||||
raise socket.gaierror(
|
||||
"Got unexpectedly disconnected while waiting for name resolution"
|
||||
) from e
|
||||
|
||||
self._read_semaphore.release()
|
||||
|
||||
if not payloads:
|
||||
self._terminated = True
|
||||
raise socket.gaierror(
|
||||
"Got unexpectedly disconnected while waiting for name resolution"
|
||||
)
|
||||
|
||||
pending_raw_identifiers = [_.raw_id for _ in self._pending]
|
||||
|
||||
for payload in payloads:
|
||||
#: We can receive two responses at once (or more, concatenated). Let's unwrap them.
|
||||
if self._rfc1035_prefix_mandated is True:
|
||||
fragments = rfc1035_unpack(payload)
|
||||
else:
|
||||
fragments = packet_fragment(payload, *pending_raw_identifiers)
|
||||
|
||||
for fragment in fragments:
|
||||
dns_resp = DomainNameServerReturn(fragment)
|
||||
|
||||
if any(dns_resp.id == _.id for _ in queries):
|
||||
responses.append(dns_resp)
|
||||
|
||||
query_tbr: DomainNameServerQuery | None = None
|
||||
|
||||
for query_tbr in self._pending:
|
||||
if query_tbr.id == dns_resp.id:
|
||||
break
|
||||
|
||||
if query_tbr:
|
||||
self._pending.remove(query_tbr)
|
||||
else:
|
||||
self._unconsumed.append(dns_resp)
|
||||
|
||||
results = []
|
||||
|
||||
for response in responses:
|
||||
if not response.is_ok:
|
||||
if response.rcode == 2:
|
||||
raise socket.gaierror(
|
||||
f"DNSSEC validation failure. Check http://dnsviz.net/d/{host}/dnssec/ and http://dnssec-debugger.verisignlabs.com/{host} for errors"
|
||||
)
|
||||
raise socket.gaierror(
|
||||
f"DNS returned an error: {COMMON_RCODE_LABEL[response.rcode] if response.rcode in COMMON_RCODE_LABEL else f'code {response.rcode}'}"
|
||||
)
|
||||
|
||||
for record in response.records:
|
||||
if record[0] == SupportedQueryType.HTTPS:
|
||||
assert isinstance(record[-1], dict)
|
||||
if "h3" in record[-1]["alpn"]:
|
||||
remote_preemptive_quic_rr = True
|
||||
continue
|
||||
|
||||
assert not isinstance(record[-1], dict)
|
||||
|
||||
inet_type = (
|
||||
socket.AF_INET
|
||||
if record[0] == SupportedQueryType.A
|
||||
else socket.AF_INET6
|
||||
)
|
||||
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
|
||||
(
|
||||
record[-1],
|
||||
port,
|
||||
)
|
||||
if inet_type == socket.AF_INET
|
||||
else (
|
||||
record[-1],
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
)
|
||||
|
||||
results.append(
|
||||
(
|
||||
inet_type,
|
||||
type,
|
||||
6 if type == socket.SOCK_STREAM else 17,
|
||||
"",
|
||||
dst_addr,
|
||||
)
|
||||
)
|
||||
|
||||
quic_results = []
|
||||
|
||||
if remote_preemptive_quic_rr:
|
||||
any_specified = False
|
||||
|
||||
for result in results:
|
||||
if result[1] == socket.SOCK_STREAM:
|
||||
quic_results.append(
|
||||
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
|
||||
)
|
||||
else:
|
||||
any_specified = True
|
||||
break
|
||||
|
||||
if any_specified:
|
||||
quic_results = []
|
||||
|
||||
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
|
||||
|
||||
|
||||
class CloudflareResolver(
|
||||
PlainResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "cloudflare"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("1.1.1.1", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class GoogleResolver(
|
||||
PlainResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "google"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("8.8.8.8", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class Quad9Resolver(
|
||||
PlainResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "quad9"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("9.9.9.9", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class AdGuardResolver(
|
||||
PlainResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "adguard"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("94.140.14.140", port, *patterns, **kwargs)
|
||||
@@ -0,0 +1,243 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import typing
|
||||
from abc import ABCMeta
|
||||
from base64 import b64encode
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
from ....util import parse_url
|
||||
from ..factories import ResolverDescription
|
||||
from ..protocols import ProtocolResolver
|
||||
from .protocols import AsyncBaseResolver
|
||||
|
||||
|
||||
class AsyncResolverFactory(metaclass=ABCMeta):
|
||||
@staticmethod
|
||||
def has(
|
||||
protocol: ProtocolResolver,
|
||||
specifier: str | None = None,
|
||||
implementation: str | None = None,
|
||||
) -> bool:
|
||||
package_name: str = __name__.split(".")[0]
|
||||
module_expr = f".{protocol.value.replace('-', '_')}"
|
||||
|
||||
if implementation:
|
||||
module_expr += f"._{implementation.replace('-', '_').lower()}"
|
||||
|
||||
try:
|
||||
resolver_module = importlib.import_module(
|
||||
module_expr, f"{package_name}.contrib.resolver._async"
|
||||
)
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
implementations: list[tuple[str, type[AsyncBaseResolver]]] = inspect.getmembers(
|
||||
resolver_module,
|
||||
lambda e: isinstance(e, type)
|
||||
and issubclass(e, AsyncBaseResolver)
|
||||
and (
|
||||
(specifier is None and e.specifier is None) or specifier == e.specifier
|
||||
),
|
||||
)
|
||||
|
||||
if not implementations:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def new(
|
||||
protocol: ProtocolResolver,
|
||||
specifier: str | None = None,
|
||||
implementation: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncBaseResolver:
|
||||
package_name: str = __name__.split(".")[0]
|
||||
|
||||
module_expr = f".{protocol.value.replace('-', '_')}"
|
||||
|
||||
if implementation:
|
||||
module_expr += f"._{implementation.replace('-', '_').lower()}"
|
||||
|
||||
spe_msg = " " if specifier is None else f' (w/ specifier "{specifier}") '
|
||||
|
||||
try:
|
||||
resolver_module = importlib.import_module(
|
||||
module_expr, f"{package_name}.contrib.resolver._async"
|
||||
)
|
||||
except ImportError as e:
|
||||
raise NotImplementedError(
|
||||
f"{protocol}{spe_msg}cannot be loaded. Tried to import '{module_expr}'. Did you specify a non-existent implementation?"
|
||||
) from e
|
||||
|
||||
implementations: list[tuple[str, type[AsyncBaseResolver]]] = inspect.getmembers(
|
||||
resolver_module,
|
||||
lambda e: isinstance(e, type)
|
||||
and issubclass(e, AsyncBaseResolver)
|
||||
and (
|
||||
(specifier is None and e.specifier is None) or specifier == e.specifier
|
||||
)
|
||||
and hasattr(e, "protocol")
|
||||
and e.protocol == protocol,
|
||||
)
|
||||
|
||||
if not implementations:
|
||||
raise NotImplementedError(
|
||||
f"{protocol}{spe_msg}cannot be loaded. "
|
||||
"No compatible implementation available. "
|
||||
"Make sure your implementation inherit from BaseResolver."
|
||||
)
|
||||
|
||||
implementation_target: type[AsyncBaseResolver] = implementations.pop()[1]
|
||||
|
||||
return implementation_target(**kwargs)
|
||||
|
||||
|
||||
class AsyncResolverDescription(ResolverDescription):
|
||||
"""Describe how a BaseResolver must be instantiated."""
|
||||
|
||||
def new(self) -> AsyncBaseResolver:
|
||||
kwargs = {**self.kwargs}
|
||||
|
||||
if self.server:
|
||||
kwargs["server"] = self.server
|
||||
if self.port:
|
||||
kwargs["port"] = self.port
|
||||
if self.host_patterns:
|
||||
kwargs["patterns"] = self.host_patterns
|
||||
|
||||
return AsyncResolverFactory.new(
|
||||
self.protocol,
|
||||
self.specifier,
|
||||
self.implementation,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_url(url: str) -> AsyncResolverDescription:
|
||||
parsed_url = parse_url(url)
|
||||
|
||||
schema = parsed_url.scheme
|
||||
|
||||
if schema is None:
|
||||
raise ValueError("Given DNS url is missing a protocol")
|
||||
|
||||
specifier = None
|
||||
implementation = None
|
||||
|
||||
if "+" in schema:
|
||||
schema, specifier = tuple(schema.lower().split("+", 1))
|
||||
|
||||
protocol = ProtocolResolver(schema)
|
||||
kwargs: dict[str, typing.Any] = {}
|
||||
|
||||
if parsed_url.path:
|
||||
kwargs["path"] = parsed_url.path
|
||||
|
||||
if parsed_url.auth:
|
||||
kwargs["headers"] = dict()
|
||||
if ":" in parsed_url.auth:
|
||||
username, password = parsed_url.auth.split(":")
|
||||
|
||||
username = username.strip("'\"")
|
||||
password = password.strip("'\"")
|
||||
|
||||
kwargs["headers"]["Authorization"] = (
|
||||
f"Basic {b64encode(f'{username}:{password}'.encode()).decode()}"
|
||||
)
|
||||
else:
|
||||
kwargs["headers"]["Authorization"] = f"Bearer {parsed_url.auth}"
|
||||
|
||||
if parsed_url.query:
|
||||
parameters = parse_qs(parsed_url.query)
|
||||
|
||||
for parameter in parameters:
|
||||
if not parameters[parameter]:
|
||||
continue
|
||||
|
||||
parameter_insensible = parameter.lower()
|
||||
|
||||
if (
|
||||
isinstance(parameters[parameter], list)
|
||||
and len(parameters[parameter]) > 1
|
||||
):
|
||||
if parameter == "implementation":
|
||||
raise ValueError("Only one implementation can be passed to URL")
|
||||
|
||||
values = []
|
||||
|
||||
for e in parameters[parameter]:
|
||||
if "," in e:
|
||||
values.extend(e.split(","))
|
||||
else:
|
||||
values.append(e)
|
||||
|
||||
if parameter_insensible in kwargs:
|
||||
if isinstance(kwargs[parameter_insensible], list):
|
||||
kwargs[parameter_insensible].extend(values)
|
||||
else:
|
||||
values.append(kwargs[parameter_insensible])
|
||||
kwargs[parameter_insensible] = values
|
||||
continue
|
||||
|
||||
kwargs[parameter_insensible] = values
|
||||
continue
|
||||
|
||||
value: str = parameters[parameter][0].lower().strip(" ")
|
||||
|
||||
if parameter == "implementation":
|
||||
implementation = value
|
||||
continue
|
||||
|
||||
if "," in value:
|
||||
list_of_values = value.split(",")
|
||||
|
||||
if parameter_insensible in kwargs:
|
||||
if isinstance(kwargs[parameter_insensible], list):
|
||||
kwargs[parameter_insensible].extend(list_of_values)
|
||||
else:
|
||||
list_of_values.append(kwargs[parameter_insensible])
|
||||
continue
|
||||
|
||||
kwargs[parameter_insensible] = list_of_values
|
||||
continue
|
||||
|
||||
value_converted: bool | int | float | None = None
|
||||
|
||||
if value in ["false", "true"]:
|
||||
value_converted = True if value == "true" else False
|
||||
elif value.isdigit():
|
||||
value_converted = int(value)
|
||||
elif (
|
||||
value.count(".") == 1
|
||||
and value.index(".") > 0
|
||||
and value.replace(".", "").isdigit()
|
||||
):
|
||||
value_converted = float(value)
|
||||
|
||||
kwargs[parameter_insensible] = (
|
||||
value if value_converted is None else value_converted
|
||||
)
|
||||
|
||||
host_patterns: list[str] = []
|
||||
|
||||
if "hosts" in kwargs:
|
||||
host_patterns = (
|
||||
kwargs["hosts"].split(",")
|
||||
if isinstance(kwargs["hosts"], str)
|
||||
else kwargs["hosts"]
|
||||
)
|
||||
del kwargs["hosts"]
|
||||
|
||||
return AsyncResolverDescription(
|
||||
protocol,
|
||||
specifier,
|
||||
implementation,
|
||||
parsed_url.host,
|
||||
parsed_url.port,
|
||||
*host_patterns,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -0,0 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._dict import InMemoryResolver
|
||||
|
||||
__all__ = ("InMemoryResolver",)
|
||||
@@ -0,0 +1,186 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
|
||||
from .....util.url import _IPV6_ADDRZ_RE
|
||||
from ...protocols import ProtocolResolver
|
||||
from ...utils import is_ipv4, is_ipv6
|
||||
from ..protocols import AsyncBaseResolver
|
||||
|
||||
|
||||
class InMemoryResolver(AsyncBaseResolver):
|
||||
protocol = ProtocolResolver.MANUAL
|
||||
implementation = "dict"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any):
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
kwargs.pop("port")
|
||||
super().__init__(None, None, *patterns, **kwargs)
|
||||
|
||||
self._maxsize = 65535 if "maxsize" not in kwargs else int(kwargs["maxsize"])
|
||||
self._hosts: dict[str, list[tuple[socket.AddressFamily, str]]] = {}
|
||||
|
||||
if self._host_patterns:
|
||||
for record in self._host_patterns:
|
||||
if ":" not in record:
|
||||
continue
|
||||
hostname, addr = record.split(":", 1)
|
||||
self.register(hostname, addr)
|
||||
self._host_patterns = tuple([])
|
||||
|
||||
def recycle(self) -> AsyncBaseResolver:
|
||||
return self
|
||||
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
pass # no-op
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True
|
||||
|
||||
def have_constraints(self) -> bool:
|
||||
return True
|
||||
|
||||
def support(self, hostname: str | bytes | None) -> bool | None:
|
||||
if hostname is None:
|
||||
hostname = "localhost"
|
||||
if isinstance(hostname, bytes):
|
||||
hostname = hostname.decode("ascii")
|
||||
return hostname in self._hosts
|
||||
|
||||
def register(self, hostname: str, ipaddr: str) -> None:
|
||||
if hostname not in self._hosts:
|
||||
self._hosts[hostname] = []
|
||||
else:
|
||||
for e in self._hosts[hostname]:
|
||||
t, addr = e
|
||||
if addr in ipaddr:
|
||||
return
|
||||
|
||||
if _IPV6_ADDRZ_RE.match(ipaddr):
|
||||
self._hosts[hostname].append((socket.AF_INET6, ipaddr[1:-1]))
|
||||
elif is_ipv6(ipaddr):
|
||||
self._hosts[hostname].append((socket.AF_INET6, ipaddr))
|
||||
else:
|
||||
self._hosts[hostname].append((socket.AF_INET, ipaddr))
|
||||
|
||||
if len(self._hosts) > self._maxsize:
|
||||
k = None
|
||||
for k in self._hosts.keys():
|
||||
break
|
||||
if k:
|
||||
self._hosts.pop(k)
|
||||
|
||||
def clear(self, hostname: str) -> None:
|
||||
if hostname in self._hosts:
|
||||
del self._hosts[hostname]
|
||||
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if host is None:
|
||||
host = "localhost" # Defensive: stdlib cpy behavior
|
||||
|
||||
if port is None:
|
||||
port = 0 # Defensive: stdlib cpy behavior
|
||||
if isinstance(port, str):
|
||||
port = int(port) # Defensive: stdlib cpy behavior
|
||||
if port < 0:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Servname not supported for ai_socktype"
|
||||
)
|
||||
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii") # Defensive: stdlib cpy behavior
|
||||
|
||||
if is_ipv4(host):
|
||||
if family == socket.AF_INET6:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
type,
|
||||
6,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
),
|
||||
)
|
||||
]
|
||||
elif is_ipv6(host):
|
||||
if family == socket.AF_INET:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET6,
|
||||
type,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
results: list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
] = []
|
||||
|
||||
if host not in self._hosts:
|
||||
raise socket.gaierror(f"no records found for hostname {host} in-memory")
|
||||
|
||||
for entry in self._hosts[host]:
|
||||
addr_type, addr_target = entry
|
||||
|
||||
if family != socket.AF_UNSPEC:
|
||||
if family != addr_type:
|
||||
continue
|
||||
|
||||
results.append(
|
||||
(
|
||||
addr_type,
|
||||
type,
|
||||
6 if type == socket.SOCK_STREAM else 17,
|
||||
"",
|
||||
(addr_target, port)
|
||||
if addr_type == socket.AF_INET
|
||||
else (addr_target, port, 0, 0),
|
||||
)
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise socket.gaierror(f"no records found for hostname {host} in-memory")
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,105 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
|
||||
from ...protocols import ProtocolResolver
|
||||
from ...utils import is_ipv4, is_ipv6
|
||||
from ..protocols import AsyncBaseResolver
|
||||
|
||||
|
||||
class NullResolver(AsyncBaseResolver):
|
||||
protocol = ProtocolResolver.NULL
|
||||
implementation = "dummy"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any):
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
kwargs.pop("port")
|
||||
super().__init__(None, None, *patterns, **kwargs)
|
||||
|
||||
def recycle(self) -> AsyncBaseResolver:
|
||||
return self
|
||||
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
pass # no-op
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True
|
||||
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if host is None:
|
||||
host = "localhost" # Defensive: stdlib cpy behavior
|
||||
|
||||
if port is None:
|
||||
port = 0 # Defensive: stdlib cpy behavior
|
||||
if isinstance(port, str):
|
||||
port = int(port) # Defensive: stdlib cpy behavior
|
||||
if port < 0:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Servname not supported for ai_socktype"
|
||||
)
|
||||
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii") # Defensive: stdlib cpy behavior
|
||||
|
||||
if is_ipv4(host):
|
||||
if family == socket.AF_INET6:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
type,
|
||||
6,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
),
|
||||
)
|
||||
]
|
||||
elif is_ipv6(host):
|
||||
if family == socket.AF_INET:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET6,
|
||||
type,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
raise socket.gaierror(f"Tried to resolve '{host}' using the NullResolver")
|
||||
|
||||
|
||||
__all__ = ("NullResolver",)
|
||||
@@ -0,0 +1,375 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import socket
|
||||
import struct
|
||||
import sys
|
||||
import typing
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from ...._constant import UDP_LINUX_GRO
|
||||
from ...._typing import _TYPE_SOCKET_OPTIONS, _TYPE_TIMEOUT_INTERNAL
|
||||
from ....exceptions import LocationParseError
|
||||
from ....util.connection import _set_socket_options, allowed_gai_family
|
||||
from ....util.timeout import _DEFAULT_TIMEOUT
|
||||
from ...ssa import AsyncSocket
|
||||
from ...ssa._timeout import timeout as timeout_
|
||||
from ..protocols import BaseResolver
|
||||
|
||||
|
||||
class AsyncBaseResolver(BaseResolver, metaclass=ABCMeta):
|
||||
def recycle(self) -> AsyncBaseResolver:
|
||||
return super().recycle() # type: ignore[return-value]
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
"""Terminate the given resolver instance. This should render it unusable. Further inquiries should raise an exception."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
"""This method align itself on the standard library socket.getaddrinfo(). It must be implemented as-is on your Resolver."""
|
||||
raise NotImplementedError
|
||||
|
||||
# This function is copied from socket.py in the Python 2.7 standard
|
||||
# library test suite. Added to its signature is only `socket_options`.
|
||||
# One additional modification is that we avoid binding to IPv6 servers
|
||||
# discovered in DNS if the system doesn't have IPv6 functionality.
|
||||
async def create_connection( # type: ignore[override]
|
||||
self,
|
||||
address: tuple[str, int],
|
||||
timeout: _TYPE_TIMEOUT_INTERNAL = _DEFAULT_TIMEOUT,
|
||||
source_address: tuple[str, int] | None = None,
|
||||
socket_options: _TYPE_SOCKET_OPTIONS | None = None,
|
||||
socket_kind: socket.SocketKind = socket.SOCK_STREAM,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
timing_hook: typing.Callable[[tuple[timedelta, timedelta, datetime]], None]
|
||||
| None = None,
|
||||
default_socket_family: socket.AddressFamily = socket.AF_UNSPEC,
|
||||
) -> AsyncSocket:
|
||||
"""Connect to *address* and return the socket object.
|
||||
|
||||
Convenience function. Connect to *address* (a 2-tuple ``(host,
|
||||
port)``) and return the socket object. Passing the optional
|
||||
*timeout* parameter will set the timeout on the socket instance
|
||||
before attempting to connect. If no *timeout* is supplied, the
|
||||
global default timeout setting returned by :func:`socket.getdefaulttimeout`
|
||||
is used. If *source_address* is set it must be a tuple of (host, port)
|
||||
for the socket to bind as a source address before making the connection.
|
||||
An host of '' or port 0 tells the OS to use the default.
|
||||
"""
|
||||
|
||||
host, port = address
|
||||
if host.startswith("["):
|
||||
host = host.strip("[]")
|
||||
err = None
|
||||
|
||||
# Using the value from allowed_gai_family() in the context of getaddrinfo lets
|
||||
# us select whether to work with IPv4 DNS records, IPv6 records, or both.
|
||||
# The original create_connection function always returns all records.
|
||||
family = allowed_gai_family()
|
||||
|
||||
if family != socket.AF_UNSPEC:
|
||||
default_socket_family = family
|
||||
|
||||
if source_address is not None:
|
||||
if isinstance(
|
||||
ipaddress.ip_address(source_address[0]), ipaddress.IPv4Address
|
||||
):
|
||||
default_socket_family = socket.AF_INET
|
||||
else:
|
||||
default_socket_family = socket.AF_INET6
|
||||
|
||||
try:
|
||||
host.encode("idna")
|
||||
except UnicodeError:
|
||||
raise LocationParseError(f"'{host}', label empty or too long") from None
|
||||
|
||||
dt_pre_resolve = datetime.now(tz=timezone.utc)
|
||||
if timeout is not _DEFAULT_TIMEOUT and timeout is not None:
|
||||
# we can hang here in case of bad networking conditions
|
||||
# the DNS may never answer or the packets can be lost.
|
||||
# this isn't possible in sync mode. unfortunately.
|
||||
# found by user at https://github.com/jawah/niquests/issues/183
|
||||
# todo: find a way to limit getaddrinfo delays in sync mode.
|
||||
try:
|
||||
async with timeout_(timeout):
|
||||
records = await self.getaddrinfo(
|
||||
host,
|
||||
port,
|
||||
default_socket_family,
|
||||
socket_kind,
|
||||
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
|
||||
)
|
||||
except TimeoutError:
|
||||
raise socket.gaierror(
|
||||
f"unable to resolve '{host}' within timeout. the DNS server may be unresponsive."
|
||||
)
|
||||
else:
|
||||
records = await self.getaddrinfo(
|
||||
host,
|
||||
port,
|
||||
default_socket_family,
|
||||
socket_kind,
|
||||
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
|
||||
)
|
||||
delta_post_resolve = datetime.now(tz=timezone.utc) - dt_pre_resolve
|
||||
|
||||
dt_pre_established = datetime.now(tz=timezone.utc)
|
||||
for res in records:
|
||||
af, socktype, proto, canonname, sa = res
|
||||
sock = None
|
||||
try:
|
||||
sock = AsyncSocket(af, socktype, proto)
|
||||
|
||||
# we need to add this or reusing the same origin port will likely fail within
|
||||
# short period of time. kernel put port on wait shut.
|
||||
if source_address:
|
||||
try:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
|
||||
except (OSError, AttributeError): # Defensive: very old OS?
|
||||
try:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
except (
|
||||
OSError,
|
||||
AttributeError,
|
||||
): # Defensive: we can't do anything better than this.
|
||||
pass
|
||||
|
||||
try:
|
||||
sock.setsockopt(
|
||||
socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0)
|
||||
)
|
||||
except (OSError, AttributeError):
|
||||
pass
|
||||
|
||||
# attempt to leverage GRO when under Linux
|
||||
if socktype == socket.SOCK_DGRAM and sys.platform == "linux":
|
||||
try:
|
||||
sock.setsockopt(socket.SOL_UDP, UDP_LINUX_GRO, 1)
|
||||
except OSError: # Defensive: oh, well(...) anyway!
|
||||
pass
|
||||
|
||||
# If provided, set socket level options before connecting.
|
||||
_set_socket_options(sock, socket_options)
|
||||
|
||||
if timeout is not _DEFAULT_TIMEOUT:
|
||||
sock.settimeout(timeout)
|
||||
if source_address:
|
||||
sock.bind(source_address)
|
||||
|
||||
try:
|
||||
await sock.connect(sa)
|
||||
except asyncio.CancelledError:
|
||||
sock.close()
|
||||
raise
|
||||
|
||||
# Break explicitly a reference cycle
|
||||
err = None
|
||||
|
||||
delta_post_established = (
|
||||
datetime.now(tz=timezone.utc) - dt_pre_established
|
||||
)
|
||||
|
||||
if timing_hook is not None:
|
||||
timing_hook(
|
||||
(
|
||||
delta_post_resolve,
|
||||
delta_post_established,
|
||||
datetime.now(tz=timezone.utc),
|
||||
)
|
||||
)
|
||||
|
||||
return sock
|
||||
except (OSError, OverflowError) as _:
|
||||
err = _
|
||||
if sock is not None:
|
||||
sock.close()
|
||||
if isinstance(_, OverflowError):
|
||||
break
|
||||
|
||||
if err is not None:
|
||||
try:
|
||||
raise err
|
||||
finally:
|
||||
# Break explicitly a reference cycle
|
||||
err = None
|
||||
else:
|
||||
raise OSError("getaddrinfo returns an empty list")
|
||||
|
||||
|
||||
class AsyncManyResolver(AsyncBaseResolver):
|
||||
"""
|
||||
Special resolver that use many child resolver. Priorities
|
||||
are based on given order (list of BaseResolver).
|
||||
"""
|
||||
|
||||
def __init__(self, *resolvers: AsyncBaseResolver) -> None:
|
||||
super().__init__(None, None)
|
||||
|
||||
self._size = len(resolvers)
|
||||
|
||||
self._unconstrained: list[AsyncBaseResolver] = [
|
||||
_ for _ in resolvers if not _.have_constraints()
|
||||
]
|
||||
self._constrained: list[AsyncBaseResolver] = [
|
||||
_ for _ in resolvers if _.have_constraints()
|
||||
]
|
||||
|
||||
self._concurrent: int = 0
|
||||
self._terminated: bool = False
|
||||
|
||||
def recycle(self) -> AsyncBaseResolver:
|
||||
resolvers = []
|
||||
|
||||
for resolver in self._unconstrained + self._constrained:
|
||||
resolvers.append(resolver.recycle())
|
||||
|
||||
return AsyncManyResolver(*resolvers)
|
||||
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
for resolver in self._unconstrained + self._constrained:
|
||||
await resolver.close()
|
||||
|
||||
self._terminated = True
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return not self._terminated
|
||||
|
||||
def __resolvers(
|
||||
self, constrained: bool = False
|
||||
) -> typing.Generator[AsyncBaseResolver, None, None]:
|
||||
resolvers = self._unconstrained if not constrained else self._constrained
|
||||
|
||||
if not resolvers:
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
self._concurrent += 1
|
||||
|
||||
try:
|
||||
resolver_count = len(resolvers)
|
||||
start_idx = (self._concurrent - 1) % resolver_count
|
||||
|
||||
for idx in range(start_idx, resolver_count):
|
||||
if not resolvers[idx].is_available():
|
||||
with self._lock:
|
||||
resolvers[idx] = resolvers[idx].recycle()
|
||||
yield resolvers[idx]
|
||||
|
||||
if start_idx > 0:
|
||||
for idx in range(0, start_idx):
|
||||
if not resolvers[idx].is_available():
|
||||
with self._lock:
|
||||
resolvers[idx] = resolvers[idx].recycle()
|
||||
yield resolvers[idx]
|
||||
finally:
|
||||
with self._lock:
|
||||
self._concurrent -= 1
|
||||
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii")
|
||||
if host is None:
|
||||
host = "localhost"
|
||||
|
||||
tested_resolvers = []
|
||||
|
||||
any_constrained_tried: bool = False
|
||||
|
||||
for resolver in self.__resolvers(True):
|
||||
can_resolve = resolver.support(host)
|
||||
|
||||
if can_resolve is True:
|
||||
any_constrained_tried = True
|
||||
|
||||
try:
|
||||
results = await resolver.getaddrinfo(
|
||||
host,
|
||||
port,
|
||||
family,
|
||||
type,
|
||||
proto,
|
||||
flags,
|
||||
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
|
||||
)
|
||||
|
||||
if results:
|
||||
return results
|
||||
except socket.gaierror as exc:
|
||||
if isinstance(exc.args[0], str) and (
|
||||
"DNSSEC" in exc.args[0] or "DNSKEY" in exc.args[0]
|
||||
):
|
||||
raise
|
||||
continue
|
||||
elif can_resolve is False:
|
||||
tested_resolvers.append(resolver)
|
||||
|
||||
if any_constrained_tried:
|
||||
raise socket.gaierror(
|
||||
f"Name or service not known: {host} using {self._size - len(self._unconstrained)} resolver(s)"
|
||||
)
|
||||
|
||||
for resolver in self.__resolvers():
|
||||
try:
|
||||
results = await resolver.getaddrinfo(
|
||||
host,
|
||||
port,
|
||||
family,
|
||||
type,
|
||||
proto,
|
||||
flags,
|
||||
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
|
||||
)
|
||||
|
||||
if results:
|
||||
return results
|
||||
except socket.gaierror as exc:
|
||||
if isinstance(exc.args[0], str) and (
|
||||
"DNSSEC" in exc.args[0] or "DNSKEY" in exc.args[0]
|
||||
):
|
||||
raise
|
||||
continue
|
||||
|
||||
raise socket.gaierror(
|
||||
f"Name or service not known: {host} using {self._size - len(self._constrained)} resolver(s)"
|
||||
)
|
||||
@@ -0,0 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._socket import SystemResolver
|
||||
|
||||
__all__ = ("SystemResolver",)
|
||||
@@ -0,0 +1,66 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
import typing
|
||||
|
||||
from ...protocols import ProtocolResolver
|
||||
from ..protocols import AsyncBaseResolver
|
||||
|
||||
|
||||
class SystemResolver(AsyncBaseResolver):
|
||||
implementation = "socket"
|
||||
protocol = ProtocolResolver.SYSTEM
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any):
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
kwargs.pop("port")
|
||||
super().__init__(None, None, *patterns, **kwargs)
|
||||
|
||||
def support(self, hostname: str | bytes | None) -> bool | None:
|
||||
if hostname is None:
|
||||
return True
|
||||
if isinstance(hostname, bytes):
|
||||
hostname = hostname.decode("ascii")
|
||||
if hostname == "localhost":
|
||||
return True
|
||||
return super().support(hostname)
|
||||
|
||||
def recycle(self) -> AsyncBaseResolver:
|
||||
return self
|
||||
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
pass # no-op!
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True
|
||||
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
return await asyncio.get_running_loop().getaddrinfo(
|
||||
host=host,
|
||||
port=port,
|
||||
family=family,
|
||||
type=type,
|
||||
proto=proto,
|
||||
flags=flags,
|
||||
)
|
||||
Reference in New Issue
Block a user