fix: 포트 충돌 회피 — note_bridge 8098, intent_service 8099
Jellyfin(8096), OrbStack(8097) 포트 충돌으로 변경. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .factories import ResolverDescription, ResolverFactory
|
||||
from .protocols import BaseResolver, ManyResolver, ProtocolResolver
|
||||
|
||||
__all__ = (
|
||||
"ResolverFactory",
|
||||
"ProtocolResolver",
|
||||
"BaseResolver",
|
||||
"ManyResolver",
|
||||
"ResolverDescription",
|
||||
)
|
||||
@@ -0,0 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .factories import AsyncResolverDescription, AsyncResolverFactory
|
||||
from .protocols import AsyncBaseResolver, AsyncManyResolver
|
||||
|
||||
__all__ = (
|
||||
"AsyncResolverDescription",
|
||||
"AsyncResolverFactory",
|
||||
"AsyncBaseResolver",
|
||||
"AsyncManyResolver",
|
||||
)
|
||||
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._urllib3 import (
|
||||
AdGuardResolver,
|
||||
CloudflareResolver,
|
||||
GoogleResolver,
|
||||
HTTPSResolver,
|
||||
NextDNSResolver,
|
||||
OpenDNSResolver,
|
||||
Quad9Resolver,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"HTTPSResolver",
|
||||
"GoogleResolver",
|
||||
"CloudflareResolver",
|
||||
"AdGuardResolver",
|
||||
"OpenDNSResolver",
|
||||
"Quad9Resolver",
|
||||
"NextDNSResolver",
|
||||
)
|
||||
@@ -0,0 +1,656 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
from asyncio import as_completed
|
||||
from base64 import b64encode
|
||||
|
||||
from ....._async.connectionpool import AsyncHTTPSConnectionPool
|
||||
from ....._async.response import AsyncHTTPResponse
|
||||
from ....._collections import HTTPHeaderDict
|
||||
from .....backend import ConnectionInfo, HttpVersion
|
||||
from .....util.url import parse_url
|
||||
from ...protocols import (
|
||||
DomainNameServerQuery,
|
||||
DomainNameServerReturn,
|
||||
ProtocolResolver,
|
||||
SupportedQueryType,
|
||||
)
|
||||
from ...utils import is_ipv4, is_ipv6, validate_length_of, parse_https_rdata
|
||||
from ..protocols import AsyncBaseResolver
|
||||
|
||||
|
||||
class HTTPSResolver(AsyncBaseResolver):
|
||||
"""
|
||||
Advanced DNS over HTTPS resolver.
|
||||
No common ground emerged from IETF w/ JSON. Following Google’s DNS over HTTPS schematics that is
|
||||
also implemented at Cloudflare.
|
||||
|
||||
Support RFC 8484 without JSON. Disabled by default.
|
||||
"""
|
||||
|
||||
implementation = "urllib3"
|
||||
protocol = ProtocolResolver.DOH
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: str | None,
|
||||
port: int | None = None,
|
||||
*patterns: str,
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
super().__init__(server, port or 443, *patterns, **kwargs)
|
||||
|
||||
self._path: str = "/resolve"
|
||||
|
||||
if "path" in kwargs:
|
||||
if isinstance(kwargs["path"], str) and kwargs["path"] != "/":
|
||||
self._path = kwargs["path"]
|
||||
kwargs.pop("path")
|
||||
|
||||
self._rfc8484: bool = False
|
||||
|
||||
if "rfc8484" in kwargs:
|
||||
if kwargs["rfc8484"]:
|
||||
self._rfc8484 = True
|
||||
kwargs.pop("rfc8484")
|
||||
|
||||
assert self._server is not None
|
||||
|
||||
if "source_address" in kwargs:
|
||||
if isinstance(kwargs["source_address"], str):
|
||||
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
|
||||
|
||||
if bind_ip and bind_port.isdigit():
|
||||
kwargs["source_address"] = (
|
||||
bind_ip,
|
||||
int(bind_port),
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid source_address given in parameters")
|
||||
else:
|
||||
raise ValueError("invalid source_address given in parameters")
|
||||
|
||||
if "proxy" in kwargs:
|
||||
kwargs["_proxy"] = parse_url(kwargs["proxy"])
|
||||
kwargs.pop("proxy")
|
||||
|
||||
if "maxsize" not in kwargs:
|
||||
kwargs["maxsize"] = 10
|
||||
|
||||
if "proxy_headers" in kwargs and "_proxy" in kwargs:
|
||||
proxy_headers = HTTPHeaderDict()
|
||||
|
||||
if not isinstance(kwargs["proxy_headers"], list):
|
||||
kwargs["proxy_headers"] = [kwargs["proxy_headers"]]
|
||||
|
||||
for item in kwargs["proxy_headers"]:
|
||||
if ":" not in item:
|
||||
raise ValueError("Passed header is invalid in DNS parameters")
|
||||
|
||||
k, v = item.split(":", 1)
|
||||
proxy_headers.add(k, v)
|
||||
|
||||
kwargs["_proxy_headers"] = proxy_headers
|
||||
|
||||
if "headers" in kwargs:
|
||||
headers = HTTPHeaderDict()
|
||||
|
||||
if not isinstance(kwargs["headers"], list):
|
||||
kwargs["headers"] = [kwargs["headers"]]
|
||||
|
||||
for item in kwargs["headers"]:
|
||||
if ":" not in item:
|
||||
raise ValueError("Passed header is invalid in DNS parameters")
|
||||
|
||||
k, v = item.split(":", 1)
|
||||
headers.add(k, v)
|
||||
|
||||
kwargs["headers"] = headers
|
||||
|
||||
if "disabled_svn" in kwargs:
|
||||
if not isinstance(kwargs["disabled_svn"], list):
|
||||
kwargs["disabled_svn"] = [kwargs["disabled_svn"]]
|
||||
|
||||
disabled_svn = set()
|
||||
|
||||
for svn in kwargs["disabled_svn"]:
|
||||
svn = svn.lower()
|
||||
|
||||
if svn == "h11":
|
||||
disabled_svn.add(HttpVersion.h11)
|
||||
elif svn == "h2":
|
||||
disabled_svn.add(HttpVersion.h2)
|
||||
elif svn == "h3":
|
||||
disabled_svn.add(HttpVersion.h3)
|
||||
|
||||
kwargs["disabled_svn"] = disabled_svn
|
||||
|
||||
if "on_post_connection" in kwargs and callable(kwargs["on_post_connection"]):
|
||||
self._connection_callback: (
|
||||
typing.Callable[[ConnectionInfo], None] | None
|
||||
) = kwargs["on_post_connection"]
|
||||
kwargs.pop("on_post_connection")
|
||||
else:
|
||||
self._connection_callback = None
|
||||
|
||||
self._pool = AsyncHTTPSConnectionPool(self._server, self._port, **kwargs)
|
||||
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
await self._pool.close()
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self._pool.pool is not None
|
||||
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if host is None:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Tried to resolve 'localhost' from a HTTPSResolver"
|
||||
)
|
||||
|
||||
if port is None:
|
||||
port = 0 # Defensive: stdlib cpy behavior
|
||||
if isinstance(port, str):
|
||||
port = int(port) # Defensive: stdlib cpy behavior
|
||||
if port < 0:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Servname not supported for ai_socktype"
|
||||
)
|
||||
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii") # Defensive: stdlib cpy behavior
|
||||
|
||||
if is_ipv4(host):
|
||||
if family == socket.AF_INET6:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
type,
|
||||
6,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
),
|
||||
)
|
||||
]
|
||||
elif is_ipv6(host):
|
||||
if family == socket.AF_INET:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET6,
|
||||
type,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
validate_length_of(host)
|
||||
|
||||
promises = []
|
||||
remote_preemptive_quic_rr = False
|
||||
|
||||
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
|
||||
quic_upgrade_via_dns_rr = False
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET]:
|
||||
if not self._rfc8484:
|
||||
promises.append(
|
||||
await self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{"name": host, "type": "1"},
|
||||
headers={"Accept": "application/dns-json"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
dns_query = DomainNameServerQuery(
|
||||
host, SupportedQueryType.A, override_id=0
|
||||
)
|
||||
dns_payload = bytes(dns_query)
|
||||
promises.append(
|
||||
await self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{
|
||||
"dns": b64encode(dns_payload).decode().replace("=", ""),
|
||||
},
|
||||
headers={"Accept": "application/dns-message"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
|
||||
if not self._rfc8484:
|
||||
promises.append(
|
||||
await self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{"name": host, "type": "28"},
|
||||
headers={"Accept": "application/dns-json"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
dns_query = DomainNameServerQuery(
|
||||
host, SupportedQueryType.AAAA, override_id=0
|
||||
)
|
||||
dns_payload = bytes(dns_query)
|
||||
|
||||
promises.append(
|
||||
await self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{
|
||||
"dns": b64encode(dns_payload).decode().replace("=", ""),
|
||||
},
|
||||
headers={"Accept": "application/dns-message"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
|
||||
if quic_upgrade_via_dns_rr:
|
||||
if not self._rfc8484:
|
||||
promises.append(
|
||||
await self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{"name": host, "type": "65"},
|
||||
headers={"Accept": "application/dns-json"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
dns_query = DomainNameServerQuery(
|
||||
host, SupportedQueryType.HTTPS, override_id=0
|
||||
)
|
||||
dns_payload = bytes(dns_query)
|
||||
|
||||
promises.append(
|
||||
await self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{
|
||||
"dns": b64encode(dns_payload).decode().replace("=", ""),
|
||||
},
|
||||
headers={"Accept": "application/dns-message"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
|
||||
tasks = []
|
||||
responses = []
|
||||
|
||||
for promise in promises:
|
||||
# already resolved
|
||||
if isinstance(promise, AsyncHTTPResponse):
|
||||
responses.append(promise)
|
||||
continue
|
||||
tasks.append(self._pool.get_response(promise=promise))
|
||||
|
||||
if tasks:
|
||||
for waiting_promise_coro in as_completed(tasks):
|
||||
responses.append(await waiting_promise_coro) # type: ignore[arg-type]
|
||||
|
||||
results: list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
] = []
|
||||
|
||||
for response in responses:
|
||||
if response.status >= 300:
|
||||
raise socket.gaierror(
|
||||
f"DNS over HTTPS was unsuccessful, server response status {response.status}."
|
||||
)
|
||||
|
||||
if not self._rfc8484:
|
||||
payload = await response.json()
|
||||
|
||||
assert "Status" in payload and isinstance(payload["Status"], int)
|
||||
|
||||
if payload["Status"] != 0:
|
||||
msg = (
|
||||
payload["Comment"]
|
||||
if "Comment" in payload
|
||||
else f"Remote DNS indicated that an error occurred while providing resolution. Status {payload['Status']}."
|
||||
)
|
||||
|
||||
if isinstance(msg, list):
|
||||
msg = ", ".join(msg)
|
||||
|
||||
raise socket.gaierror(msg)
|
||||
|
||||
assert "Question" in payload and isinstance(payload["Question"], list)
|
||||
|
||||
if "Answer" not in payload:
|
||||
continue
|
||||
|
||||
assert isinstance(payload["Answer"], list)
|
||||
|
||||
for answer in payload["Answer"]:
|
||||
if answer["type"] not in [1, 28, 65]:
|
||||
continue
|
||||
|
||||
assert "data" in answer
|
||||
assert isinstance(answer["data"], str)
|
||||
|
||||
# DNS RR/HTTPS
|
||||
if answer["type"] == 65:
|
||||
# "1 . alpn=h3,h2 ipv4hint=104.16.132.229,104.16.133.229 ipv6hint=2606:4700::6810:84e5,2606:4700::6810:85e5"
|
||||
# or..
|
||||
# "1 . alpn=h2,h3"
|
||||
rr: str = answer["data"]
|
||||
|
||||
if rr.startswith("\\#"): # it means, raw, bytes.
|
||||
rr = "".join(rr[2:].split(" ")[2:])
|
||||
|
||||
try:
|
||||
raw_record = bytes.fromhex(rr)
|
||||
except ValueError:
|
||||
raw_record = b""
|
||||
|
||||
if not raw_record:
|
||||
continue
|
||||
|
||||
https_record = parse_https_rdata(raw_record)
|
||||
|
||||
if "h3" not in https_record["alpn"]:
|
||||
continue
|
||||
|
||||
remote_preemptive_quic_rr = True
|
||||
else:
|
||||
rr_decode: dict[str, str] = dict(
|
||||
tuple(_.lower().split("=", 1)) # type: ignore[misc]
|
||||
for _ in rr.split(" ")
|
||||
if "=" in _
|
||||
)
|
||||
|
||||
if "alpn" not in rr_decode or "h3" not in rr_decode["alpn"]:
|
||||
continue
|
||||
|
||||
remote_preemptive_quic_rr = True
|
||||
|
||||
if "ipv4hint" in rr_decode and family in [
|
||||
socket.AF_UNSPEC,
|
||||
socket.AF_INET,
|
||||
]:
|
||||
for ipv4 in rr_decode["ipv4hint"].split(","):
|
||||
results.append(
|
||||
(
|
||||
socket.AF_INET,
|
||||
socket.SOCK_DGRAM,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
ipv4,
|
||||
port,
|
||||
),
|
||||
)
|
||||
)
|
||||
if "ipv6hint" in rr_decode and family in [
|
||||
socket.AF_UNSPEC,
|
||||
socket.AF_INET6,
|
||||
]:
|
||||
for ipv6 in rr_decode["ipv6hint"].split(","):
|
||||
results.append(
|
||||
(
|
||||
socket.AF_INET6,
|
||||
socket.SOCK_DGRAM,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
ipv6,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
inet_type = (
|
||||
socket.AF_INET if answer["type"] == 1 else socket.AF_INET6
|
||||
)
|
||||
|
||||
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
|
||||
(
|
||||
answer["data"],
|
||||
port,
|
||||
)
|
||||
if inet_type == socket.AF_INET
|
||||
else (
|
||||
answer["data"],
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
)
|
||||
|
||||
results.append(
|
||||
(
|
||||
inet_type,
|
||||
type,
|
||||
6 if type == socket.SOCK_STREAM else 17,
|
||||
"",
|
||||
dst_addr,
|
||||
)
|
||||
)
|
||||
else:
|
||||
dns_resp = DomainNameServerReturn(await response.data)
|
||||
|
||||
for record in dns_resp.records:
|
||||
if record[0] == SupportedQueryType.HTTPS:
|
||||
assert isinstance(record[-1], dict)
|
||||
if "h3" in record[-1]["alpn"]:
|
||||
remote_preemptive_quic_rr = True
|
||||
continue
|
||||
|
||||
assert not isinstance(record[-1], dict)
|
||||
|
||||
inet_type = (
|
||||
socket.AF_INET
|
||||
if record[0] == SupportedQueryType.A
|
||||
else socket.AF_INET6
|
||||
)
|
||||
|
||||
dst_addr = (
|
||||
(
|
||||
record[-1],
|
||||
port,
|
||||
)
|
||||
if inet_type == socket.AF_INET
|
||||
else (
|
||||
record[-1],
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
)
|
||||
|
||||
results.append(
|
||||
(
|
||||
inet_type,
|
||||
type,
|
||||
6 if type == socket.SOCK_STREAM else 17,
|
||||
"",
|
||||
dst_addr,
|
||||
)
|
||||
)
|
||||
|
||||
quic_results: list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
] = []
|
||||
|
||||
if remote_preemptive_quic_rr:
|
||||
any_specified = False
|
||||
|
||||
for result in results:
|
||||
if result[1] == socket.SOCK_STREAM:
|
||||
quic_results.append(
|
||||
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
|
||||
)
|
||||
else:
|
||||
any_specified = True
|
||||
break
|
||||
|
||||
if any_specified:
|
||||
quic_results = []
|
||||
|
||||
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
|
||||
|
||||
|
||||
class GoogleResolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "google"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
if "rfc8484" in kwargs:
|
||||
if kwargs["rfc8484"]:
|
||||
kwargs["path"] = "/dns-query"
|
||||
super().__init__("dns.google", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class CloudflareResolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "cloudflare"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
kwargs.update({"path": "/dns-query"})
|
||||
super().__init__("cloudflare-dns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class AdGuardResolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "adguard"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
kwargs.update({"path": "/dns-query", "rfc8484": True})
|
||||
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class OpenDNSResolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "opendns"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
kwargs.update({"path": "/dns-query", "rfc8484": True})
|
||||
super().__init__("dns.opendns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class Quad9Resolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "quad9"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
kwargs.update({"path": "/dns-query", "rfc8484": True})
|
||||
super().__init__("dns11.quad9.net", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class NextDNSResolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "nextdns"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("dns.nextdns.io", port, *patterns, **kwargs)
|
||||
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
try:
|
||||
from ._qh3 import AdGuardResolver, NextDNSResolver, QUICResolver
|
||||
except ImportError:
|
||||
QUICResolver = None # type: ignore
|
||||
AdGuardResolver = None # type: ignore
|
||||
NextDNSResolver = None # type: ignore
|
||||
|
||||
|
||||
__all__ = (
|
||||
"QUICResolver",
|
||||
"AdGuardResolver",
|
||||
"NextDNSResolver",
|
||||
)
|
||||
@@ -0,0 +1,557 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
import ssl
|
||||
import typing
|
||||
from collections import deque
|
||||
from ssl import SSLError
|
||||
from time import time as monotonic
|
||||
|
||||
from qh3.quic.configuration import QuicConfiguration
|
||||
from qh3.quic.connection import QuicConnection
|
||||
from qh3.quic.events import (
|
||||
ConnectionTerminated,
|
||||
HandshakeCompleted,
|
||||
QuicEvent,
|
||||
StopSendingReceived,
|
||||
StreamDataReceived,
|
||||
StreamReset,
|
||||
)
|
||||
|
||||
from .....util.ssl_ import IS_FIPS, resolve_cert_reqs
|
||||
from ...protocols import (
|
||||
COMMON_RCODE_LABEL,
|
||||
DomainNameServerQuery,
|
||||
DomainNameServerReturn,
|
||||
ProtocolResolver,
|
||||
SupportedQueryType,
|
||||
)
|
||||
from ...utils import (
|
||||
is_ipv4,
|
||||
is_ipv6,
|
||||
rfc1035_pack,
|
||||
rfc1035_should_read,
|
||||
rfc1035_unpack,
|
||||
validate_length_of,
|
||||
)
|
||||
from ..dou import PlainResolver
|
||||
from ..system import SystemResolver
|
||||
|
||||
if IS_FIPS:
|
||||
raise ImportError(
|
||||
"DNS-over-QUIC disabled when Python is built with FIPS-compliant ssl module"
|
||||
)
|
||||
|
||||
|
||||
class QUICResolver(PlainResolver):
|
||||
protocol = ProtocolResolver.DOQ
|
||||
implementation = "qh3"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: str | None,
|
||||
port: int | None = None,
|
||||
*patterns: str,
|
||||
**kwargs: typing.Any,
|
||||
):
|
||||
super().__init__(server, port or 853, *patterns, **kwargs)
|
||||
|
||||
# qh3 load_default_certs seems off. need to investigate.
|
||||
if "ca_cert_data" not in kwargs and "ca_certs" not in kwargs:
|
||||
kwargs["ca_cert_data"] = []
|
||||
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
|
||||
try:
|
||||
ctx.load_default_certs()
|
||||
|
||||
for der in ctx.get_ca_certs(binary_form=True):
|
||||
kwargs["ca_cert_data"].append(ssl.DER_cert_to_PEM_cert(der))
|
||||
|
||||
if kwargs["ca_cert_data"]:
|
||||
kwargs["ca_cert_data"] = "".join(kwargs["ca_cert_data"])
|
||||
else:
|
||||
del kwargs["ca_cert_data"]
|
||||
except (AttributeError, ValueError, OSError):
|
||||
del kwargs["ca_cert_data"]
|
||||
|
||||
if "ca_cert_data" not in kwargs and "ca_certs" not in kwargs:
|
||||
if (
|
||||
"cert_reqs" not in kwargs
|
||||
or resolve_cert_reqs(kwargs["cert_reqs"]) is ssl.CERT_REQUIRED
|
||||
):
|
||||
raise ssl.SSLError(
|
||||
"DoQ requires at least one CA loaded in order to verify the remote peer certificate. "
|
||||
"Add ?cert_reqs=0 to disable certificate checks."
|
||||
)
|
||||
|
||||
configuration = QuicConfiguration(
|
||||
is_client=True,
|
||||
alpn_protocols=["doq"],
|
||||
server_name=self._server
|
||||
if "server_hostname" not in kwargs
|
||||
else kwargs["server_hostname"],
|
||||
verify_mode=resolve_cert_reqs(kwargs["cert_reqs"])
|
||||
if "cert_reqs" in kwargs
|
||||
else ssl.CERT_REQUIRED,
|
||||
cadata=kwargs["ca_cert_data"].encode()
|
||||
if "ca_cert_data" in kwargs
|
||||
else None,
|
||||
cafile=kwargs["ca_certs"] if "ca_certs" in kwargs else None,
|
||||
idle_timeout=300.0,
|
||||
)
|
||||
|
||||
if "cert_file" in kwargs:
|
||||
configuration.load_cert_chain(
|
||||
kwargs["cert_file"],
|
||||
kwargs["key_file"] if "key_file" in kwargs else None,
|
||||
kwargs["key_password"] if "key_password" in kwargs else None,
|
||||
)
|
||||
elif "cert_data" in kwargs:
|
||||
configuration.load_cert_chain(
|
||||
kwargs["cert_data"],
|
||||
kwargs["key_data"] if "key_data" in kwargs else None,
|
||||
kwargs["key_password"] if "key_password" in kwargs else None,
|
||||
)
|
||||
|
||||
self._quic = QuicConnection(configuration=configuration)
|
||||
|
||||
self._read_semaphore: asyncio.Semaphore = asyncio.Semaphore()
|
||||
self._connect_attempt: asyncio.Event = asyncio.Event()
|
||||
self._handshake_event: asyncio.Event = asyncio.Event()
|
||||
|
||||
self._terminated: bool = False
|
||||
self._should_disconnect: bool = False
|
||||
|
||||
# DNS over QUIC mandate the size-prefix (unsigned int, 2b)
|
||||
self._rfc1035_prefix_mandated = True
|
||||
|
||||
self._unconsumed: deque[DomainNameServerReturn] = deque()
|
||||
self._pending: deque[DomainNameServerQuery] = deque()
|
||||
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
if (
|
||||
not self._terminated
|
||||
and self._socket is not None
|
||||
and not self._socket.should_connect()
|
||||
):
|
||||
self._quic.close()
|
||||
|
||||
while True:
|
||||
datagrams = self._quic.datagrams_to_send(monotonic())
|
||||
|
||||
if not datagrams:
|
||||
break
|
||||
|
||||
for datagram in datagrams:
|
||||
data, addr = datagram
|
||||
await self._socket.sendall(data)
|
||||
|
||||
self._socket.close()
|
||||
await self._socket.wait_for_close()
|
||||
self._terminated = True
|
||||
if self._socket is None or self._socket.should_connect():
|
||||
self._terminated = True
|
||||
|
||||
def is_available(self) -> bool:
|
||||
if self._terminated:
|
||||
return False
|
||||
if self._socket is None or self._socket.should_connect():
|
||||
return True
|
||||
self._quic.handle_timer(monotonic())
|
||||
if hasattr(self._quic, "_close_event") and self._quic._close_event is not None:
|
||||
self._terminated = True
|
||||
return not self._terminated
|
||||
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if host is None:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Tried to resolve 'localhost' using the QUICResolver"
|
||||
)
|
||||
if port is None:
|
||||
port = 0 # Defensive: stdlib cpy behavior
|
||||
if isinstance(port, str):
|
||||
port = int(port) # Defensive: stdlib cpy behavior
|
||||
if port < 0:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Servname not supported for ai_socktype"
|
||||
)
|
||||
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii") # Defensive: stdlib cpy behavior
|
||||
|
||||
if is_ipv4(host):
|
||||
if family == socket.AF_INET6:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
type,
|
||||
6,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
),
|
||||
)
|
||||
]
|
||||
elif is_ipv6(host):
|
||||
if family == socket.AF_INET:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET6,
|
||||
type,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
validate_length_of(host)
|
||||
|
||||
if self._socket is None and self._connect_attempt.is_set() is False:
|
||||
self._connect_attempt.set()
|
||||
assert self.server is not None
|
||||
self._quic.connect((self._server, self._port), monotonic())
|
||||
self._socket = await SystemResolver().create_connection(
|
||||
(self.server, self.port or 853),
|
||||
timeout=self._timeout,
|
||||
source_address=self._source_address,
|
||||
socket_options=None,
|
||||
socket_kind=self._socket_type,
|
||||
)
|
||||
await self.__exchange_until(HandshakeCompleted, receive_first=False)
|
||||
self._handshake_event.set()
|
||||
else:
|
||||
await self._handshake_event.wait()
|
||||
|
||||
assert self._socket is not None
|
||||
|
||||
remote_preemptive_quic_rr = False
|
||||
|
||||
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
|
||||
quic_upgrade_via_dns_rr = False
|
||||
|
||||
tbq = []
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET]:
|
||||
tbq.append(SupportedQueryType.A)
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
|
||||
tbq.append(SupportedQueryType.AAAA)
|
||||
|
||||
if quic_upgrade_via_dns_rr:
|
||||
tbq.append(SupportedQueryType.HTTPS)
|
||||
|
||||
queries = DomainNameServerQuery.bulk(host, *tbq)
|
||||
open_streams = []
|
||||
|
||||
for q in queries:
|
||||
payload = bytes(q)
|
||||
|
||||
self._pending.append(q)
|
||||
|
||||
if self._rfc1035_prefix_mandated is True:
|
||||
payload = rfc1035_pack(payload)
|
||||
|
||||
stream_id = self._quic.get_next_available_stream_id()
|
||||
self._quic.send_stream_data(stream_id, payload, True)
|
||||
|
||||
open_streams.append(stream_id)
|
||||
|
||||
for dg in self._quic.datagrams_to_send(monotonic()):
|
||||
await self._socket.sendall(dg[0])
|
||||
|
||||
responses: list[DomainNameServerReturn] = []
|
||||
|
||||
while len(responses) < len(tbq):
|
||||
await self._read_semaphore.acquire()
|
||||
if self._unconsumed:
|
||||
dns_resp = None
|
||||
for query in queries:
|
||||
for unconsumed in self._unconsumed:
|
||||
if unconsumed.id == query.id:
|
||||
dns_resp = unconsumed
|
||||
responses.append(dns_resp)
|
||||
break
|
||||
if dns_resp:
|
||||
break
|
||||
if dns_resp:
|
||||
self._unconsumed.remove(dns_resp)
|
||||
self._pending.remove(query)
|
||||
self._read_semaphore.release()
|
||||
continue
|
||||
|
||||
try:
|
||||
events: list[StreamDataReceived] = await self.__exchange_until( # type: ignore[assignment]
|
||||
StreamDataReceived,
|
||||
receive_first=True,
|
||||
event_type_collectable=(StreamDataReceived,),
|
||||
respect_end_stream_signal=False,
|
||||
)
|
||||
|
||||
payload = b"".join([e.data for e in events])
|
||||
|
||||
while rfc1035_should_read(payload):
|
||||
events.extend(
|
||||
await self.__exchange_until( # type: ignore[arg-type]
|
||||
StreamDataReceived,
|
||||
receive_first=True,
|
||||
event_type_collectable=(StreamDataReceived,),
|
||||
respect_end_stream_signal=False,
|
||||
)
|
||||
)
|
||||
payload = b"".join([e.data for e in events])
|
||||
except (TimeoutError, OSError, socket.timeout, ConnectionError) as e:
|
||||
raise socket.gaierror(
|
||||
"Got unexpectedly disconnected while waiting for name resolution"
|
||||
) from e
|
||||
|
||||
self._read_semaphore.release()
|
||||
|
||||
if not payload:
|
||||
continue
|
||||
|
||||
#: We can receive two responses at once (or more, concatenated). Let's unwrap them.
|
||||
fragments = rfc1035_unpack(payload)
|
||||
|
||||
for fragment in fragments:
|
||||
dns_resp = DomainNameServerReturn(fragment)
|
||||
|
||||
if any(dns_resp.id == _.id for _ in queries):
|
||||
responses.append(dns_resp)
|
||||
|
||||
query_tbr: DomainNameServerQuery | None = None
|
||||
|
||||
for query_tbr in self._pending:
|
||||
if query_tbr.id == dns_resp.id:
|
||||
break
|
||||
if query_tbr:
|
||||
self._pending.remove(query_tbr)
|
||||
else:
|
||||
self._unconsumed.append(dns_resp)
|
||||
|
||||
if self._should_disconnect:
|
||||
await self.close()
|
||||
self._should_disconnect = False
|
||||
self._terminated = True
|
||||
|
||||
results = []
|
||||
|
||||
for response in responses:
|
||||
if not response.is_ok:
|
||||
if response.rcode == 2:
|
||||
raise socket.gaierror(
|
||||
f"DNSSEC validation failure. Check http://dnsviz.net/d/{host}/dnssec/ and http://dnssec-debugger.verisignlabs.com/{host} for errors"
|
||||
)
|
||||
raise socket.gaierror(
|
||||
f"DNS returned an error: {COMMON_RCODE_LABEL[response.rcode] if response.rcode in COMMON_RCODE_LABEL else f'code {response.rcode}'}"
|
||||
)
|
||||
|
||||
for record in response.records:
|
||||
if record[0] == SupportedQueryType.HTTPS:
|
||||
assert isinstance(record[-1], dict)
|
||||
if "h3" in record[-1]["alpn"]:
|
||||
remote_preemptive_quic_rr = True
|
||||
continue
|
||||
|
||||
assert not isinstance(record[-1], dict)
|
||||
|
||||
inet_type = (
|
||||
socket.AF_INET
|
||||
if record[0] == SupportedQueryType.A
|
||||
else socket.AF_INET6
|
||||
)
|
||||
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
|
||||
(
|
||||
record[-1],
|
||||
port,
|
||||
)
|
||||
if inet_type == socket.AF_INET
|
||||
else (
|
||||
record[-1],
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
)
|
||||
|
||||
results.append(
|
||||
(
|
||||
inet_type,
|
||||
type,
|
||||
6 if type == socket.SOCK_STREAM else 17,
|
||||
"",
|
||||
dst_addr,
|
||||
)
|
||||
)
|
||||
|
||||
quic_results = []
|
||||
|
||||
if remote_preemptive_quic_rr:
|
||||
any_specified = False
|
||||
|
||||
for result in results:
|
||||
if result[1] == socket.SOCK_STREAM:
|
||||
quic_results.append(
|
||||
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
|
||||
)
|
||||
else:
|
||||
any_specified = True
|
||||
break
|
||||
|
||||
if any_specified:
|
||||
quic_results = []
|
||||
|
||||
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
|
||||
|
||||
async def __exchange_until(
|
||||
self,
|
||||
event_type: type[QuicEvent] | tuple[type[QuicEvent], ...],
|
||||
*,
|
||||
receive_first: bool = False,
|
||||
event_type_collectable: type[QuicEvent]
|
||||
| tuple[type[QuicEvent], ...]
|
||||
| None = None,
|
||||
respect_end_stream_signal: bool = True,
|
||||
) -> list[QuicEvent]:
|
||||
assert self._socket is not None
|
||||
|
||||
while True:
|
||||
if receive_first is False:
|
||||
now = monotonic()
|
||||
while True:
|
||||
datagrams = self._quic.datagrams_to_send(now)
|
||||
|
||||
if not datagrams:
|
||||
break
|
||||
|
||||
for datagram in datagrams:
|
||||
data, addr = datagram
|
||||
await self._socket.sendall(data)
|
||||
|
||||
events = []
|
||||
|
||||
while True:
|
||||
if not self._quic._events:
|
||||
data_in = await self._socket.recv(1500)
|
||||
|
||||
if not data_in:
|
||||
break
|
||||
|
||||
now = monotonic()
|
||||
|
||||
if not isinstance(data_in, list):
|
||||
self._quic.receive_datagram(
|
||||
data_in, (self._server, self._port), now
|
||||
)
|
||||
else:
|
||||
for gro_segment in data_in:
|
||||
self._quic.receive_datagram(
|
||||
gro_segment, (self._server, self._port), now
|
||||
)
|
||||
|
||||
while True:
|
||||
datagrams = self._quic.datagrams_to_send(now)
|
||||
|
||||
if not datagrams:
|
||||
break
|
||||
|
||||
for datagram in datagrams:
|
||||
data, addr = datagram
|
||||
await self._socket.sendall(data)
|
||||
|
||||
for ev in iter(self._quic.next_event, None):
|
||||
if isinstance(ev, ConnectionTerminated):
|
||||
if ev.error_code == 298:
|
||||
raise SSLError(
|
||||
"DNS over QUIC did not succeed (Error 298). Chain certificate verification failed."
|
||||
)
|
||||
raise socket.gaierror(
|
||||
f"DNS over QUIC encountered a unrecoverable failure (error {ev.error_code} {ev.reason_phrase})"
|
||||
)
|
||||
elif isinstance(ev, StreamReset):
|
||||
self._terminated = True
|
||||
raise socket.gaierror(
|
||||
"DNS over QUIC server submitted a StreamReset. A request was rejected."
|
||||
)
|
||||
elif isinstance(ev, StopSendingReceived):
|
||||
self._should_disconnect = True
|
||||
continue
|
||||
|
||||
if event_type_collectable:
|
||||
if isinstance(ev, event_type_collectable):
|
||||
events.append(ev)
|
||||
else:
|
||||
events.append(ev)
|
||||
|
||||
if isinstance(ev, event_type):
|
||||
if not respect_end_stream_signal:
|
||||
return events
|
||||
if hasattr(ev, "stream_ended") and ev.stream_ended:
|
||||
return events
|
||||
elif hasattr(ev, "stream_ended") is False:
|
||||
return events
|
||||
|
||||
return events
|
||||
|
||||
|
||||
class AdGuardResolver(
|
||||
QUICResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "adguard"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any):
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class NextDNSResolver(
|
||||
QUICResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "nextdns"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any):
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
super().__init__("dns.nextdns.io", port, *patterns, **kwargs)
|
||||
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._ssl import (
|
||||
AdGuardResolver,
|
||||
CloudflareResolver,
|
||||
GoogleResolver,
|
||||
OpenDNSResolver,
|
||||
Quad9Resolver,
|
||||
TLSResolver,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"TLSResolver",
|
||||
"GoogleResolver",
|
||||
"CloudflareResolver",
|
||||
"AdGuardResolver",
|
||||
"Quad9Resolver",
|
||||
"OpenDNSResolver",
|
||||
)
|
||||
@@ -0,0 +1,197 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
|
||||
from .....util._async.ssl_ import ssl_wrap_socket
|
||||
from .....util.ssl_ import resolve_cert_reqs
|
||||
from ...protocols import ProtocolResolver
|
||||
from ..dou import PlainResolver
|
||||
from ..system import SystemResolver
|
||||
|
||||
|
||||
class TLSResolver(PlainResolver):
|
||||
"""
|
||||
Basic DNS resolver over TLS.
|
||||
Comply with RFC 7858: https://datatracker.ietf.org/doc/html/rfc7858
|
||||
"""
|
||||
|
||||
protocol = ProtocolResolver.DOT
|
||||
implementation = "ssl"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: str | None,
|
||||
port: int | None = None,
|
||||
*patterns: str,
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
self._socket_type = socket.SOCK_STREAM
|
||||
|
||||
super().__init__(server, port or 853, *patterns, **kwargs)
|
||||
|
||||
# DNS over TLS mandate the size-prefix (unsigned int, 2 bytes)
|
||||
self._rfc1035_prefix_mandated = True
|
||||
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if self._socket is None and self._connect_attempt.is_set() is False:
|
||||
assert self.server is not None
|
||||
self._connect_attempt.set()
|
||||
self._socket = await SystemResolver().create_connection(
|
||||
(self.server, self.port or 853),
|
||||
timeout=self._timeout,
|
||||
source_address=self._source_address,
|
||||
socket_options=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1, "tcp"),),
|
||||
socket_kind=self._socket_type,
|
||||
)
|
||||
self._socket = await ssl_wrap_socket(
|
||||
self._socket,
|
||||
server_hostname=self.server
|
||||
if "server_hostname" not in self._kwargs
|
||||
else self._kwargs["server_hostname"],
|
||||
keyfile=self._kwargs["key_file"]
|
||||
if "key_file" in self._kwargs
|
||||
else None,
|
||||
certfile=self._kwargs["cert_file"]
|
||||
if "cert_file" in self._kwargs
|
||||
else None,
|
||||
cert_reqs=resolve_cert_reqs(self._kwargs["cert_reqs"])
|
||||
if "cert_reqs" in self._kwargs
|
||||
else None,
|
||||
ca_certs=self._kwargs["ca_certs"]
|
||||
if "ca_certs" in self._kwargs
|
||||
else None,
|
||||
ssl_version=self._kwargs["ssl_version"]
|
||||
if "ssl_version" in self._kwargs
|
||||
else None,
|
||||
ciphers=self._kwargs["ciphers"] if "ciphers" in self._kwargs else None,
|
||||
ca_cert_dir=self._kwargs["ca_cert_dir"]
|
||||
if "ca_cert_dir" in self._kwargs
|
||||
else None,
|
||||
key_password=self._kwargs["key_password"]
|
||||
if "key_password" in self._kwargs
|
||||
else None,
|
||||
ca_cert_data=self._kwargs["ca_cert_data"]
|
||||
if "ca_cert_data" in self._kwargs
|
||||
else None,
|
||||
certdata=self._kwargs["cert_data"]
|
||||
if "cert_data" in self._kwargs
|
||||
else None,
|
||||
keydata=self._kwargs["key_data"]
|
||||
if "key_data" in self._kwargs
|
||||
else None,
|
||||
)
|
||||
self._connect_finalized.set()
|
||||
|
||||
return await super().getaddrinfo(
|
||||
host,
|
||||
port,
|
||||
family=family,
|
||||
type=type,
|
||||
proto=proto,
|
||||
flags=flags,
|
||||
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
|
||||
)
|
||||
|
||||
|
||||
class GoogleResolver(
|
||||
TLSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "google"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("dns.google", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class CloudflareResolver(
|
||||
TLSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "cloudflare"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("1.1.1.1", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class AdGuardResolver(
|
||||
TLSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "adguard"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class OpenDNSResolver(
|
||||
TLSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "opendns"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("dns.opendns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class Quad9Resolver(
|
||||
TLSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "quad9"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("dns11.quad9.net", port, *patterns, **kwargs)
|
||||
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._socket import (
|
||||
AdGuardResolver,
|
||||
CloudflareResolver,
|
||||
GoogleResolver,
|
||||
PlainResolver,
|
||||
Quad9Resolver,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"PlainResolver",
|
||||
"CloudflareResolver",
|
||||
"GoogleResolver",
|
||||
"Quad9Resolver",
|
||||
"AdGuardResolver",
|
||||
)
|
||||
@@ -0,0 +1,431 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
import typing
|
||||
from collections import deque
|
||||
|
||||
from ....ssa import AsyncSocket
|
||||
from ...protocols import (
|
||||
COMMON_RCODE_LABEL,
|
||||
DomainNameServerQuery,
|
||||
DomainNameServerReturn,
|
||||
ProtocolResolver,
|
||||
SupportedQueryType,
|
||||
)
|
||||
from ...utils import (
|
||||
is_ipv4,
|
||||
is_ipv6,
|
||||
packet_fragment,
|
||||
rfc1035_pack,
|
||||
rfc1035_should_read,
|
||||
rfc1035_unpack,
|
||||
validate_length_of,
|
||||
)
|
||||
from ..protocols import AsyncBaseResolver
|
||||
from ..system import SystemResolver
|
||||
|
||||
|
||||
class PlainResolver(AsyncBaseResolver):
|
||||
"""
|
||||
Minimalist DNS resolver over UDP
|
||||
Comply with RFC 1035: https://datatracker.ietf.org/doc/html/rfc1035
|
||||
|
||||
EDNS is not supported, yet. But we plan to. Willing to contribute?
|
||||
"""
|
||||
|
||||
protocol = ProtocolResolver.DOU
|
||||
implementation = "socket"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: str | None,
|
||||
port: int | None = None,
|
||||
*patterns: str,
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
super().__init__(server, port, *patterns, **kwargs)
|
||||
|
||||
self._socket: AsyncSocket | None = None
|
||||
|
||||
if not hasattr(self, "_socket_type"):
|
||||
self._socket_type = socket.SOCK_DGRAM
|
||||
|
||||
if "source_address" in kwargs and isinstance(kwargs["source_address"], str):
|
||||
if ":" in kwargs["source_address"]:
|
||||
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
|
||||
self._source_address: tuple[str, int] | None = (bind_ip, int(bind_port))
|
||||
else:
|
||||
self._source_address = (kwargs["source_address"], 0)
|
||||
else:
|
||||
self._source_address = None
|
||||
|
||||
if "timeout" in kwargs and isinstance(
|
||||
kwargs["timeout"],
|
||||
(
|
||||
float,
|
||||
int,
|
||||
),
|
||||
):
|
||||
self._timeout: float | int | None = kwargs["timeout"]
|
||||
else:
|
||||
self._timeout = None
|
||||
|
||||
#: Only useful for inheritance, e.g. DNS over TLS support dns-message but require a prefix.
|
||||
self._rfc1035_prefix_mandated: bool = False
|
||||
|
||||
self._unconsumed: deque[DomainNameServerReturn] = deque()
|
||||
self._pending: deque[DomainNameServerQuery] = deque()
|
||||
|
||||
self._read_semaphore: asyncio.Semaphore = asyncio.Semaphore()
|
||||
self._connect_attempt: asyncio.Event = asyncio.Event()
|
||||
self._connect_finalized: asyncio.Event = asyncio.Event()
|
||||
|
||||
self._terminated: bool = False
|
||||
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
if not self._terminated:
|
||||
with self._lock:
|
||||
if self._socket is not None:
|
||||
self._socket.close()
|
||||
await self._socket.wait_for_close()
|
||||
self._terminated = True
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return not self._terminated
|
||||
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if host is None:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Tried to resolve 'localhost' from a PlainResolver"
|
||||
)
|
||||
|
||||
if port is None:
|
||||
port = 0 # Defensive: stdlib cpy behavior
|
||||
if isinstance(port, str):
|
||||
port = int(port) # Defensive: stdlib cpy behavior
|
||||
if port < 0:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Servname not supported for ai_socktype"
|
||||
)
|
||||
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii") # Defensive: stdlib cpy behavior
|
||||
|
||||
if is_ipv4(host):
|
||||
if family == socket.AF_INET6:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
type,
|
||||
6,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
),
|
||||
)
|
||||
]
|
||||
elif is_ipv6(host):
|
||||
if family == socket.AF_INET:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET6,
|
||||
type,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
validate_length_of(host)
|
||||
|
||||
if self._socket is None and self._connect_attempt.is_set() is False:
|
||||
self._connect_attempt.set()
|
||||
assert self.server is not None
|
||||
self._socket = await SystemResolver().create_connection(
|
||||
(self.server, self.port or 53),
|
||||
timeout=self._timeout,
|
||||
source_address=self._source_address,
|
||||
socket_options=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1, "tcp"),),
|
||||
socket_kind=self._socket_type,
|
||||
)
|
||||
self._connect_finalized.set()
|
||||
else:
|
||||
await self._connect_finalized.wait()
|
||||
assert self._socket is not None
|
||||
await self._socket.wait_for_readiness()
|
||||
|
||||
remote_preemptive_quic_rr = False
|
||||
|
||||
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
|
||||
quic_upgrade_via_dns_rr = False
|
||||
|
||||
tbq = []
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET]:
|
||||
tbq.append(SupportedQueryType.A)
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
|
||||
tbq.append(SupportedQueryType.AAAA)
|
||||
|
||||
if quic_upgrade_via_dns_rr:
|
||||
tbq.append(SupportedQueryType.HTTPS)
|
||||
|
||||
queries = DomainNameServerQuery.bulk(host, *tbq)
|
||||
|
||||
for q in queries:
|
||||
payload = bytes(q)
|
||||
self._pending.append(q)
|
||||
|
||||
if self._rfc1035_prefix_mandated is True:
|
||||
payload = rfc1035_pack(payload)
|
||||
|
||||
await self._socket.sendall(payload)
|
||||
|
||||
responses: list[DomainNameServerReturn] = []
|
||||
|
||||
while len(responses) < len(tbq):
|
||||
await self._read_semaphore.acquire()
|
||||
#: There we want to verify if another thread got a response that belong to this thread.
|
||||
if self._unconsumed:
|
||||
dns_resp = None
|
||||
|
||||
for query in queries:
|
||||
for unconsumed in self._unconsumed:
|
||||
if unconsumed.id == query.id:
|
||||
dns_resp = unconsumed
|
||||
responses.append(dns_resp)
|
||||
break
|
||||
if dns_resp:
|
||||
break
|
||||
|
||||
if dns_resp:
|
||||
self._pending.remove(query)
|
||||
self._unconsumed.remove(dns_resp)
|
||||
self._read_semaphore.release()
|
||||
continue
|
||||
|
||||
try:
|
||||
data_in_or_segments = await self._socket.recv(1500)
|
||||
|
||||
if isinstance(data_in_or_segments, list):
|
||||
payloads = data_in_or_segments
|
||||
elif data_in_or_segments:
|
||||
payloads = [data_in_or_segments]
|
||||
else:
|
||||
payloads = []
|
||||
|
||||
if self._rfc1035_prefix_mandated is True and payloads:
|
||||
payload = b"".join(payloads)
|
||||
while rfc1035_should_read(payload):
|
||||
extra = await self._socket.recv(1500)
|
||||
if isinstance(extra, list):
|
||||
payload += b"".join(extra)
|
||||
else:
|
||||
payload += extra
|
||||
payloads = [payload]
|
||||
except (TimeoutError, OSError, socket.timeout, ConnectionError) as e:
|
||||
raise socket.gaierror(
|
||||
"Got unexpectedly disconnected while waiting for name resolution"
|
||||
) from e
|
||||
|
||||
self._read_semaphore.release()
|
||||
|
||||
if not payloads:
|
||||
self._terminated = True
|
||||
raise socket.gaierror(
|
||||
"Got unexpectedly disconnected while waiting for name resolution"
|
||||
)
|
||||
|
||||
pending_raw_identifiers = [_.raw_id for _ in self._pending]
|
||||
|
||||
for payload in payloads:
|
||||
#: We can receive two responses at once (or more, concatenated). Let's unwrap them.
|
||||
if self._rfc1035_prefix_mandated is True:
|
||||
fragments = rfc1035_unpack(payload)
|
||||
else:
|
||||
fragments = packet_fragment(payload, *pending_raw_identifiers)
|
||||
|
||||
for fragment in fragments:
|
||||
dns_resp = DomainNameServerReturn(fragment)
|
||||
|
||||
if any(dns_resp.id == _.id for _ in queries):
|
||||
responses.append(dns_resp)
|
||||
|
||||
query_tbr: DomainNameServerQuery | None = None
|
||||
|
||||
for query_tbr in self._pending:
|
||||
if query_tbr.id == dns_resp.id:
|
||||
break
|
||||
|
||||
if query_tbr:
|
||||
self._pending.remove(query_tbr)
|
||||
else:
|
||||
self._unconsumed.append(dns_resp)
|
||||
|
||||
results = []
|
||||
|
||||
for response in responses:
|
||||
if not response.is_ok:
|
||||
if response.rcode == 2:
|
||||
raise socket.gaierror(
|
||||
f"DNSSEC validation failure. Check http://dnsviz.net/d/{host}/dnssec/ and http://dnssec-debugger.verisignlabs.com/{host} for errors"
|
||||
)
|
||||
raise socket.gaierror(
|
||||
f"DNS returned an error: {COMMON_RCODE_LABEL[response.rcode] if response.rcode in COMMON_RCODE_LABEL else f'code {response.rcode}'}"
|
||||
)
|
||||
|
||||
for record in response.records:
|
||||
if record[0] == SupportedQueryType.HTTPS:
|
||||
assert isinstance(record[-1], dict)
|
||||
if "h3" in record[-1]["alpn"]:
|
||||
remote_preemptive_quic_rr = True
|
||||
continue
|
||||
|
||||
assert not isinstance(record[-1], dict)
|
||||
|
||||
inet_type = (
|
||||
socket.AF_INET
|
||||
if record[0] == SupportedQueryType.A
|
||||
else socket.AF_INET6
|
||||
)
|
||||
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
|
||||
(
|
||||
record[-1],
|
||||
port,
|
||||
)
|
||||
if inet_type == socket.AF_INET
|
||||
else (
|
||||
record[-1],
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
)
|
||||
|
||||
results.append(
|
||||
(
|
||||
inet_type,
|
||||
type,
|
||||
6 if type == socket.SOCK_STREAM else 17,
|
||||
"",
|
||||
dst_addr,
|
||||
)
|
||||
)
|
||||
|
||||
quic_results = []
|
||||
|
||||
if remote_preemptive_quic_rr:
|
||||
any_specified = False
|
||||
|
||||
for result in results:
|
||||
if result[1] == socket.SOCK_STREAM:
|
||||
quic_results.append(
|
||||
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
|
||||
)
|
||||
else:
|
||||
any_specified = True
|
||||
break
|
||||
|
||||
if any_specified:
|
||||
quic_results = []
|
||||
|
||||
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
|
||||
|
||||
|
||||
class CloudflareResolver(
|
||||
PlainResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "cloudflare"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("1.1.1.1", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class GoogleResolver(
|
||||
PlainResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "google"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("8.8.8.8", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class Quad9Resolver(
|
||||
PlainResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "quad9"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("9.9.9.9", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class AdGuardResolver(
|
||||
PlainResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "adguard"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("94.140.14.140", port, *patterns, **kwargs)
|
||||
@@ -0,0 +1,243 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import typing
|
||||
from abc import ABCMeta
|
||||
from base64 import b64encode
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
from ....util import parse_url
|
||||
from ..factories import ResolverDescription
|
||||
from ..protocols import ProtocolResolver
|
||||
from .protocols import AsyncBaseResolver
|
||||
|
||||
|
||||
class AsyncResolverFactory(metaclass=ABCMeta):
|
||||
@staticmethod
|
||||
def has(
|
||||
protocol: ProtocolResolver,
|
||||
specifier: str | None = None,
|
||||
implementation: str | None = None,
|
||||
) -> bool:
|
||||
package_name: str = __name__.split(".")[0]
|
||||
module_expr = f".{protocol.value.replace('-', '_')}"
|
||||
|
||||
if implementation:
|
||||
module_expr += f"._{implementation.replace('-', '_').lower()}"
|
||||
|
||||
try:
|
||||
resolver_module = importlib.import_module(
|
||||
module_expr, f"{package_name}.contrib.resolver._async"
|
||||
)
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
implementations: list[tuple[str, type[AsyncBaseResolver]]] = inspect.getmembers(
|
||||
resolver_module,
|
||||
lambda e: isinstance(e, type)
|
||||
and issubclass(e, AsyncBaseResolver)
|
||||
and (
|
||||
(specifier is None and e.specifier is None) or specifier == e.specifier
|
||||
),
|
||||
)
|
||||
|
||||
if not implementations:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def new(
|
||||
protocol: ProtocolResolver,
|
||||
specifier: str | None = None,
|
||||
implementation: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncBaseResolver:
|
||||
package_name: str = __name__.split(".")[0]
|
||||
|
||||
module_expr = f".{protocol.value.replace('-', '_')}"
|
||||
|
||||
if implementation:
|
||||
module_expr += f"._{implementation.replace('-', '_').lower()}"
|
||||
|
||||
spe_msg = " " if specifier is None else f' (w/ specifier "{specifier}") '
|
||||
|
||||
try:
|
||||
resolver_module = importlib.import_module(
|
||||
module_expr, f"{package_name}.contrib.resolver._async"
|
||||
)
|
||||
except ImportError as e:
|
||||
raise NotImplementedError(
|
||||
f"{protocol}{spe_msg}cannot be loaded. Tried to import '{module_expr}'. Did you specify a non-existent implementation?"
|
||||
) from e
|
||||
|
||||
implementations: list[tuple[str, type[AsyncBaseResolver]]] = inspect.getmembers(
|
||||
resolver_module,
|
||||
lambda e: isinstance(e, type)
|
||||
and issubclass(e, AsyncBaseResolver)
|
||||
and (
|
||||
(specifier is None and e.specifier is None) or specifier == e.specifier
|
||||
)
|
||||
and hasattr(e, "protocol")
|
||||
and e.protocol == protocol,
|
||||
)
|
||||
|
||||
if not implementations:
|
||||
raise NotImplementedError(
|
||||
f"{protocol}{spe_msg}cannot be loaded. "
|
||||
"No compatible implementation available. "
|
||||
"Make sure your implementation inherit from BaseResolver."
|
||||
)
|
||||
|
||||
implementation_target: type[AsyncBaseResolver] = implementations.pop()[1]
|
||||
|
||||
return implementation_target(**kwargs)
|
||||
|
||||
|
||||
class AsyncResolverDescription(ResolverDescription):
|
||||
"""Describe how a BaseResolver must be instantiated."""
|
||||
|
||||
def new(self) -> AsyncBaseResolver:
|
||||
kwargs = {**self.kwargs}
|
||||
|
||||
if self.server:
|
||||
kwargs["server"] = self.server
|
||||
if self.port:
|
||||
kwargs["port"] = self.port
|
||||
if self.host_patterns:
|
||||
kwargs["patterns"] = self.host_patterns
|
||||
|
||||
return AsyncResolverFactory.new(
|
||||
self.protocol,
|
||||
self.specifier,
|
||||
self.implementation,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_url(url: str) -> AsyncResolverDescription:
|
||||
parsed_url = parse_url(url)
|
||||
|
||||
schema = parsed_url.scheme
|
||||
|
||||
if schema is None:
|
||||
raise ValueError("Given DNS url is missing a protocol")
|
||||
|
||||
specifier = None
|
||||
implementation = None
|
||||
|
||||
if "+" in schema:
|
||||
schema, specifier = tuple(schema.lower().split("+", 1))
|
||||
|
||||
protocol = ProtocolResolver(schema)
|
||||
kwargs: dict[str, typing.Any] = {}
|
||||
|
||||
if parsed_url.path:
|
||||
kwargs["path"] = parsed_url.path
|
||||
|
||||
if parsed_url.auth:
|
||||
kwargs["headers"] = dict()
|
||||
if ":" in parsed_url.auth:
|
||||
username, password = parsed_url.auth.split(":")
|
||||
|
||||
username = username.strip("'\"")
|
||||
password = password.strip("'\"")
|
||||
|
||||
kwargs["headers"]["Authorization"] = (
|
||||
f"Basic {b64encode(f'{username}:{password}'.encode()).decode()}"
|
||||
)
|
||||
else:
|
||||
kwargs["headers"]["Authorization"] = f"Bearer {parsed_url.auth}"
|
||||
|
||||
if parsed_url.query:
|
||||
parameters = parse_qs(parsed_url.query)
|
||||
|
||||
for parameter in parameters:
|
||||
if not parameters[parameter]:
|
||||
continue
|
||||
|
||||
parameter_insensible = parameter.lower()
|
||||
|
||||
if (
|
||||
isinstance(parameters[parameter], list)
|
||||
and len(parameters[parameter]) > 1
|
||||
):
|
||||
if parameter == "implementation":
|
||||
raise ValueError("Only one implementation can be passed to URL")
|
||||
|
||||
values = []
|
||||
|
||||
for e in parameters[parameter]:
|
||||
if "," in e:
|
||||
values.extend(e.split(","))
|
||||
else:
|
||||
values.append(e)
|
||||
|
||||
if parameter_insensible in kwargs:
|
||||
if isinstance(kwargs[parameter_insensible], list):
|
||||
kwargs[parameter_insensible].extend(values)
|
||||
else:
|
||||
values.append(kwargs[parameter_insensible])
|
||||
kwargs[parameter_insensible] = values
|
||||
continue
|
||||
|
||||
kwargs[parameter_insensible] = values
|
||||
continue
|
||||
|
||||
value: str = parameters[parameter][0].lower().strip(" ")
|
||||
|
||||
if parameter == "implementation":
|
||||
implementation = value
|
||||
continue
|
||||
|
||||
if "," in value:
|
||||
list_of_values = value.split(",")
|
||||
|
||||
if parameter_insensible in kwargs:
|
||||
if isinstance(kwargs[parameter_insensible], list):
|
||||
kwargs[parameter_insensible].extend(list_of_values)
|
||||
else:
|
||||
list_of_values.append(kwargs[parameter_insensible])
|
||||
continue
|
||||
|
||||
kwargs[parameter_insensible] = list_of_values
|
||||
continue
|
||||
|
||||
value_converted: bool | int | float | None = None
|
||||
|
||||
if value in ["false", "true"]:
|
||||
value_converted = True if value == "true" else False
|
||||
elif value.isdigit():
|
||||
value_converted = int(value)
|
||||
elif (
|
||||
value.count(".") == 1
|
||||
and value.index(".") > 0
|
||||
and value.replace(".", "").isdigit()
|
||||
):
|
||||
value_converted = float(value)
|
||||
|
||||
kwargs[parameter_insensible] = (
|
||||
value if value_converted is None else value_converted
|
||||
)
|
||||
|
||||
host_patterns: list[str] = []
|
||||
|
||||
if "hosts" in kwargs:
|
||||
host_patterns = (
|
||||
kwargs["hosts"].split(",")
|
||||
if isinstance(kwargs["hosts"], str)
|
||||
else kwargs["hosts"]
|
||||
)
|
||||
del kwargs["hosts"]
|
||||
|
||||
return AsyncResolverDescription(
|
||||
protocol,
|
||||
specifier,
|
||||
implementation,
|
||||
parsed_url.host,
|
||||
parsed_url.port,
|
||||
*host_patterns,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -0,0 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._dict import InMemoryResolver
|
||||
|
||||
__all__ = ("InMemoryResolver",)
|
||||
@@ -0,0 +1,186 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
|
||||
from .....util.url import _IPV6_ADDRZ_RE
|
||||
from ...protocols import ProtocolResolver
|
||||
from ...utils import is_ipv4, is_ipv6
|
||||
from ..protocols import AsyncBaseResolver
|
||||
|
||||
|
||||
class InMemoryResolver(AsyncBaseResolver):
|
||||
protocol = ProtocolResolver.MANUAL
|
||||
implementation = "dict"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any):
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
kwargs.pop("port")
|
||||
super().__init__(None, None, *patterns, **kwargs)
|
||||
|
||||
self._maxsize = 65535 if "maxsize" not in kwargs else int(kwargs["maxsize"])
|
||||
self._hosts: dict[str, list[tuple[socket.AddressFamily, str]]] = {}
|
||||
|
||||
if self._host_patterns:
|
||||
for record in self._host_patterns:
|
||||
if ":" not in record:
|
||||
continue
|
||||
hostname, addr = record.split(":", 1)
|
||||
self.register(hostname, addr)
|
||||
self._host_patterns = tuple([])
|
||||
|
||||
def recycle(self) -> AsyncBaseResolver:
|
||||
return self
|
||||
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
pass # no-op
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True
|
||||
|
||||
def have_constraints(self) -> bool:
|
||||
return True
|
||||
|
||||
def support(self, hostname: str | bytes | None) -> bool | None:
|
||||
if hostname is None:
|
||||
hostname = "localhost"
|
||||
if isinstance(hostname, bytes):
|
||||
hostname = hostname.decode("ascii")
|
||||
return hostname in self._hosts
|
||||
|
||||
def register(self, hostname: str, ipaddr: str) -> None:
|
||||
if hostname not in self._hosts:
|
||||
self._hosts[hostname] = []
|
||||
else:
|
||||
for e in self._hosts[hostname]:
|
||||
t, addr = e
|
||||
if addr in ipaddr:
|
||||
return
|
||||
|
||||
if _IPV6_ADDRZ_RE.match(ipaddr):
|
||||
self._hosts[hostname].append((socket.AF_INET6, ipaddr[1:-1]))
|
||||
elif is_ipv6(ipaddr):
|
||||
self._hosts[hostname].append((socket.AF_INET6, ipaddr))
|
||||
else:
|
||||
self._hosts[hostname].append((socket.AF_INET, ipaddr))
|
||||
|
||||
if len(self._hosts) > self._maxsize:
|
||||
k = None
|
||||
for k in self._hosts.keys():
|
||||
break
|
||||
if k:
|
||||
self._hosts.pop(k)
|
||||
|
||||
def clear(self, hostname: str) -> None:
|
||||
if hostname in self._hosts:
|
||||
del self._hosts[hostname]
|
||||
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if host is None:
|
||||
host = "localhost" # Defensive: stdlib cpy behavior
|
||||
|
||||
if port is None:
|
||||
port = 0 # Defensive: stdlib cpy behavior
|
||||
if isinstance(port, str):
|
||||
port = int(port) # Defensive: stdlib cpy behavior
|
||||
if port < 0:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Servname not supported for ai_socktype"
|
||||
)
|
||||
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii") # Defensive: stdlib cpy behavior
|
||||
|
||||
if is_ipv4(host):
|
||||
if family == socket.AF_INET6:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
type,
|
||||
6,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
),
|
||||
)
|
||||
]
|
||||
elif is_ipv6(host):
|
||||
if family == socket.AF_INET:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET6,
|
||||
type,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
results: list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
] = []
|
||||
|
||||
if host not in self._hosts:
|
||||
raise socket.gaierror(f"no records found for hostname {host} in-memory")
|
||||
|
||||
for entry in self._hosts[host]:
|
||||
addr_type, addr_target = entry
|
||||
|
||||
if family != socket.AF_UNSPEC:
|
||||
if family != addr_type:
|
||||
continue
|
||||
|
||||
results.append(
|
||||
(
|
||||
addr_type,
|
||||
type,
|
||||
6 if type == socket.SOCK_STREAM else 17,
|
||||
"",
|
||||
(addr_target, port)
|
||||
if addr_type == socket.AF_INET
|
||||
else (addr_target, port, 0, 0),
|
||||
)
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise socket.gaierror(f"no records found for hostname {host} in-memory")
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,105 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
|
||||
from ...protocols import ProtocolResolver
|
||||
from ...utils import is_ipv4, is_ipv6
|
||||
from ..protocols import AsyncBaseResolver
|
||||
|
||||
|
||||
class NullResolver(AsyncBaseResolver):
|
||||
protocol = ProtocolResolver.NULL
|
||||
implementation = "dummy"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any):
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
kwargs.pop("port")
|
||||
super().__init__(None, None, *patterns, **kwargs)
|
||||
|
||||
def recycle(self) -> AsyncBaseResolver:
|
||||
return self
|
||||
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
pass # no-op
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True
|
||||
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if host is None:
|
||||
host = "localhost" # Defensive: stdlib cpy behavior
|
||||
|
||||
if port is None:
|
||||
port = 0 # Defensive: stdlib cpy behavior
|
||||
if isinstance(port, str):
|
||||
port = int(port) # Defensive: stdlib cpy behavior
|
||||
if port < 0:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Servname not supported for ai_socktype"
|
||||
)
|
||||
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii") # Defensive: stdlib cpy behavior
|
||||
|
||||
if is_ipv4(host):
|
||||
if family == socket.AF_INET6:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
type,
|
||||
6,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
),
|
||||
)
|
||||
]
|
||||
elif is_ipv6(host):
|
||||
if family == socket.AF_INET:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET6,
|
||||
type,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
raise socket.gaierror(f"Tried to resolve '{host}' using the NullResolver")
|
||||
|
||||
|
||||
__all__ = ("NullResolver",)
|
||||
@@ -0,0 +1,375 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import socket
|
||||
import struct
|
||||
import sys
|
||||
import typing
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from ...._constant import UDP_LINUX_GRO
|
||||
from ...._typing import _TYPE_SOCKET_OPTIONS, _TYPE_TIMEOUT_INTERNAL
|
||||
from ....exceptions import LocationParseError
|
||||
from ....util.connection import _set_socket_options, allowed_gai_family
|
||||
from ....util.timeout import _DEFAULT_TIMEOUT
|
||||
from ...ssa import AsyncSocket
|
||||
from ...ssa._timeout import timeout as timeout_
|
||||
from ..protocols import BaseResolver
|
||||
|
||||
|
||||
class AsyncBaseResolver(BaseResolver, metaclass=ABCMeta):
|
||||
def recycle(self) -> AsyncBaseResolver:
|
||||
return super().recycle() # type: ignore[return-value]
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
"""Terminate the given resolver instance. This should render it unusable. Further inquiries should raise an exception."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
"""This method align itself on the standard library socket.getaddrinfo(). It must be implemented as-is on your Resolver."""
|
||||
raise NotImplementedError
|
||||
|
||||
# This function is copied from socket.py in the Python 2.7 standard
|
||||
# library test suite. Added to its signature is only `socket_options`.
|
||||
# One additional modification is that we avoid binding to IPv6 servers
|
||||
# discovered in DNS if the system doesn't have IPv6 functionality.
|
||||
async def create_connection( # type: ignore[override]
|
||||
self,
|
||||
address: tuple[str, int],
|
||||
timeout: _TYPE_TIMEOUT_INTERNAL = _DEFAULT_TIMEOUT,
|
||||
source_address: tuple[str, int] | None = None,
|
||||
socket_options: _TYPE_SOCKET_OPTIONS | None = None,
|
||||
socket_kind: socket.SocketKind = socket.SOCK_STREAM,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
timing_hook: typing.Callable[[tuple[timedelta, timedelta, datetime]], None]
|
||||
| None = None,
|
||||
default_socket_family: socket.AddressFamily = socket.AF_UNSPEC,
|
||||
) -> AsyncSocket:
|
||||
"""Connect to *address* and return the socket object.
|
||||
|
||||
Convenience function. Connect to *address* (a 2-tuple ``(host,
|
||||
port)``) and return the socket object. Passing the optional
|
||||
*timeout* parameter will set the timeout on the socket instance
|
||||
before attempting to connect. If no *timeout* is supplied, the
|
||||
global default timeout setting returned by :func:`socket.getdefaulttimeout`
|
||||
is used. If *source_address* is set it must be a tuple of (host, port)
|
||||
for the socket to bind as a source address before making the connection.
|
||||
An host of '' or port 0 tells the OS to use the default.
|
||||
"""
|
||||
|
||||
host, port = address
|
||||
if host.startswith("["):
|
||||
host = host.strip("[]")
|
||||
err = None
|
||||
|
||||
# Using the value from allowed_gai_family() in the context of getaddrinfo lets
|
||||
# us select whether to work with IPv4 DNS records, IPv6 records, or both.
|
||||
# The original create_connection function always returns all records.
|
||||
family = allowed_gai_family()
|
||||
|
||||
if family != socket.AF_UNSPEC:
|
||||
default_socket_family = family
|
||||
|
||||
if source_address is not None:
|
||||
if isinstance(
|
||||
ipaddress.ip_address(source_address[0]), ipaddress.IPv4Address
|
||||
):
|
||||
default_socket_family = socket.AF_INET
|
||||
else:
|
||||
default_socket_family = socket.AF_INET6
|
||||
|
||||
try:
|
||||
host.encode("idna")
|
||||
except UnicodeError:
|
||||
raise LocationParseError(f"'{host}', label empty or too long") from None
|
||||
|
||||
dt_pre_resolve = datetime.now(tz=timezone.utc)
|
||||
if timeout is not _DEFAULT_TIMEOUT and timeout is not None:
|
||||
# we can hang here in case of bad networking conditions
|
||||
# the DNS may never answer or the packets can be lost.
|
||||
# this isn't possible in sync mode. unfortunately.
|
||||
# found by user at https://github.com/jawah/niquests/issues/183
|
||||
# todo: find a way to limit getaddrinfo delays in sync mode.
|
||||
try:
|
||||
async with timeout_(timeout):
|
||||
records = await self.getaddrinfo(
|
||||
host,
|
||||
port,
|
||||
default_socket_family,
|
||||
socket_kind,
|
||||
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
|
||||
)
|
||||
except TimeoutError:
|
||||
raise socket.gaierror(
|
||||
f"unable to resolve '{host}' within timeout. the DNS server may be unresponsive."
|
||||
)
|
||||
else:
|
||||
records = await self.getaddrinfo(
|
||||
host,
|
||||
port,
|
||||
default_socket_family,
|
||||
socket_kind,
|
||||
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
|
||||
)
|
||||
delta_post_resolve = datetime.now(tz=timezone.utc) - dt_pre_resolve
|
||||
|
||||
dt_pre_established = datetime.now(tz=timezone.utc)
|
||||
for res in records:
|
||||
af, socktype, proto, canonname, sa = res
|
||||
sock = None
|
||||
try:
|
||||
sock = AsyncSocket(af, socktype, proto)
|
||||
|
||||
# we need to add this or reusing the same origin port will likely fail within
|
||||
# short period of time. kernel put port on wait shut.
|
||||
if source_address:
|
||||
try:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
|
||||
except (OSError, AttributeError): # Defensive: very old OS?
|
||||
try:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
except (
|
||||
OSError,
|
||||
AttributeError,
|
||||
): # Defensive: we can't do anything better than this.
|
||||
pass
|
||||
|
||||
try:
|
||||
sock.setsockopt(
|
||||
socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0)
|
||||
)
|
||||
except (OSError, AttributeError):
|
||||
pass
|
||||
|
||||
# attempt to leverage GRO when under Linux
|
||||
if socktype == socket.SOCK_DGRAM and sys.platform == "linux":
|
||||
try:
|
||||
sock.setsockopt(socket.SOL_UDP, UDP_LINUX_GRO, 1)
|
||||
except OSError: # Defensive: oh, well(...) anyway!
|
||||
pass
|
||||
|
||||
# If provided, set socket level options before connecting.
|
||||
_set_socket_options(sock, socket_options)
|
||||
|
||||
if timeout is not _DEFAULT_TIMEOUT:
|
||||
sock.settimeout(timeout)
|
||||
if source_address:
|
||||
sock.bind(source_address)
|
||||
|
||||
try:
|
||||
await sock.connect(sa)
|
||||
except asyncio.CancelledError:
|
||||
sock.close()
|
||||
raise
|
||||
|
||||
# Break explicitly a reference cycle
|
||||
err = None
|
||||
|
||||
delta_post_established = (
|
||||
datetime.now(tz=timezone.utc) - dt_pre_established
|
||||
)
|
||||
|
||||
if timing_hook is not None:
|
||||
timing_hook(
|
||||
(
|
||||
delta_post_resolve,
|
||||
delta_post_established,
|
||||
datetime.now(tz=timezone.utc),
|
||||
)
|
||||
)
|
||||
|
||||
return sock
|
||||
except (OSError, OverflowError) as _:
|
||||
err = _
|
||||
if sock is not None:
|
||||
sock.close()
|
||||
if isinstance(_, OverflowError):
|
||||
break
|
||||
|
||||
if err is not None:
|
||||
try:
|
||||
raise err
|
||||
finally:
|
||||
# Break explicitly a reference cycle
|
||||
err = None
|
||||
else:
|
||||
raise OSError("getaddrinfo returns an empty list")
|
||||
|
||||
|
||||
class AsyncManyResolver(AsyncBaseResolver):
|
||||
"""
|
||||
Special resolver that use many child resolver. Priorities
|
||||
are based on given order (list of BaseResolver).
|
||||
"""
|
||||
|
||||
def __init__(self, *resolvers: AsyncBaseResolver) -> None:
|
||||
super().__init__(None, None)
|
||||
|
||||
self._size = len(resolvers)
|
||||
|
||||
self._unconstrained: list[AsyncBaseResolver] = [
|
||||
_ for _ in resolvers if not _.have_constraints()
|
||||
]
|
||||
self._constrained: list[AsyncBaseResolver] = [
|
||||
_ for _ in resolvers if _.have_constraints()
|
||||
]
|
||||
|
||||
self._concurrent: int = 0
|
||||
self._terminated: bool = False
|
||||
|
||||
def recycle(self) -> AsyncBaseResolver:
|
||||
resolvers = []
|
||||
|
||||
for resolver in self._unconstrained + self._constrained:
|
||||
resolvers.append(resolver.recycle())
|
||||
|
||||
return AsyncManyResolver(*resolvers)
|
||||
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
for resolver in self._unconstrained + self._constrained:
|
||||
await resolver.close()
|
||||
|
||||
self._terminated = True
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return not self._terminated
|
||||
|
||||
def __resolvers(
|
||||
self, constrained: bool = False
|
||||
) -> typing.Generator[AsyncBaseResolver, None, None]:
|
||||
resolvers = self._unconstrained if not constrained else self._constrained
|
||||
|
||||
if not resolvers:
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
self._concurrent += 1
|
||||
|
||||
try:
|
||||
resolver_count = len(resolvers)
|
||||
start_idx = (self._concurrent - 1) % resolver_count
|
||||
|
||||
for idx in range(start_idx, resolver_count):
|
||||
if not resolvers[idx].is_available():
|
||||
with self._lock:
|
||||
resolvers[idx] = resolvers[idx].recycle()
|
||||
yield resolvers[idx]
|
||||
|
||||
if start_idx > 0:
|
||||
for idx in range(0, start_idx):
|
||||
if not resolvers[idx].is_available():
|
||||
with self._lock:
|
||||
resolvers[idx] = resolvers[idx].recycle()
|
||||
yield resolvers[idx]
|
||||
finally:
|
||||
with self._lock:
|
||||
self._concurrent -= 1
|
||||
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii")
|
||||
if host is None:
|
||||
host = "localhost"
|
||||
|
||||
tested_resolvers = []
|
||||
|
||||
any_constrained_tried: bool = False
|
||||
|
||||
for resolver in self.__resolvers(True):
|
||||
can_resolve = resolver.support(host)
|
||||
|
||||
if can_resolve is True:
|
||||
any_constrained_tried = True
|
||||
|
||||
try:
|
||||
results = await resolver.getaddrinfo(
|
||||
host,
|
||||
port,
|
||||
family,
|
||||
type,
|
||||
proto,
|
||||
flags,
|
||||
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
|
||||
)
|
||||
|
||||
if results:
|
||||
return results
|
||||
except socket.gaierror as exc:
|
||||
if isinstance(exc.args[0], str) and (
|
||||
"DNSSEC" in exc.args[0] or "DNSKEY" in exc.args[0]
|
||||
):
|
||||
raise
|
||||
continue
|
||||
elif can_resolve is False:
|
||||
tested_resolvers.append(resolver)
|
||||
|
||||
if any_constrained_tried:
|
||||
raise socket.gaierror(
|
||||
f"Name or service not known: {host} using {self._size - len(self._unconstrained)} resolver(s)"
|
||||
)
|
||||
|
||||
for resolver in self.__resolvers():
|
||||
try:
|
||||
results = await resolver.getaddrinfo(
|
||||
host,
|
||||
port,
|
||||
family,
|
||||
type,
|
||||
proto,
|
||||
flags,
|
||||
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
|
||||
)
|
||||
|
||||
if results:
|
||||
return results
|
||||
except socket.gaierror as exc:
|
||||
if isinstance(exc.args[0], str) and (
|
||||
"DNSSEC" in exc.args[0] or "DNSKEY" in exc.args[0]
|
||||
):
|
||||
raise
|
||||
continue
|
||||
|
||||
raise socket.gaierror(
|
||||
f"Name or service not known: {host} using {self._size - len(self._constrained)} resolver(s)"
|
||||
)
|
||||
@@ -0,0 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._socket import SystemResolver
|
||||
|
||||
__all__ = ("SystemResolver",)
|
||||
@@ -0,0 +1,66 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
import typing
|
||||
|
||||
from ...protocols import ProtocolResolver
|
||||
from ..protocols import AsyncBaseResolver
|
||||
|
||||
|
||||
class SystemResolver(AsyncBaseResolver):
|
||||
implementation = "socket"
|
||||
protocol = ProtocolResolver.SYSTEM
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any):
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
kwargs.pop("port")
|
||||
super().__init__(None, None, *patterns, **kwargs)
|
||||
|
||||
def support(self, hostname: str | bytes | None) -> bool | None:
|
||||
if hostname is None:
|
||||
return True
|
||||
if isinstance(hostname, bytes):
|
||||
hostname = hostname.decode("ascii")
|
||||
if hostname == "localhost":
|
||||
return True
|
||||
return super().support(hostname)
|
||||
|
||||
def recycle(self) -> AsyncBaseResolver:
|
||||
return self
|
||||
|
||||
async def close(self) -> None: # type: ignore[override]
|
||||
pass # no-op!
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True
|
||||
|
||||
async def getaddrinfo( # type: ignore[override]
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
return await asyncio.get_running_loop().getaddrinfo(
|
||||
host=host,
|
||||
port=port,
|
||||
family=family,
|
||||
type=type,
|
||||
proto=proto,
|
||||
flags=flags,
|
||||
)
|
||||
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._urllib3 import (
|
||||
AdGuardResolver,
|
||||
CloudflareResolver,
|
||||
GoogleResolver,
|
||||
HTTPSResolver,
|
||||
NextDNSResolver,
|
||||
OpenDNSResolver,
|
||||
Quad9Resolver,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"HTTPSResolver",
|
||||
"GoogleResolver",
|
||||
"CloudflareResolver",
|
||||
"AdGuardResolver",
|
||||
"OpenDNSResolver",
|
||||
"Quad9Resolver",
|
||||
"NextDNSResolver",
|
||||
)
|
||||
@@ -0,0 +1,641 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
from base64 import b64encode
|
||||
|
||||
from ...._collections import HTTPHeaderDict
|
||||
from ....backend import ConnectionInfo, HttpVersion, ResponsePromise
|
||||
from ....connectionpool import HTTPSConnectionPool
|
||||
from ....response import HTTPResponse
|
||||
from ....util.url import parse_url
|
||||
from ..protocols import (
|
||||
BaseResolver,
|
||||
DomainNameServerQuery,
|
||||
DomainNameServerReturn,
|
||||
ProtocolResolver,
|
||||
SupportedQueryType,
|
||||
)
|
||||
from ..utils import is_ipv4, is_ipv6, validate_length_of, parse_https_rdata
|
||||
|
||||
|
||||
class HTTPSResolver(BaseResolver):
|
||||
"""
|
||||
Advanced DNS over HTTPS resolver.
|
||||
No common ground emerged from IETF w/ JSON. Following Google’s DNS over HTTPS schematics that is
|
||||
also implemented at Cloudflare.
|
||||
|
||||
Support RFC 8484 without JSON. Disabled by default.
|
||||
"""
|
||||
|
||||
implementation = "urllib3"
|
||||
protocol = ProtocolResolver.DOH
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: str | None,
|
||||
port: int | None = None,
|
||||
*patterns: str,
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
super().__init__(server, port or 443, *patterns, **kwargs)
|
||||
|
||||
self._path: str = "/resolve"
|
||||
|
||||
if "path" in kwargs:
|
||||
if isinstance(kwargs["path"], str) and kwargs["path"] != "/":
|
||||
self._path = kwargs["path"]
|
||||
kwargs.pop("path")
|
||||
|
||||
self._rfc8484: bool = False
|
||||
|
||||
if "rfc8484" in kwargs:
|
||||
if kwargs["rfc8484"]:
|
||||
self._rfc8484 = True
|
||||
kwargs.pop("rfc8484")
|
||||
|
||||
assert self._server is not None
|
||||
|
||||
if "source_address" in kwargs:
|
||||
if isinstance(kwargs["source_address"], str):
|
||||
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
|
||||
|
||||
if bind_ip and bind_port.isdigit():
|
||||
kwargs["source_address"] = (
|
||||
bind_ip,
|
||||
int(bind_port),
|
||||
)
|
||||
else:
|
||||
raise ValueError("invalid source_address given in parameters")
|
||||
else:
|
||||
raise ValueError("invalid source_address given in parameters")
|
||||
|
||||
if "proxy" in kwargs:
|
||||
kwargs["_proxy"] = parse_url(kwargs["proxy"])
|
||||
kwargs.pop("proxy")
|
||||
|
||||
if "maxsize" not in kwargs:
|
||||
kwargs["maxsize"] = 10
|
||||
|
||||
if "proxy_headers" in kwargs and "_proxy" in kwargs:
|
||||
proxy_headers = HTTPHeaderDict()
|
||||
|
||||
if not isinstance(kwargs["proxy_headers"], list):
|
||||
kwargs["proxy_headers"] = [kwargs["proxy_headers"]]
|
||||
|
||||
for item in kwargs["proxy_headers"]:
|
||||
if ":" not in item:
|
||||
raise ValueError("Passed header is invalid in DNS parameters")
|
||||
|
||||
k, v = item.split(":", 1)
|
||||
proxy_headers.add(k, v)
|
||||
|
||||
kwargs["_proxy_headers"] = proxy_headers
|
||||
|
||||
if "headers" in kwargs:
|
||||
headers = HTTPHeaderDict()
|
||||
|
||||
if not isinstance(kwargs["headers"], list):
|
||||
kwargs["headers"] = [kwargs["headers"]]
|
||||
|
||||
for item in kwargs["headers"]:
|
||||
if ":" not in item:
|
||||
raise ValueError("Passed header is invalid in DNS parameters")
|
||||
|
||||
k, v = item.split(":", 1)
|
||||
headers.add(k, v)
|
||||
|
||||
kwargs["headers"] = headers
|
||||
|
||||
if "disabled_svn" in kwargs:
|
||||
if not isinstance(kwargs["disabled_svn"], list):
|
||||
kwargs["disabled_svn"] = [kwargs["disabled_svn"]]
|
||||
|
||||
disabled_svn = set()
|
||||
|
||||
for svn in kwargs["disabled_svn"]:
|
||||
svn = svn.lower()
|
||||
|
||||
if svn == "h11":
|
||||
disabled_svn.add(HttpVersion.h11)
|
||||
elif svn == "h2":
|
||||
disabled_svn.add(HttpVersion.h2)
|
||||
elif svn == "h3":
|
||||
disabled_svn.add(HttpVersion.h3)
|
||||
|
||||
kwargs["disabled_svn"] = disabled_svn
|
||||
|
||||
if "on_post_connection" in kwargs and callable(kwargs["on_post_connection"]):
|
||||
self._connection_callback: (
|
||||
typing.Callable[[ConnectionInfo], None] | None
|
||||
) = kwargs["on_post_connection"]
|
||||
kwargs.pop("on_post_connection")
|
||||
else:
|
||||
self._connection_callback = None
|
||||
|
||||
self._pool = HTTPSConnectionPool(self._server, self._port, **kwargs)
|
||||
|
||||
def close(self) -> None:
|
||||
self._pool.close()
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self._pool.pool is not None
|
||||
|
||||
def getaddrinfo(
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if host is None:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Tried to resolve 'localhost' from a HTTPSResolver"
|
||||
)
|
||||
|
||||
if port is None:
|
||||
port = 0 # Defensive: stdlib cpy behavior
|
||||
if isinstance(port, str):
|
||||
port = int(port) # Defensive: stdlib cpy behavior
|
||||
if port < 0:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Servname not supported for ai_socktype"
|
||||
)
|
||||
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii") # Defensive: stdlib cpy behavior
|
||||
|
||||
if is_ipv4(host):
|
||||
if family == socket.AF_INET6:
|
||||
raise socket.gaierror("Address family for hostname not supported")
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
type,
|
||||
6,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
),
|
||||
)
|
||||
]
|
||||
elif is_ipv6(host):
|
||||
if family == socket.AF_INET:
|
||||
raise socket.gaierror("Address family for hostname not supported")
|
||||
return [
|
||||
(
|
||||
socket.AF_INET6,
|
||||
type,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
validate_length_of(host)
|
||||
|
||||
promises: list[HTTPResponse | ResponsePromise] = []
|
||||
remote_preemptive_quic_rr = False
|
||||
|
||||
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
|
||||
quic_upgrade_via_dns_rr = False
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET]:
|
||||
if not self._rfc8484:
|
||||
promises.append(
|
||||
self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{"name": host, "type": "1"},
|
||||
headers={"Accept": "application/dns-json"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
dns_query = DomainNameServerQuery(
|
||||
host, SupportedQueryType.A, override_id=0
|
||||
)
|
||||
dns_payload = bytes(dns_query)
|
||||
promises.append(
|
||||
self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{
|
||||
"dns": b64encode(dns_payload).decode().replace("=", ""),
|
||||
},
|
||||
headers={"Accept": "application/dns-message"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
|
||||
if not self._rfc8484:
|
||||
promises.append(
|
||||
self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{"name": host, "type": "28"},
|
||||
headers={"Accept": "application/dns-json"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
dns_query = DomainNameServerQuery(
|
||||
host, SupportedQueryType.AAAA, override_id=0
|
||||
)
|
||||
dns_payload = bytes(dns_query)
|
||||
|
||||
promises.append(
|
||||
self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{
|
||||
"dns": b64encode(dns_payload).decode().replace("=", ""),
|
||||
},
|
||||
headers={"Accept": "application/dns-message"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
|
||||
if quic_upgrade_via_dns_rr:
|
||||
if not self._rfc8484:
|
||||
promises.append(
|
||||
self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{"name": host, "type": "65"},
|
||||
headers={"Accept": "application/dns-json"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
dns_query = DomainNameServerQuery(
|
||||
host, SupportedQueryType.HTTPS, override_id=0
|
||||
)
|
||||
dns_payload = bytes(dns_query)
|
||||
|
||||
promises.append(
|
||||
self._pool.request_encode_url(
|
||||
"GET",
|
||||
self._path,
|
||||
{
|
||||
"dns": b64encode(dns_payload).decode().replace("=", ""),
|
||||
},
|
||||
headers={"Accept": "application/dns-message"},
|
||||
on_post_connection=self._connection_callback,
|
||||
multiplexed=True,
|
||||
)
|
||||
)
|
||||
|
||||
responses: list[HTTPResponse] = []
|
||||
|
||||
for promise in promises:
|
||||
if isinstance(promise, HTTPResponse):
|
||||
responses.append(promise)
|
||||
continue
|
||||
responses.append(self._pool.get_response(promise=promise)) # type: ignore[arg-type]
|
||||
|
||||
results: list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
] = []
|
||||
|
||||
for response in responses:
|
||||
if response.status >= 300:
|
||||
raise socket.gaierror(
|
||||
f"DNS over HTTPS was unsuccessful, server response status {response.status}."
|
||||
)
|
||||
|
||||
if not self._rfc8484:
|
||||
payload = response.json()
|
||||
|
||||
assert "Status" in payload and isinstance(payload["Status"], int)
|
||||
|
||||
if payload["Status"] != 0:
|
||||
msg = (
|
||||
payload["Comment"]
|
||||
if "Comment" in payload
|
||||
else f"Remote DNS indicated that an error occurred while providing resolution. Status {payload['Status']}."
|
||||
)
|
||||
|
||||
if isinstance(msg, list):
|
||||
msg = ", ".join(msg)
|
||||
|
||||
raise socket.gaierror(msg)
|
||||
|
||||
assert "Question" in payload and isinstance(payload["Question"], list)
|
||||
|
||||
if "Answer" not in payload:
|
||||
continue
|
||||
|
||||
assert isinstance(payload["Answer"], list)
|
||||
|
||||
for answer in payload["Answer"]:
|
||||
if answer["type"] not in [1, 28, 65]:
|
||||
continue
|
||||
|
||||
assert "data" in answer
|
||||
assert isinstance(answer["data"], str)
|
||||
|
||||
# DNS RR/HTTPS
|
||||
if answer["type"] == 65:
|
||||
# "1 . alpn=h3,h2 ipv4hint=104.16.132.229,104.16.133.229 ipv6hint=2606:4700::6810:84e5,2606:4700::6810:85e5"
|
||||
# or..
|
||||
# "1 . alpn=h2,h3"
|
||||
rr: str = answer["data"]
|
||||
|
||||
if rr.startswith("\\#"): # it means, raw, bytes.
|
||||
rr = "".join(rr[2:].split(" ")[2:])
|
||||
|
||||
try:
|
||||
raw_record = bytes.fromhex(rr)
|
||||
except ValueError:
|
||||
raw_record = b""
|
||||
|
||||
https_record = parse_https_rdata(raw_record)
|
||||
|
||||
if "h3" not in https_record["alpn"]:
|
||||
continue
|
||||
|
||||
remote_preemptive_quic_rr = True
|
||||
else:
|
||||
rr_decode: dict[str, str] = dict(
|
||||
tuple(_.lower().split("=", 1)) # type: ignore[misc]
|
||||
for _ in rr.split(" ")
|
||||
if "=" in _
|
||||
)
|
||||
|
||||
if "alpn" not in rr_decode or "h3" not in rr_decode["alpn"]:
|
||||
continue
|
||||
|
||||
remote_preemptive_quic_rr = True
|
||||
|
||||
if "ipv4hint" in rr_decode and family in [
|
||||
socket.AF_UNSPEC,
|
||||
socket.AF_INET,
|
||||
]:
|
||||
for ipv4 in rr_decode["ipv4hint"].split(","):
|
||||
results.append(
|
||||
(
|
||||
socket.AF_INET,
|
||||
socket.SOCK_DGRAM,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
ipv4,
|
||||
port,
|
||||
),
|
||||
)
|
||||
)
|
||||
if "ipv6hint" in rr_decode and family in [
|
||||
socket.AF_UNSPEC,
|
||||
socket.AF_INET6,
|
||||
]:
|
||||
for ipv6 in rr_decode["ipv6hint"].split(","):
|
||||
results.append(
|
||||
(
|
||||
socket.AF_INET6,
|
||||
socket.SOCK_DGRAM,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
ipv6,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
inet_type = (
|
||||
socket.AF_INET if answer["type"] == 1 else socket.AF_INET6
|
||||
)
|
||||
|
||||
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
|
||||
(
|
||||
answer["data"],
|
||||
port,
|
||||
)
|
||||
if inet_type == socket.AF_INET
|
||||
else (
|
||||
answer["data"],
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
)
|
||||
|
||||
results.append(
|
||||
(
|
||||
inet_type,
|
||||
type,
|
||||
6 if type == socket.SOCK_STREAM else 17,
|
||||
"",
|
||||
dst_addr,
|
||||
)
|
||||
)
|
||||
else:
|
||||
dns_resp = DomainNameServerReturn(response.data)
|
||||
|
||||
for record in dns_resp.records:
|
||||
if record[0] == SupportedQueryType.HTTPS:
|
||||
assert isinstance(record[-1], dict)
|
||||
if "h3" in record[-1]["alpn"]:
|
||||
remote_preemptive_quic_rr = True
|
||||
continue
|
||||
|
||||
assert not isinstance(record[-1], dict)
|
||||
|
||||
inet_type = (
|
||||
socket.AF_INET
|
||||
if record[0] == SupportedQueryType.A
|
||||
else socket.AF_INET6
|
||||
)
|
||||
dst_addr = (
|
||||
(
|
||||
record[-1],
|
||||
port,
|
||||
)
|
||||
if inet_type == socket.AF_INET
|
||||
else (
|
||||
record[-1],
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
)
|
||||
|
||||
results.append(
|
||||
(
|
||||
inet_type,
|
||||
type,
|
||||
6 if type == socket.SOCK_STREAM else 17,
|
||||
"",
|
||||
dst_addr,
|
||||
)
|
||||
)
|
||||
|
||||
quic_results: list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
] = []
|
||||
|
||||
if remote_preemptive_quic_rr:
|
||||
any_specified = False
|
||||
|
||||
for result in results:
|
||||
if result[1] == socket.SOCK_STREAM:
|
||||
quic_results.append(
|
||||
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
|
||||
)
|
||||
else:
|
||||
any_specified = True
|
||||
break
|
||||
|
||||
if any_specified:
|
||||
quic_results = []
|
||||
|
||||
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
|
||||
|
||||
|
||||
class GoogleResolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "google"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
if "rfc8484" in kwargs:
|
||||
if kwargs["rfc8484"]:
|
||||
kwargs["path"] = "/dns-query"
|
||||
super().__init__("dns.google", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class CloudflareResolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "cloudflare"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
kwargs.update({"path": "/dns-query"})
|
||||
super().__init__("cloudflare-dns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class AdGuardResolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "adguard"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
kwargs.update({"path": "/dns-query", "rfc8484": True})
|
||||
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class OpenDNSResolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "opendns"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
kwargs.update({"path": "/dns-query", "rfc8484": True})
|
||||
super().__init__("dns.opendns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class Quad9Resolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "quad9"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
kwargs.update({"path": "/dns-query", "rfc8484": True})
|
||||
super().__init__("dns11.quad9.net", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class NextDNSResolver(
|
||||
HTTPSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "nextdns"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("dns.nextdns.io", port, *patterns, **kwargs)
|
||||
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
try:
|
||||
from ._qh3 import AdGuardResolver, NextDNSResolver, QUICResolver
|
||||
except ImportError:
|
||||
QUICResolver = None # type: ignore
|
||||
AdGuardResolver = None # type: ignore
|
||||
NextDNSResolver = None # type: ignore
|
||||
|
||||
|
||||
__all__ = (
|
||||
"QUICResolver",
|
||||
"AdGuardResolver",
|
||||
"NextDNSResolver",
|
||||
)
|
||||
@@ -0,0 +1,541 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import ssl
|
||||
import typing
|
||||
from collections import deque
|
||||
from ssl import SSLError
|
||||
from time import time as monotonic
|
||||
|
||||
from qh3.quic.configuration import QuicConfiguration
|
||||
from qh3.quic.connection import QuicConnection
|
||||
from qh3.quic.events import (
|
||||
ConnectionTerminated,
|
||||
HandshakeCompleted,
|
||||
QuicEvent,
|
||||
StopSendingReceived,
|
||||
StreamDataReceived,
|
||||
StreamReset,
|
||||
)
|
||||
|
||||
from ....util.ssl_ import IS_FIPS, resolve_cert_reqs
|
||||
from ...ssa._gro import _sock_has_gro, _sock_has_gso, sync_recv_gro, sync_sendmsg_gso
|
||||
from ..dou import PlainResolver
|
||||
from ..protocols import (
|
||||
COMMON_RCODE_LABEL,
|
||||
DomainNameServerQuery,
|
||||
DomainNameServerReturn,
|
||||
ProtocolResolver,
|
||||
SupportedQueryType,
|
||||
)
|
||||
from ..utils import (
|
||||
is_ipv4,
|
||||
is_ipv6,
|
||||
rfc1035_pack,
|
||||
rfc1035_should_read,
|
||||
rfc1035_unpack,
|
||||
validate_length_of,
|
||||
)
|
||||
|
||||
if IS_FIPS:
|
||||
raise ImportError(
|
||||
"DNS-over-QUIC disabled when Python is built with FIPS-compliant ssl module"
|
||||
)
|
||||
|
||||
|
||||
class QUICResolver(PlainResolver):
|
||||
protocol = ProtocolResolver.DOQ
|
||||
implementation = "qh3"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: str,
|
||||
port: int | None = None,
|
||||
*patterns: str,
|
||||
**kwargs: typing.Any,
|
||||
):
|
||||
super().__init__(server, port or 853, *patterns, **kwargs)
|
||||
|
||||
# qh3 load_default_certs seems off. need to investigate.
|
||||
if "ca_cert_data" not in kwargs and "ca_certs" not in kwargs:
|
||||
kwargs["ca_cert_data"] = []
|
||||
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
|
||||
try:
|
||||
ctx.load_default_certs()
|
||||
|
||||
for der in ctx.get_ca_certs(binary_form=True):
|
||||
kwargs["ca_cert_data"].append(ssl.DER_cert_to_PEM_cert(der))
|
||||
|
||||
if kwargs["ca_cert_data"]:
|
||||
kwargs["ca_cert_data"] = "".join(kwargs["ca_cert_data"])
|
||||
else:
|
||||
del kwargs["ca_cert_data"]
|
||||
except (AttributeError, ValueError, OSError):
|
||||
del kwargs["ca_cert_data"]
|
||||
|
||||
if "ca_cert_data" not in kwargs and "ca_certs" not in kwargs:
|
||||
if (
|
||||
"cert_reqs" not in kwargs
|
||||
or resolve_cert_reqs(kwargs["cert_reqs"]) is ssl.CERT_REQUIRED
|
||||
):
|
||||
raise ssl.SSLError(
|
||||
"DoQ requires at least one CA loaded in order to verify the remote peer certificate. "
|
||||
"Add ?cert_reqs=0 to disable certificate checks."
|
||||
)
|
||||
|
||||
configuration = QuicConfiguration(
|
||||
is_client=True,
|
||||
alpn_protocols=["doq"],
|
||||
server_name=self._server
|
||||
if "server_hostname" not in kwargs
|
||||
else kwargs["server_hostname"],
|
||||
verify_mode=resolve_cert_reqs(kwargs["cert_reqs"])
|
||||
if "cert_reqs" in kwargs
|
||||
else ssl.CERT_REQUIRED,
|
||||
cadata=kwargs["ca_cert_data"].encode()
|
||||
if "ca_cert_data" in kwargs
|
||||
else None,
|
||||
cafile=kwargs["ca_certs"] if "ca_certs" in kwargs else None,
|
||||
idle_timeout=300.0,
|
||||
)
|
||||
|
||||
if "cert_file" in kwargs:
|
||||
configuration.load_cert_chain(
|
||||
kwargs["cert_file"],
|
||||
kwargs["key_file"] if "key_file" in kwargs else None,
|
||||
kwargs["key_password"] if "key_password" in kwargs else None,
|
||||
)
|
||||
elif "cert_data" in kwargs:
|
||||
configuration.load_cert_chain(
|
||||
kwargs["cert_data"],
|
||||
kwargs["key_data"] if "key_data" in kwargs else None,
|
||||
kwargs["key_password"] if "key_password" in kwargs else None,
|
||||
)
|
||||
|
||||
self._quic = QuicConnection(configuration=configuration)
|
||||
|
||||
self._dgram_gro_enabled: bool = _sock_has_gro(self._socket)
|
||||
self._dgram_gso_enabled: bool = _sock_has_gso(self._socket)
|
||||
|
||||
self._quic.connect((self._server, self._port), monotonic())
|
||||
self.__exchange_until(HandshakeCompleted, receive_first=False)
|
||||
|
||||
self._terminated: bool = False
|
||||
self._should_disconnect: bool = False
|
||||
|
||||
# DNS over QUIC mandate the size-prefix (unsigned int, 2b)
|
||||
self._rfc1035_prefix_mandated = True
|
||||
|
||||
self._unconsumed: deque[DomainNameServerReturn] = deque()
|
||||
self._pending: deque[DomainNameServerQuery] = deque()
|
||||
|
||||
def close(self) -> None:
|
||||
if not self._terminated:
|
||||
with self._lock:
|
||||
self._quic.close()
|
||||
|
||||
while True:
|
||||
datagrams = self._quic.datagrams_to_send(monotonic())
|
||||
|
||||
if not datagrams:
|
||||
break
|
||||
|
||||
if self._dgram_gso_enabled and len(datagrams) > 1:
|
||||
sync_sendmsg_gso(self._socket, [d[0] for d in datagrams])
|
||||
else:
|
||||
for datagram in datagrams:
|
||||
self._socket.sendall(datagram[0])
|
||||
|
||||
self._socket.close()
|
||||
self._terminated = True
|
||||
|
||||
def is_available(self) -> bool:
|
||||
self._quic.handle_timer(monotonic())
|
||||
if hasattr(self._quic, "_close_event") and self._quic._close_event is not None:
|
||||
self._terminated = True
|
||||
return not self._terminated
|
||||
|
||||
def getaddrinfo(
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if host is None:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Tried to resolve 'localhost' using the QUICResolver"
|
||||
)
|
||||
if port is None:
|
||||
port = 0 # Defensive: stdlib cpy behavior
|
||||
if isinstance(port, str):
|
||||
port = int(port) # Defensive: stdlib cpy behavior
|
||||
if port < 0:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Servname not supported for ai_socktype"
|
||||
)
|
||||
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii") # Defensive: stdlib cpy behavior
|
||||
|
||||
if is_ipv4(host):
|
||||
if family == socket.AF_INET6:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
type,
|
||||
6,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
),
|
||||
)
|
||||
]
|
||||
elif is_ipv6(host):
|
||||
if family == socket.AF_INET:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET6,
|
||||
type,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
validate_length_of(host)
|
||||
|
||||
remote_preemptive_quic_rr = False
|
||||
|
||||
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
|
||||
quic_upgrade_via_dns_rr = False
|
||||
|
||||
tbq = []
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET]:
|
||||
tbq.append(SupportedQueryType.A)
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
|
||||
tbq.append(SupportedQueryType.AAAA)
|
||||
|
||||
if quic_upgrade_via_dns_rr:
|
||||
tbq.append(SupportedQueryType.HTTPS)
|
||||
|
||||
queries = DomainNameServerQuery.bulk(host, *tbq)
|
||||
open_streams = []
|
||||
|
||||
with self._lock:
|
||||
for q in queries:
|
||||
payload = bytes(q)
|
||||
|
||||
self._pending.append(q)
|
||||
|
||||
if self._rfc1035_prefix_mandated is True:
|
||||
payload = rfc1035_pack(payload)
|
||||
|
||||
stream_id = self._quic.get_next_available_stream_id()
|
||||
self._quic.send_stream_data(stream_id, payload, True)
|
||||
|
||||
open_streams.append(stream_id)
|
||||
|
||||
datagrams = self._quic.datagrams_to_send(monotonic())
|
||||
if self._dgram_gso_enabled and len(datagrams) > 1:
|
||||
sync_sendmsg_gso(self._socket, [d[0] for d in datagrams])
|
||||
else:
|
||||
for dg in datagrams:
|
||||
self._socket.sendall(dg[0])
|
||||
|
||||
responses: list[DomainNameServerReturn] = []
|
||||
|
||||
while len(responses) < len(tbq):
|
||||
with self._lock:
|
||||
if self._unconsumed:
|
||||
dns_resp = None
|
||||
for query in queries:
|
||||
for unconsumed in self._unconsumed:
|
||||
if unconsumed.id == query.id:
|
||||
dns_resp = unconsumed
|
||||
responses.append(dns_resp)
|
||||
break
|
||||
if dns_resp:
|
||||
break
|
||||
if dns_resp:
|
||||
self._unconsumed.remove(dns_resp)
|
||||
self._pending.remove(query)
|
||||
continue
|
||||
|
||||
try:
|
||||
events: list[StreamDataReceived] = self.__exchange_until( # type: ignore[assignment]
|
||||
StreamDataReceived,
|
||||
receive_first=True,
|
||||
event_type_collectable=(StreamDataReceived,),
|
||||
respect_end_stream_signal=False,
|
||||
)
|
||||
|
||||
payload = b"".join([e.data for e in events])
|
||||
|
||||
while rfc1035_should_read(payload):
|
||||
events.extend(
|
||||
self.__exchange_until( # type: ignore[arg-type]
|
||||
StreamDataReceived,
|
||||
receive_first=True,
|
||||
event_type_collectable=(StreamDataReceived,),
|
||||
respect_end_stream_signal=False,
|
||||
)
|
||||
)
|
||||
payload = b"".join([e.data for e in events])
|
||||
except (TimeoutError, OSError, socket.timeout, ConnectionError) as e:
|
||||
raise socket.gaierror(
|
||||
"Got unexpectedly disconnected while waiting for name resolution"
|
||||
) from e
|
||||
|
||||
if not payload:
|
||||
continue
|
||||
|
||||
#: We can receive two responses at once (or more, concatenated). Let's unwrap them.
|
||||
fragments = rfc1035_unpack(payload)
|
||||
|
||||
for fragment in fragments:
|
||||
dns_resp = DomainNameServerReturn(fragment)
|
||||
|
||||
if any(dns_resp.id == _.id for _ in queries):
|
||||
responses.append(dns_resp)
|
||||
|
||||
query_tbr: DomainNameServerQuery | None = None
|
||||
|
||||
for query_tbr in self._pending:
|
||||
if query_tbr.id == dns_resp.id:
|
||||
break
|
||||
if query_tbr:
|
||||
self._pending.remove(query_tbr)
|
||||
else:
|
||||
self._unconsumed.append(dns_resp)
|
||||
|
||||
if self._should_disconnect:
|
||||
with self._lock:
|
||||
self.close()
|
||||
self._should_disconnect = False
|
||||
self._terminated = True
|
||||
|
||||
results = []
|
||||
|
||||
for response in responses:
|
||||
if not response.is_ok:
|
||||
if response.rcode == 2:
|
||||
raise socket.gaierror(
|
||||
f"DNSSEC validation failure. Check http://dnsviz.net/d/{host}/dnssec/ and http://dnssec-debugger.verisignlabs.com/{host} for errors"
|
||||
)
|
||||
raise socket.gaierror(
|
||||
f"DNS returned an error: {COMMON_RCODE_LABEL[response.rcode] if response.rcode in COMMON_RCODE_LABEL else f'code {response.rcode}'}"
|
||||
)
|
||||
|
||||
for record in response.records:
|
||||
if record[0] == SupportedQueryType.HTTPS:
|
||||
assert isinstance(record[-1], dict)
|
||||
if "h3" in record[-1]["alpn"]:
|
||||
remote_preemptive_quic_rr = True
|
||||
continue
|
||||
|
||||
assert not isinstance(record[-1], dict)
|
||||
|
||||
inet_type = (
|
||||
socket.AF_INET
|
||||
if record[0] == SupportedQueryType.A
|
||||
else socket.AF_INET6
|
||||
)
|
||||
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
|
||||
(
|
||||
record[-1],
|
||||
port,
|
||||
)
|
||||
if inet_type == socket.AF_INET
|
||||
else (
|
||||
record[-1],
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
)
|
||||
|
||||
results.append(
|
||||
(
|
||||
inet_type,
|
||||
type,
|
||||
6 if type == socket.SOCK_STREAM else 17,
|
||||
"",
|
||||
dst_addr,
|
||||
)
|
||||
)
|
||||
|
||||
quic_results = []
|
||||
|
||||
if remote_preemptive_quic_rr:
|
||||
any_specified = False
|
||||
|
||||
for result in results:
|
||||
if result[1] == socket.SOCK_STREAM:
|
||||
quic_results.append(
|
||||
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
|
||||
)
|
||||
else:
|
||||
any_specified = True
|
||||
break
|
||||
|
||||
if any_specified:
|
||||
quic_results = []
|
||||
|
||||
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
|
||||
|
||||
def __exchange_until(
|
||||
self,
|
||||
event_type: type[QuicEvent] | tuple[type[QuicEvent], ...],
|
||||
*,
|
||||
receive_first: bool = False,
|
||||
event_type_collectable: type[QuicEvent]
|
||||
| tuple[type[QuicEvent], ...]
|
||||
| None = None,
|
||||
respect_end_stream_signal: bool = True,
|
||||
) -> list[QuicEvent]:
|
||||
while True:
|
||||
if receive_first is False:
|
||||
now = monotonic()
|
||||
while True:
|
||||
datagrams = self._quic.datagrams_to_send(now)
|
||||
|
||||
if not datagrams:
|
||||
break
|
||||
|
||||
if self._dgram_gso_enabled and len(datagrams) > 1:
|
||||
sync_sendmsg_gso(self._socket, [d[0] for d in datagrams])
|
||||
else:
|
||||
for datagram in datagrams:
|
||||
self._socket.sendall(datagram[0])
|
||||
|
||||
events = []
|
||||
|
||||
while True:
|
||||
if not self._quic._events:
|
||||
if self._dgram_gro_enabled:
|
||||
data_in = sync_recv_gro(self._socket, 65535)
|
||||
else:
|
||||
data_in = self._socket.recv(1500)
|
||||
|
||||
if not data_in:
|
||||
break
|
||||
|
||||
now = monotonic()
|
||||
|
||||
if isinstance(data_in, list):
|
||||
for gro_segment in data_in:
|
||||
self._quic.receive_datagram(
|
||||
gro_segment, (self._server, self._port), now
|
||||
)
|
||||
else:
|
||||
self._quic.receive_datagram(
|
||||
data_in, (self._server, self._port), now
|
||||
)
|
||||
|
||||
while True:
|
||||
now = monotonic()
|
||||
datagrams = self._quic.datagrams_to_send(now)
|
||||
|
||||
if not datagrams:
|
||||
break
|
||||
|
||||
if self._dgram_gso_enabled and len(datagrams) > 1:
|
||||
sync_sendmsg_gso(self._socket, [d[0] for d in datagrams])
|
||||
else:
|
||||
for datagram in datagrams:
|
||||
self._socket.sendall(datagram[0])
|
||||
|
||||
for ev in iter(self._quic.next_event, None):
|
||||
if isinstance(ev, ConnectionTerminated):
|
||||
if ev.error_code == 298:
|
||||
raise SSLError(
|
||||
"DNS over QUIC did not succeed (Error 298). Chain certificate verification failed."
|
||||
)
|
||||
raise socket.gaierror(
|
||||
f"DNS over QUIC encountered a unrecoverable failure (error {ev.error_code} {ev.reason_phrase})"
|
||||
)
|
||||
elif isinstance(ev, StreamReset):
|
||||
self._terminated = True
|
||||
raise socket.gaierror(
|
||||
"DNS over QUIC server submitted a StreamReset. A request was rejected."
|
||||
)
|
||||
elif isinstance(ev, StopSendingReceived):
|
||||
self._should_disconnect = True
|
||||
continue
|
||||
|
||||
if event_type_collectable:
|
||||
if isinstance(ev, event_type_collectable):
|
||||
events.append(ev)
|
||||
else:
|
||||
events.append(ev)
|
||||
|
||||
if isinstance(ev, event_type):
|
||||
if not respect_end_stream_signal:
|
||||
return events
|
||||
if hasattr(ev, "stream_ended") and ev.stream_ended:
|
||||
return events
|
||||
elif hasattr(ev, "stream_ended") is False:
|
||||
return events
|
||||
|
||||
return events
|
||||
|
||||
|
||||
class AdGuardResolver(
|
||||
QUICResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "adguard"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any):
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class NextDNSResolver(
|
||||
QUICResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "nextdns"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any):
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
super().__init__("dns.nextdns.io", port, *patterns, **kwargs)
|
||||
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._ssl import (
|
||||
AdGuardResolver,
|
||||
CloudflareResolver,
|
||||
GoogleResolver,
|
||||
OpenDNSResolver,
|
||||
Quad9Resolver,
|
||||
TLSResolver,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"TLSResolver",
|
||||
"GoogleResolver",
|
||||
"CloudflareResolver",
|
||||
"AdGuardResolver",
|
||||
"Quad9Resolver",
|
||||
"OpenDNSResolver",
|
||||
)
|
||||
@@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
|
||||
from ....util.ssl_ import resolve_cert_reqs, ssl_wrap_socket
|
||||
from ..dou import PlainResolver
|
||||
from ..protocols import ProtocolResolver
|
||||
from ..system import SystemResolver
|
||||
|
||||
|
||||
class TLSResolver(PlainResolver):
|
||||
"""
|
||||
Basic DNS resolver over TLS.
|
||||
Comply with RFC 7858: https://datatracker.ietf.org/doc/html/rfc7858
|
||||
"""
|
||||
|
||||
protocol = ProtocolResolver.DOT
|
||||
implementation = "ssl"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: str,
|
||||
port: int | None = None,
|
||||
*patterns: str,
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
if "timeout" in kwargs and isinstance(kwargs["timeout"], (int, float)):
|
||||
timeout = kwargs["timeout"]
|
||||
else:
|
||||
timeout = None
|
||||
|
||||
if "source_address" in kwargs and isinstance(kwargs["source_address"], str):
|
||||
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
|
||||
else:
|
||||
bind_ip, bind_port = "0.0.0.0", "0"
|
||||
|
||||
self._socket = SystemResolver().create_connection(
|
||||
(server, port or 853),
|
||||
timeout=timeout,
|
||||
source_address=(bind_ip, int(bind_port))
|
||||
if bind_ip != "0.0.0.0" or bind_port != "0"
|
||||
else None,
|
||||
socket_options=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1, "tcp"),),
|
||||
socket_kind=socket.SOCK_STREAM,
|
||||
)
|
||||
|
||||
super().__init__(server, port, *patterns, **kwargs)
|
||||
|
||||
self._socket = ssl_wrap_socket(
|
||||
self._socket,
|
||||
server_hostname=server
|
||||
if "server_hostname" not in kwargs
|
||||
else kwargs["server_hostname"],
|
||||
keyfile=kwargs["key_file"] if "key_file" in kwargs else None,
|
||||
certfile=kwargs["cert_file"] if "cert_file" in kwargs else None,
|
||||
cert_reqs=resolve_cert_reqs(kwargs["cert_reqs"])
|
||||
if "cert_reqs" in kwargs
|
||||
else None,
|
||||
ca_certs=kwargs["ca_certs"] if "ca_certs" in kwargs else None,
|
||||
ssl_version=kwargs["ssl_version"] if "ssl_version" in kwargs else None,
|
||||
ciphers=kwargs["ciphers"] if "ciphers" in kwargs else None,
|
||||
ca_cert_dir=kwargs["ca_cert_dir"] if "ca_cert_dir" in kwargs else None,
|
||||
key_password=kwargs["key_password"] if "key_password" in kwargs else None,
|
||||
ca_cert_data=kwargs["ca_cert_data"] if "ca_cert_data" in kwargs else None,
|
||||
certdata=kwargs["cert_data"] if "cert_data" in kwargs else None,
|
||||
keydata=kwargs["key_data"] if "key_data" in kwargs else None,
|
||||
)
|
||||
|
||||
# DNS over TLS mandate the size-prefix (unsigned int, 2 bytes)
|
||||
self._rfc1035_prefix_mandated = True
|
||||
|
||||
|
||||
class GoogleResolver(
|
||||
TLSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "google"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("dns.google", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class CloudflareResolver(
|
||||
TLSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "cloudflare"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("1.1.1.1", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class AdGuardResolver(
|
||||
TLSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "adguard"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("unfiltered.adguard-dns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class OpenDNSResolver(
|
||||
TLSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "opendns"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("dns.opendns.com", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class Quad9Resolver(
|
||||
TLSResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "quad9"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("dns11.quad9.net", port, *patterns, **kwargs)
|
||||
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._socket import (
|
||||
AdGuardResolver,
|
||||
CloudflareResolver,
|
||||
GoogleResolver,
|
||||
PlainResolver,
|
||||
Quad9Resolver,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"PlainResolver",
|
||||
"CloudflareResolver",
|
||||
"GoogleResolver",
|
||||
"Quad9Resolver",
|
||||
"AdGuardResolver",
|
||||
)
|
||||
@@ -0,0 +1,415 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
from collections import deque
|
||||
|
||||
from ...ssa._gro import _sock_has_gro, sync_recv_gro
|
||||
from ..protocols import (
|
||||
COMMON_RCODE_LABEL,
|
||||
BaseResolver,
|
||||
DomainNameServerQuery,
|
||||
DomainNameServerReturn,
|
||||
ProtocolResolver,
|
||||
SupportedQueryType,
|
||||
)
|
||||
from ..system import SystemResolver
|
||||
from ..utils import (
|
||||
is_ipv4,
|
||||
is_ipv6,
|
||||
packet_fragment,
|
||||
rfc1035_pack,
|
||||
rfc1035_should_read,
|
||||
rfc1035_unpack,
|
||||
validate_length_of,
|
||||
)
|
||||
|
||||
|
||||
class PlainResolver(BaseResolver):
|
||||
"""
|
||||
Minimalist DNS resolver over UDP
|
||||
Comply with RFC 1035: https://datatracker.ietf.org/doc/html/rfc1035
|
||||
|
||||
EDNS is not supported, yet. But we plan to. Willing to contribute?
|
||||
"""
|
||||
|
||||
protocol = ProtocolResolver.DOU
|
||||
implementation = "socket"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: str,
|
||||
port: int | None = None,
|
||||
*patterns: str,
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
super().__init__(server, port, *patterns, **kwargs)
|
||||
|
||||
if not hasattr(self, "_socket"):
|
||||
if "timeout" in kwargs and isinstance(
|
||||
kwargs["timeout"],
|
||||
(
|
||||
float,
|
||||
int,
|
||||
),
|
||||
):
|
||||
timeout = kwargs["timeout"]
|
||||
else:
|
||||
timeout = None
|
||||
|
||||
if "source_address" in kwargs and isinstance(kwargs["source_address"], str):
|
||||
bind_ip, bind_port = kwargs["source_address"].split(":", 1)
|
||||
else:
|
||||
bind_ip, bind_port = "0.0.0.0", "0"
|
||||
|
||||
self._socket = SystemResolver().create_connection(
|
||||
(server, port or 53),
|
||||
timeout=timeout,
|
||||
source_address=(bind_ip, int(bind_port))
|
||||
if bind_ip != "0.0.0.0" or bind_port != "0"
|
||||
else None,
|
||||
socket_options=None,
|
||||
socket_kind=socket.SOCK_DGRAM,
|
||||
)
|
||||
|
||||
#: Only useful for inheritance, e.g. DNS over TLS support dns-message but require a prefix.
|
||||
self._rfc1035_prefix_mandated: bool = False
|
||||
|
||||
self._gro_enabled: bool = _sock_has_gro(self._socket)
|
||||
|
||||
self._unconsumed: deque[DomainNameServerReturn] = deque()
|
||||
self._pending: deque[DomainNameServerQuery] = deque()
|
||||
|
||||
self._terminated: bool = False
|
||||
|
||||
def close(self) -> None:
|
||||
if not self._terminated:
|
||||
with self._lock:
|
||||
if self._socket is not None:
|
||||
self._socket.shutdown(0)
|
||||
self._socket.close()
|
||||
self._terminated = True
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return not self._terminated
|
||||
|
||||
def getaddrinfo(
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if host is None:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Tried to resolve 'localhost' from a PlainResolver"
|
||||
)
|
||||
|
||||
if port is None:
|
||||
port = 0 # Defensive: stdlib cpy behavior
|
||||
if isinstance(port, str):
|
||||
port = int(port) # Defensive: stdlib cpy behavior
|
||||
if port < 0:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Servname not supported for ai_socktype"
|
||||
)
|
||||
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii") # Defensive: stdlib cpy behavior
|
||||
|
||||
if is_ipv4(host):
|
||||
if family == socket.AF_INET6:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
type,
|
||||
6,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
),
|
||||
)
|
||||
]
|
||||
elif is_ipv6(host):
|
||||
if family == socket.AF_INET:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET6,
|
||||
type,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
validate_length_of(host)
|
||||
|
||||
remote_preemptive_quic_rr = False
|
||||
|
||||
if quic_upgrade_via_dns_rr and type == socket.SOCK_DGRAM:
|
||||
quic_upgrade_via_dns_rr = False
|
||||
|
||||
tbq = []
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET]:
|
||||
tbq.append(SupportedQueryType.A)
|
||||
|
||||
if family in [socket.AF_UNSPEC, socket.AF_INET6]:
|
||||
tbq.append(SupportedQueryType.AAAA)
|
||||
|
||||
if quic_upgrade_via_dns_rr:
|
||||
tbq.append(SupportedQueryType.HTTPS)
|
||||
|
||||
queries = DomainNameServerQuery.bulk(host, *tbq)
|
||||
|
||||
with self._lock:
|
||||
for q in queries:
|
||||
payload = bytes(q)
|
||||
self._pending.append(q)
|
||||
|
||||
if self._rfc1035_prefix_mandated is True:
|
||||
payload = rfc1035_pack(payload)
|
||||
|
||||
self._socket.sendall(payload)
|
||||
|
||||
responses: list[DomainNameServerReturn] = []
|
||||
|
||||
while len(responses) < len(tbq):
|
||||
with self._lock:
|
||||
#: There we want to verify if another thread got a response that belong to this thread.
|
||||
if self._unconsumed:
|
||||
dns_resp = None
|
||||
|
||||
for query in queries:
|
||||
for unconsumed in self._unconsumed:
|
||||
if unconsumed.id == query.id:
|
||||
dns_resp = unconsumed
|
||||
responses.append(dns_resp)
|
||||
break
|
||||
if dns_resp:
|
||||
break
|
||||
|
||||
if dns_resp:
|
||||
self._pending.remove(query)
|
||||
self._unconsumed.remove(dns_resp)
|
||||
continue
|
||||
|
||||
try:
|
||||
if self._gro_enabled:
|
||||
data_in_or_segments = sync_recv_gro(self._socket, 65535)
|
||||
else:
|
||||
data_in_or_segments = self._socket.recv(1500)
|
||||
|
||||
if isinstance(data_in_or_segments, list):
|
||||
payloads = data_in_or_segments
|
||||
elif data_in_or_segments:
|
||||
payloads = [data_in_or_segments]
|
||||
else:
|
||||
payloads = []
|
||||
|
||||
if self._rfc1035_prefix_mandated is True and payloads:
|
||||
payload = b"".join(payloads)
|
||||
while rfc1035_should_read(payload):
|
||||
extra = self._socket.recv(1500)
|
||||
if isinstance(extra, list):
|
||||
payload += b"".join(extra)
|
||||
else:
|
||||
payload += extra
|
||||
payloads = [payload]
|
||||
except (TimeoutError, OSError, socket.timeout, ConnectionError) as e:
|
||||
raise socket.gaierror(
|
||||
"Got unexpectedly disconnected while waiting for name resolution"
|
||||
) from e
|
||||
|
||||
if not payloads:
|
||||
self._terminated = True
|
||||
raise socket.gaierror(
|
||||
"Got unexpectedly disconnected while waiting for name resolution"
|
||||
)
|
||||
|
||||
pending_raw_identifiers = [_.raw_id for _ in self._pending]
|
||||
|
||||
for payload in payloads:
|
||||
#: We can receive two responses at once (or more, concatenated). Let's unwrap them.
|
||||
if self._rfc1035_prefix_mandated is True:
|
||||
fragments = rfc1035_unpack(payload)
|
||||
else:
|
||||
fragments = packet_fragment(payload, *pending_raw_identifiers)
|
||||
|
||||
for fragment in fragments:
|
||||
dns_resp = DomainNameServerReturn(fragment)
|
||||
|
||||
if any(dns_resp.id == _.id for _ in queries):
|
||||
responses.append(dns_resp)
|
||||
|
||||
query_tbr: DomainNameServerQuery | None = None
|
||||
|
||||
for query_tbr in self._pending:
|
||||
if query_tbr.id == dns_resp.id:
|
||||
break
|
||||
|
||||
if query_tbr:
|
||||
self._pending.remove(query_tbr)
|
||||
else:
|
||||
self._unconsumed.append(dns_resp)
|
||||
|
||||
results = []
|
||||
|
||||
for response in responses:
|
||||
if not response.is_ok:
|
||||
if response.rcode == 2:
|
||||
raise socket.gaierror(
|
||||
f"DNSSEC validation failure. Check http://dnsviz.net/d/{host}/dnssec/ and http://dnssec-debugger.verisignlabs.com/{host} for errors"
|
||||
)
|
||||
raise socket.gaierror(
|
||||
f"DNS returned an error: {COMMON_RCODE_LABEL[response.rcode] if response.rcode in COMMON_RCODE_LABEL else f'code {response.rcode}'}"
|
||||
)
|
||||
|
||||
for record in response.records:
|
||||
if record[0] == SupportedQueryType.HTTPS:
|
||||
assert isinstance(record[-1], dict)
|
||||
if "h3" in record[-1]["alpn"]:
|
||||
remote_preemptive_quic_rr = True
|
||||
continue
|
||||
|
||||
assert not isinstance(record[-1], dict)
|
||||
|
||||
inet_type = (
|
||||
socket.AF_INET
|
||||
if record[0] == SupportedQueryType.A
|
||||
else socket.AF_INET6
|
||||
)
|
||||
dst_addr: tuple[str, int] | tuple[str, int, int, int] = (
|
||||
(
|
||||
record[-1],
|
||||
port,
|
||||
)
|
||||
if inet_type == socket.AF_INET
|
||||
else (
|
||||
record[-1],
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
)
|
||||
)
|
||||
|
||||
results.append(
|
||||
(
|
||||
inet_type,
|
||||
type,
|
||||
6 if type == socket.SOCK_STREAM else 17,
|
||||
"",
|
||||
dst_addr,
|
||||
)
|
||||
)
|
||||
|
||||
quic_results = []
|
||||
|
||||
if remote_preemptive_quic_rr:
|
||||
any_specified = False
|
||||
|
||||
for result in results:
|
||||
if result[1] == socket.SOCK_STREAM:
|
||||
quic_results.append(
|
||||
(result[0], socket.SOCK_DGRAM, 17, "", result[4])
|
||||
)
|
||||
else:
|
||||
any_specified = True
|
||||
break
|
||||
|
||||
if any_specified:
|
||||
quic_results = []
|
||||
|
||||
return sorted(quic_results + results, key=lambda _: _[0] + _[1], reverse=True)
|
||||
|
||||
|
||||
class CloudflareResolver(
|
||||
PlainResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "cloudflare"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("1.1.1.1", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class GoogleResolver(
|
||||
PlainResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "google"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("8.8.8.8", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class Quad9Resolver(
|
||||
PlainResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "quad9"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("9.9.9.9", port, *patterns, **kwargs)
|
||||
|
||||
|
||||
class AdGuardResolver(
|
||||
PlainResolver
|
||||
): # Defensive: we do not cover specific vendors/DNS shortcut
|
||||
specifier = "adguard"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any) -> None:
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
port = kwargs["port"]
|
||||
kwargs.pop("port")
|
||||
else:
|
||||
port = None
|
||||
|
||||
super().__init__("94.140.14.140", port, *patterns, **kwargs)
|
||||
@@ -0,0 +1,230 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import typing
|
||||
from abc import ABCMeta
|
||||
from base64 import b64encode
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
from ...util import parse_url
|
||||
from .protocols import BaseResolver, ProtocolResolver
|
||||
|
||||
|
||||
class ResolverFactory(metaclass=ABCMeta):
|
||||
@staticmethod
|
||||
def new(
|
||||
protocol: ProtocolResolver,
|
||||
specifier: str | None = None,
|
||||
implementation: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseResolver:
|
||||
package_name: str = __name__.split(".")[0]
|
||||
|
||||
module_expr = f".{protocol.value.replace('-', '_')}"
|
||||
|
||||
if implementation:
|
||||
module_expr += f"._{implementation.replace('-', '_').lower()}"
|
||||
|
||||
spe_msg = " " if specifier is None else f' (w/ specifier "{specifier}") '
|
||||
|
||||
try:
|
||||
resolver_module = importlib.import_module(
|
||||
module_expr, f"{package_name}.contrib.resolver"
|
||||
)
|
||||
except ImportError as e:
|
||||
raise NotImplementedError(
|
||||
f"{protocol}{spe_msg}cannot be loaded. Tried to import '{module_expr}'. Did you specify a non-existent implementation?"
|
||||
) from e
|
||||
|
||||
implementations: list[tuple[str, type[BaseResolver]]] = inspect.getmembers(
|
||||
resolver_module,
|
||||
lambda e: isinstance(e, type)
|
||||
and issubclass(e, BaseResolver)
|
||||
and (
|
||||
(specifier is None and e.specifier is None) or specifier == e.specifier
|
||||
),
|
||||
)
|
||||
|
||||
if not implementations:
|
||||
raise NotImplementedError(
|
||||
f"{protocol}{spe_msg}cannot be loaded. "
|
||||
"No compatible implementation available. "
|
||||
"Make sure your implementation inherit from BaseResolver."
|
||||
)
|
||||
|
||||
implementation_target: type[BaseResolver] = implementations.pop()[1]
|
||||
|
||||
return implementation_target(**kwargs)
|
||||
|
||||
|
||||
class ResolverDescription:
|
||||
"""Describe how a BaseResolver must be instantiated."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
protocol: ProtocolResolver,
|
||||
specifier: str | None = None,
|
||||
implementation: str | None = None,
|
||||
server: str | None = None,
|
||||
port: int | None = None,
|
||||
*host_patterns: str,
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
self.protocol = protocol
|
||||
self.specifier = specifier
|
||||
self.implementation = implementation
|
||||
self.server = server
|
||||
self.port = port
|
||||
self.host_patterns = host_patterns
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __setitem__(self, key: str, value: typing.Any) -> None:
|
||||
self.kwargs[key] = value
|
||||
|
||||
def __contains__(self, item: str) -> bool:
|
||||
return item in self.kwargs
|
||||
|
||||
def new(self) -> BaseResolver:
|
||||
kwargs = {**self.kwargs}
|
||||
|
||||
if self.server:
|
||||
kwargs["server"] = self.server
|
||||
if self.port:
|
||||
kwargs["port"] = self.port
|
||||
if self.host_patterns:
|
||||
kwargs["patterns"] = self.host_patterns
|
||||
|
||||
return ResolverFactory.new(
|
||||
self.protocol,
|
||||
self.specifier,
|
||||
self.implementation,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_url(url: str) -> ResolverDescription:
|
||||
parsed_url = parse_url(url)
|
||||
|
||||
schema = parsed_url.scheme
|
||||
|
||||
if schema is None:
|
||||
raise ValueError("Given DNS url is missing a protocol")
|
||||
|
||||
specifier = None
|
||||
implementation = None
|
||||
|
||||
if "+" in schema:
|
||||
schema, specifier = tuple(schema.lower().split("+", 1))
|
||||
|
||||
protocol = ProtocolResolver(schema)
|
||||
kwargs: dict[str, typing.Any] = {}
|
||||
|
||||
if parsed_url.path:
|
||||
kwargs["path"] = parsed_url.path
|
||||
|
||||
if parsed_url.auth:
|
||||
kwargs["headers"] = dict()
|
||||
if ":" in parsed_url.auth:
|
||||
username, password = parsed_url.auth.split(":")
|
||||
|
||||
username = username.strip("'\"")
|
||||
password = password.strip("'\"")
|
||||
|
||||
kwargs["headers"]["Authorization"] = (
|
||||
f"Basic {b64encode(f'{username}:{password}'.encode()).decode()}"
|
||||
)
|
||||
else:
|
||||
kwargs["headers"]["Authorization"] = f"Bearer {parsed_url.auth}"
|
||||
|
||||
if parsed_url.query:
|
||||
parameters = parse_qs(parsed_url.query)
|
||||
|
||||
for parameter in parameters:
|
||||
if not parameters[parameter]:
|
||||
continue
|
||||
|
||||
parameter_insensible = parameter.lower()
|
||||
|
||||
if (
|
||||
isinstance(parameters[parameter], list)
|
||||
and len(parameters[parameter]) > 1
|
||||
):
|
||||
if parameter == "implementation":
|
||||
raise ValueError("Only one implementation can be passed to URL")
|
||||
|
||||
values = []
|
||||
|
||||
for e in parameters[parameter]:
|
||||
if "," in e:
|
||||
values.extend(e.split(","))
|
||||
else:
|
||||
values.append(e)
|
||||
|
||||
if parameter_insensible in kwargs:
|
||||
if isinstance(kwargs[parameter_insensible], list):
|
||||
kwargs[parameter_insensible].extend(values)
|
||||
else:
|
||||
values.append(kwargs[parameter_insensible])
|
||||
kwargs[parameter_insensible] = values
|
||||
continue
|
||||
|
||||
kwargs[parameter_insensible] = values
|
||||
continue
|
||||
|
||||
value: str = parameters[parameter][0].lower().strip(" ")
|
||||
|
||||
if parameter == "implementation":
|
||||
implementation = value
|
||||
continue
|
||||
|
||||
if "," in value:
|
||||
list_of_values = value.split(",")
|
||||
|
||||
if parameter_insensible in kwargs:
|
||||
if isinstance(kwargs[parameter_insensible], list):
|
||||
kwargs[parameter_insensible].extend(list_of_values)
|
||||
else:
|
||||
list_of_values.append(kwargs[parameter_insensible])
|
||||
continue
|
||||
|
||||
kwargs[parameter_insensible] = list_of_values
|
||||
continue
|
||||
|
||||
value_converted: bool | int | float | None = None
|
||||
|
||||
if value in ["false", "true"]:
|
||||
value_converted = True if value == "true" else False
|
||||
elif value.isdigit():
|
||||
value_converted = int(value)
|
||||
elif (
|
||||
value.count(".") == 1
|
||||
and value.index(".") > 0
|
||||
and value.replace(".", "").isdigit()
|
||||
):
|
||||
value_converted = float(value)
|
||||
|
||||
kwargs[parameter_insensible] = (
|
||||
value if value_converted is None else value_converted
|
||||
)
|
||||
|
||||
host_patterns: list[str] = []
|
||||
|
||||
if "hosts" in kwargs:
|
||||
host_patterns = (
|
||||
kwargs["hosts"].split(",")
|
||||
if isinstance(kwargs["hosts"], str)
|
||||
else kwargs["hosts"]
|
||||
)
|
||||
del kwargs["hosts"]
|
||||
|
||||
return ResolverDescription(
|
||||
protocol,
|
||||
specifier,
|
||||
implementation,
|
||||
parsed_url.host,
|
||||
parsed_url.port,
|
||||
*host_patterns,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -0,0 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._dict import InMemoryResolver
|
||||
|
||||
__all__ = ("InMemoryResolver",)
|
||||
@@ -0,0 +1,192 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
|
||||
from ....util.url import _IPV6_ADDRZ_RE
|
||||
from ..protocols import BaseResolver, ProtocolResolver
|
||||
from ..utils import is_ipv4, is_ipv6
|
||||
|
||||
|
||||
class InMemoryResolver(BaseResolver):
|
||||
protocol = ProtocolResolver.MANUAL
|
||||
implementation = "dict"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any):
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
kwargs.pop("port")
|
||||
super().__init__(None, None, *patterns, **kwargs)
|
||||
|
||||
self._maxsize = 65535 if "maxsize" not in kwargs else int(kwargs["maxsize"])
|
||||
self._hosts: dict[str, list[tuple[socket.AddressFamily, str]]] = {}
|
||||
|
||||
if self._host_patterns:
|
||||
for record in self._host_patterns:
|
||||
if ":" not in record:
|
||||
continue
|
||||
hostname, addr = record.split(":", 1)
|
||||
self.register(hostname, addr)
|
||||
self._host_patterns = tuple([])
|
||||
|
||||
# probably about our happy eyeballs impl (sync only)
|
||||
if len(self._hosts) == 1 and len(self._hosts[list(self._hosts.keys())[0]]) == 1:
|
||||
self._unsafe_expose = True
|
||||
|
||||
def recycle(self) -> BaseResolver:
|
||||
return self
|
||||
|
||||
def close(self) -> None:
|
||||
pass # no-op
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True
|
||||
|
||||
def have_constraints(self) -> bool:
|
||||
return True
|
||||
|
||||
def support(self, hostname: str | bytes | None) -> bool | None:
|
||||
if hostname is None:
|
||||
hostname = "localhost"
|
||||
if isinstance(hostname, bytes):
|
||||
hostname = hostname.decode("ascii")
|
||||
return hostname in self._hosts
|
||||
|
||||
def register(self, hostname: str, ipaddr: str) -> None:
|
||||
with self._lock:
|
||||
if hostname not in self._hosts:
|
||||
self._hosts[hostname] = []
|
||||
else:
|
||||
for e in self._hosts[hostname]:
|
||||
t, addr = e
|
||||
if addr in ipaddr:
|
||||
return
|
||||
|
||||
if _IPV6_ADDRZ_RE.match(ipaddr):
|
||||
self._hosts[hostname].append((socket.AF_INET6, ipaddr[1:-1]))
|
||||
elif is_ipv6(ipaddr):
|
||||
self._hosts[hostname].append((socket.AF_INET6, ipaddr))
|
||||
else:
|
||||
self._hosts[hostname].append((socket.AF_INET, ipaddr))
|
||||
|
||||
if len(self._hosts) > self._maxsize:
|
||||
k = None
|
||||
for k in self._hosts.keys():
|
||||
break
|
||||
if k:
|
||||
self._hosts.pop(k)
|
||||
|
||||
def clear(self, hostname: str) -> None:
|
||||
with self._lock:
|
||||
if hostname in self._hosts:
|
||||
del self._hosts[hostname]
|
||||
|
||||
def getaddrinfo(
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if host is None:
|
||||
host = "localhost" # Defensive: stdlib cpy behavior
|
||||
|
||||
if port is None:
|
||||
port = 0 # Defensive: stdlib cpy behavior
|
||||
if isinstance(port, str):
|
||||
port = int(port) # Defensive: stdlib cpy behavior
|
||||
if port < 0:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Servname not supported for ai_socktype"
|
||||
)
|
||||
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii") # Defensive: stdlib cpy behavior
|
||||
|
||||
if is_ipv4(host):
|
||||
if family == socket.AF_INET6:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
type,
|
||||
6,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
),
|
||||
)
|
||||
]
|
||||
elif is_ipv6(host):
|
||||
if family == socket.AF_INET:
|
||||
raise socket.gaierror( # Defensive: stdlib cpy behavior
|
||||
"Address family for hostname not supported"
|
||||
)
|
||||
return [
|
||||
(
|
||||
socket.AF_INET6,
|
||||
type,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
results: list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
] = []
|
||||
|
||||
with self._lock:
|
||||
if host not in self._hosts:
|
||||
raise socket.gaierror(f"no records found for hostname {host} in-memory")
|
||||
|
||||
for entry in self._hosts[host]:
|
||||
addr_type, addr_target = entry
|
||||
|
||||
if family != socket.AF_UNSPEC:
|
||||
if family != addr_type:
|
||||
continue
|
||||
|
||||
results.append(
|
||||
(
|
||||
addr_type,
|
||||
type,
|
||||
6 if type == socket.SOCK_STREAM else 17,
|
||||
"",
|
||||
(addr_target, port)
|
||||
if addr_type == socket.AF_INET
|
||||
else (addr_target, port, 0, 0),
|
||||
)
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise socket.gaierror(f"no records found for hostname {host} in-memory")
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
|
||||
from ..protocols import BaseResolver, ProtocolResolver
|
||||
from ..utils import is_ipv4, is_ipv6
|
||||
|
||||
|
||||
class NullResolver(BaseResolver):
|
||||
protocol = ProtocolResolver.NULL
|
||||
implementation = "dummy"
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any):
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
kwargs.pop("port")
|
||||
super().__init__(None, None, *patterns, **kwargs)
|
||||
|
||||
def recycle(self) -> BaseResolver:
|
||||
return self
|
||||
|
||||
def close(self) -> None:
|
||||
pass # no-op
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True
|
||||
|
||||
def getaddrinfo(
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if host is None:
|
||||
host = "localhost" # Defensive: stdlib cpy behavior
|
||||
|
||||
if port is None:
|
||||
port = 0 # Defensive: stdlib cpy behavior
|
||||
if isinstance(port, str):
|
||||
port = int(port) # Defensive: stdlib cpy behavior
|
||||
if port < 0:
|
||||
raise socket.gaierror(
|
||||
"Servname not supported for ai_socktype"
|
||||
) # Defensive: stdlib cpy behavior
|
||||
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii") # Defensive: stdlib cpy behavior
|
||||
|
||||
if is_ipv4(host):
|
||||
if family == socket.AF_INET6:
|
||||
raise socket.gaierror(
|
||||
"Address family for hostname not supported"
|
||||
) # Defensive: stdlib cpy behavior
|
||||
return [
|
||||
(
|
||||
socket.AF_INET,
|
||||
type,
|
||||
6,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
),
|
||||
)
|
||||
]
|
||||
elif is_ipv6(host):
|
||||
if family == socket.AF_INET:
|
||||
raise socket.gaierror(
|
||||
"Address family for hostname not supported"
|
||||
) # Defensive: stdlib cpy behavior
|
||||
return [
|
||||
(
|
||||
socket.AF_INET6,
|
||||
type,
|
||||
17,
|
||||
"",
|
||||
(
|
||||
host,
|
||||
port,
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
raise socket.gaierror(f"Tried to resolve '{host}' using the NullResolver")
|
||||
|
||||
|
||||
__all__ = ("NullResolver",)
|
||||
@@ -0,0 +1,655 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import socket
|
||||
import struct
|
||||
import sys
|
||||
import threading
|
||||
import typing
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from random import randint
|
||||
|
||||
from ..._constant import UDP_LINUX_GRO
|
||||
from ..._typing import _TYPE_SOCKET_OPTIONS, _TYPE_TIMEOUT_INTERNAL
|
||||
from ...exceptions import LocationParseError
|
||||
from ...util.connection import _set_socket_options, allowed_gai_family
|
||||
from ...util.ssl_match_hostname import CertificateError, match_hostname
|
||||
from ...util.timeout import _DEFAULT_TIMEOUT
|
||||
from .utils import inet4_ntoa, inet6_ntoa, parse_https_rdata
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .utils import HttpsRecord
|
||||
|
||||
|
||||
class ProtocolResolver(str, Enum):
|
||||
"""
|
||||
At urllib3.future we aim to propose a wide range of DNS-protocols.
|
||||
The most used techniques are available.
|
||||
"""
|
||||
|
||||
#: Ask the OS native DNS layer
|
||||
SYSTEM = "system"
|
||||
#: DNS over HTTPS
|
||||
DOH = "doh"
|
||||
#: DNS over QUIC
|
||||
DOQ = "doq"
|
||||
#: DNS over TLS
|
||||
DOT = "dot"
|
||||
#: DNS over UDP (insecure)
|
||||
DOU = "dou"
|
||||
#: Manual (e.g. hosts)
|
||||
MANUAL = "in-memory"
|
||||
#: Void (e.g. purposely disable resolution)
|
||||
NULL = "null"
|
||||
#: Custom (e.g. your own implementation, use this when it does not suit any of the protocols specified)
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class BaseResolver(metaclass=ABCMeta):
|
||||
protocol: typing.ClassVar[ProtocolResolver]
|
||||
specifier: typing.ClassVar[str | None] = None
|
||||
|
||||
implementation: typing.ClassVar[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: str | None,
|
||||
port: int | None = None,
|
||||
*patterns: str,
|
||||
**kwargs: typing.Any,
|
||||
) -> None:
|
||||
self._server = server
|
||||
self._port = port
|
||||
self._host_patterns: tuple[str, ...] = patterns
|
||||
self._lock = threading.Lock()
|
||||
self._kwargs = kwargs
|
||||
|
||||
if not self._host_patterns and "patterns" in kwargs:
|
||||
self._host_patterns = kwargs["patterns"]
|
||||
|
||||
# allow to temporarily expose a sock that is "being" created
|
||||
# this helps with our Happy Eyeballs implementation in sync.
|
||||
self._unsafe_expose: bool = False
|
||||
self._sock_cursor: socket.socket | None = None
|
||||
|
||||
def recycle(self) -> BaseResolver:
|
||||
if self.is_available():
|
||||
raise RuntimeError("Attempting to recycle a Resolver that was not closed")
|
||||
|
||||
args = list(self.__class__.__init__.__code__.co_varnames)
|
||||
args.remove("self")
|
||||
|
||||
kwargs_cpy = deepcopy(self._kwargs)
|
||||
|
||||
if self._server:
|
||||
kwargs_cpy["server"] = self._server
|
||||
if self._port:
|
||||
kwargs_cpy["port"] = self._port
|
||||
|
||||
if "patterns" in args and "kwargs" in args:
|
||||
return self.__class__(*self._host_patterns, **kwargs_cpy) # type: ignore[arg-type]
|
||||
elif "kwargs" in args:
|
||||
return self.__class__(**kwargs_cpy)
|
||||
|
||||
return self.__class__() # type: ignore[call-arg]
|
||||
|
||||
@property
|
||||
def server(self) -> str | None:
|
||||
return self._server
|
||||
|
||||
@property
|
||||
def port(self) -> int | None:
|
||||
return self._port
|
||||
|
||||
def have_constraints(self) -> bool:
|
||||
return bool(self._host_patterns)
|
||||
|
||||
def support(self, hostname: str | bytes | None) -> bool | None:
|
||||
"""
|
||||
Determine if given hostname is especially resolvable by given resolver.
|
||||
If this resolver does not have any constrained list of host, it returns None. Meaning
|
||||
it support any hostname for resolution.
|
||||
"""
|
||||
if not self._host_patterns:
|
||||
return None
|
||||
if hostname is None:
|
||||
hostname = "localhost"
|
||||
if isinstance(hostname, bytes):
|
||||
hostname = hostname.decode("ascii")
|
||||
try:
|
||||
match_hostname(
|
||||
{"subjectAltName": (tuple(("DNS", e) for e in self._host_patterns))},
|
||||
hostname,
|
||||
)
|
||||
except CertificateError:
|
||||
return False
|
||||
return True
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""Terminate the given resolver instance. This should render it unusable. Further inquiries should raise an exception."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def is_available(self) -> bool:
|
||||
"""Determine if Resolver can receive inquiries."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def getaddrinfo(
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
"""This method align itself on the standard library socket.getaddrinfo(). It must be implemented as-is on your Resolver."""
|
||||
raise NotImplementedError
|
||||
|
||||
# This function is copied from socket.py in the Python 2.7 standard
|
||||
# library test suite. Added to its signature is only `socket_options`.
|
||||
# One additional modification is that we avoid binding to IPv6 servers
|
||||
# discovered in DNS if the system doesn't have IPv6 functionality.
|
||||
def create_connection(
|
||||
self,
|
||||
address: tuple[str, int],
|
||||
timeout: _TYPE_TIMEOUT_INTERNAL = _DEFAULT_TIMEOUT,
|
||||
source_address: tuple[str, int] | None = None,
|
||||
socket_options: _TYPE_SOCKET_OPTIONS | None = None,
|
||||
socket_kind: socket.SocketKind = socket.SOCK_STREAM,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
timing_hook: typing.Callable[[tuple[timedelta, timedelta, datetime]], None]
|
||||
| None = None,
|
||||
default_socket_family: socket.AddressFamily = socket.AF_UNSPEC,
|
||||
) -> socket.socket:
|
||||
"""Connect to *address* and return the socket object.
|
||||
|
||||
Convenience function. Connect to *address* (a 2-tuple ``(host,
|
||||
port)``) and return the socket object. Passing the optional
|
||||
*timeout* parameter will set the timeout on the socket instance
|
||||
before attempting to connect. If no *timeout* is supplied, the
|
||||
global default timeout setting returned by :func:`socket.getdefaulttimeout`
|
||||
is used. If *source_address* is set it must be a tuple of (host, port)
|
||||
for the socket to bind as a source address before making the connection.
|
||||
An host of '' or port 0 tells the OS to use the default.
|
||||
"""
|
||||
|
||||
host, port = address
|
||||
if host.startswith("["):
|
||||
host = host.strip("[]")
|
||||
err = None
|
||||
|
||||
# Using the value from allowed_gai_family() in the context of getaddrinfo lets
|
||||
# us select whether to work with IPv4 DNS records, IPv6 records, or both.
|
||||
# The original create_connection function always returns all records.
|
||||
family = allowed_gai_family()
|
||||
|
||||
if family != socket.AF_UNSPEC:
|
||||
default_socket_family = family
|
||||
|
||||
if source_address is not None:
|
||||
if isinstance(
|
||||
ipaddress.ip_address(source_address[0]), ipaddress.IPv4Address
|
||||
):
|
||||
default_socket_family = socket.AF_INET
|
||||
else:
|
||||
default_socket_family = socket.AF_INET6
|
||||
|
||||
try:
|
||||
host.encode("idna")
|
||||
except UnicodeError:
|
||||
raise LocationParseError(f"'{host}', label empty or too long") from None
|
||||
|
||||
dt_pre_resolve = datetime.now(tz=timezone.utc)
|
||||
records = self.getaddrinfo(
|
||||
host,
|
||||
port,
|
||||
default_socket_family,
|
||||
socket_kind,
|
||||
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
|
||||
)
|
||||
delta_post_resolve = datetime.now(tz=timezone.utc) - dt_pre_resolve
|
||||
|
||||
dt_pre_established = datetime.now(tz=timezone.utc)
|
||||
for res in records:
|
||||
af, socktype, proto, canonname, sa = res
|
||||
sock = None
|
||||
try:
|
||||
sock = socket.socket(af, socktype, proto)
|
||||
|
||||
# we need to add this or reusing the same origin port will likely fail within
|
||||
# short period of time. kernel put port on wait shut.
|
||||
if source_address is not None:
|
||||
try:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
|
||||
except (
|
||||
OSError,
|
||||
AttributeError,
|
||||
): # Defensive: Windows or very old OS?
|
||||
try:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
except (
|
||||
OSError,
|
||||
AttributeError,
|
||||
): # Defensive: we can't do anything better than this.
|
||||
pass
|
||||
|
||||
try:
|
||||
sock.setsockopt(
|
||||
socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0)
|
||||
)
|
||||
except (OSError, AttributeError):
|
||||
pass
|
||||
|
||||
sock.bind(source_address)
|
||||
|
||||
# attempt to leverage GRO when under Linux
|
||||
if socktype == socket.SOCK_DGRAM and sys.platform == "linux":
|
||||
try:
|
||||
sock.setsockopt(socket.SOL_UDP, UDP_LINUX_GRO, 1)
|
||||
except OSError: # Defensive: oh, well(...) anyway!
|
||||
pass
|
||||
|
||||
# If provided, set socket level options before connecting.
|
||||
_set_socket_options(sock, socket_options)
|
||||
|
||||
if timeout is not _DEFAULT_TIMEOUT:
|
||||
sock.settimeout(timeout)
|
||||
|
||||
if self._unsafe_expose:
|
||||
self._sock_cursor = sock
|
||||
|
||||
sock.connect(sa)
|
||||
|
||||
if self._unsafe_expose:
|
||||
self._sock_cursor = None
|
||||
# Break explicitly a reference cycle
|
||||
err = None
|
||||
|
||||
delta_post_established = (
|
||||
datetime.now(tz=timezone.utc) - dt_pre_established
|
||||
)
|
||||
|
||||
if timing_hook is not None:
|
||||
timing_hook(
|
||||
(
|
||||
delta_post_resolve,
|
||||
delta_post_established,
|
||||
datetime.now(tz=timezone.utc),
|
||||
)
|
||||
)
|
||||
|
||||
return sock
|
||||
except (OSError, OverflowError) as _:
|
||||
err = _
|
||||
if sock is not None:
|
||||
sock.close()
|
||||
if isinstance(_, OverflowError):
|
||||
break
|
||||
|
||||
if err is not None:
|
||||
try:
|
||||
raise err
|
||||
finally:
|
||||
# Break explicitly a reference cycle
|
||||
err = None
|
||||
else:
|
||||
raise OSError("getaddrinfo returns an empty list")
|
||||
|
||||
|
||||
class ManyResolver(BaseResolver):
|
||||
"""
|
||||
Special resolver that use many child resolver. Priorities
|
||||
are based on given order (list of BaseResolver).
|
||||
"""
|
||||
|
||||
def __init__(self, *resolvers: BaseResolver) -> None:
|
||||
super().__init__(None, None)
|
||||
|
||||
self._size = len(resolvers)
|
||||
|
||||
self._unconstrained: list[BaseResolver] = [
|
||||
_ for _ in resolvers if not _.have_constraints()
|
||||
]
|
||||
self._constrained: list[BaseResolver] = [
|
||||
_ for _ in resolvers if _.have_constraints()
|
||||
]
|
||||
|
||||
self._concurrent: int = 0
|
||||
self._terminated: bool = False
|
||||
|
||||
def recycle(self) -> BaseResolver:
|
||||
resolvers = []
|
||||
|
||||
for resolver in self._unconstrained + self._constrained:
|
||||
resolvers.append(resolver.recycle())
|
||||
|
||||
return ManyResolver(*resolvers)
|
||||
|
||||
def close(self) -> None:
|
||||
for resolver in self._unconstrained + self._constrained:
|
||||
resolver.close()
|
||||
|
||||
self._terminated = True
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return not self._terminated
|
||||
|
||||
def __resolvers(
|
||||
self, constrained: bool = False
|
||||
) -> typing.Generator[BaseResolver, None, None]:
|
||||
resolvers = self._unconstrained if not constrained else self._constrained
|
||||
|
||||
if not resolvers:
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
self._concurrent += 1
|
||||
|
||||
try:
|
||||
resolver_count = len(resolvers)
|
||||
start_idx = (self._concurrent - 1) % resolver_count
|
||||
|
||||
for idx in range(start_idx, resolver_count):
|
||||
if not resolvers[idx].is_available():
|
||||
with self._lock:
|
||||
resolvers[idx] = resolvers[idx].recycle()
|
||||
yield resolvers[idx]
|
||||
|
||||
if start_idx > 0:
|
||||
for idx in range(0, start_idx):
|
||||
if not resolvers[idx].is_available():
|
||||
with self._lock:
|
||||
resolvers[idx] = resolvers[idx].recycle()
|
||||
yield resolvers[idx]
|
||||
finally:
|
||||
with self._lock:
|
||||
self._concurrent -= 1
|
||||
|
||||
def getaddrinfo(
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if isinstance(host, bytes):
|
||||
host = host.decode("ascii")
|
||||
if host is None:
|
||||
host = "localhost"
|
||||
|
||||
tested_resolvers = []
|
||||
|
||||
any_constrained_tried: bool = False
|
||||
|
||||
for resolver in self.__resolvers(True):
|
||||
can_resolve = resolver.support(host)
|
||||
|
||||
if can_resolve is True:
|
||||
any_constrained_tried = True
|
||||
|
||||
try:
|
||||
results = resolver.getaddrinfo(
|
||||
host,
|
||||
port,
|
||||
family,
|
||||
type,
|
||||
proto,
|
||||
flags,
|
||||
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
|
||||
)
|
||||
|
||||
if results:
|
||||
return results
|
||||
except socket.gaierror as exc:
|
||||
if isinstance(exc.args[0], str) and (
|
||||
"DNSSEC" in exc.args[0] or "DNSKEY" in exc.args[0]
|
||||
):
|
||||
raise
|
||||
continue
|
||||
elif can_resolve is False:
|
||||
tested_resolvers.append(resolver)
|
||||
|
||||
if any_constrained_tried:
|
||||
raise socket.gaierror(
|
||||
f"Name or service not known: {host} using {self._size - len(self._unconstrained)} resolver(s)"
|
||||
)
|
||||
|
||||
for resolver in self.__resolvers():
|
||||
try:
|
||||
results = resolver.getaddrinfo(
|
||||
host,
|
||||
port,
|
||||
family,
|
||||
type,
|
||||
proto,
|
||||
flags,
|
||||
quic_upgrade_via_dns_rr=quic_upgrade_via_dns_rr,
|
||||
)
|
||||
|
||||
if results:
|
||||
return results
|
||||
except socket.gaierror as exc:
|
||||
if isinstance(exc.args[0], str) and (
|
||||
"DNSSEC" in exc.args[0] or "DNSKEY" in exc.args[0]
|
||||
):
|
||||
raise
|
||||
continue
|
||||
|
||||
raise socket.gaierror(
|
||||
f"Name or service not known: {host} using {self._size - len(self._constrained)} resolver(s)"
|
||||
)
|
||||
|
||||
|
||||
class SupportedQueryType(int, Enum):
|
||||
"""
|
||||
urllib3.future does not need anything else so far. let's be pragmatic.
|
||||
Each type is associated with its hex value as per the RFC.
|
||||
"""
|
||||
|
||||
A = 0x0001
|
||||
AAAA = 0x001C
|
||||
HTTPS = 0x0041
|
||||
|
||||
|
||||
class DomainNameServerQuery:
|
||||
"""
|
||||
Minimalist DNS query/message to ask for A, AAAA and HTTPS records.
|
||||
Only meant for urllib3.future use. Does not cover all of possible extent of use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, host: str, query_type: SupportedQueryType, override_id: int | None = None
|
||||
) -> None:
|
||||
self._id = struct.pack(
|
||||
"!H", randint(0x0000, 0xFFFF) if override_id is None else override_id
|
||||
)
|
||||
self._host = host
|
||||
self._query = query_type
|
||||
self._flags = struct.pack("!H", 0x0100)
|
||||
self._qd_count = struct.pack("!H", 1)
|
||||
|
||||
self._cached: bytes | None = None
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return struct.unpack("!H", self._id)[0] # type: ignore[no-any-return]
|
||||
|
||||
@property
|
||||
def raw_id(self) -> bytes:
|
||||
return self._id
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Query '{self._host}' IN {self._query.name}>"
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
if self._cached:
|
||||
return self._cached
|
||||
|
||||
payload = b""
|
||||
|
||||
payload += self._id
|
||||
payload += self._flags
|
||||
payload += self._qd_count
|
||||
payload += b"\x00\x00"
|
||||
payload += b"\x00\x00"
|
||||
payload += b"\x00\x00"
|
||||
|
||||
for ext in self._host.split("."):
|
||||
payload += struct.pack("!B", len(ext))
|
||||
payload += ext.encode("ascii")
|
||||
|
||||
payload += b"\x00"
|
||||
payload += struct.pack("!H", self._query.value)
|
||||
payload += struct.pack("!H", 0x0001)
|
||||
|
||||
self._cached = payload
|
||||
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def bulk(host: str, *types: SupportedQueryType) -> list[DomainNameServerQuery]:
|
||||
queries = []
|
||||
|
||||
for query_type in types:
|
||||
queries.append(DomainNameServerQuery(host, query_type=query_type))
|
||||
|
||||
return queries
|
||||
|
||||
|
||||
#: Most common status code, not exhaustive at all.
|
||||
COMMON_RCODE_LABEL: dict[int, str] = {
|
||||
0: "No Error",
|
||||
1: "Format Error",
|
||||
2: "Server Failure",
|
||||
3: "Non-Existent Domain",
|
||||
5: "Query Refused",
|
||||
9: "Not Authorized",
|
||||
}
|
||||
|
||||
|
||||
class DomainNameServerParseException(Exception): ...
|
||||
|
||||
|
||||
class DomainNameServerReturn:
|
||||
"""
|
||||
Minimalist DNS response parser. Allow to quickly extract key-data out of it.
|
||||
Meant for A, AAAA and HTTPS records. Basically only what we need.
|
||||
"""
|
||||
|
||||
def __init__(self, payload: bytes) -> None:
|
||||
try:
|
||||
up = struct.unpack("!HHHHHH", payload[:12])
|
||||
|
||||
self._id = up[0]
|
||||
self._flags = up[1]
|
||||
self._qd_count = up[2]
|
||||
self._an_count = up[3]
|
||||
|
||||
self._rcode = int(f"0x{hex(payload[3])[-1]}", 16)
|
||||
|
||||
self._hostname: str = ""
|
||||
|
||||
idx = 12
|
||||
|
||||
while True:
|
||||
c = payload[idx]
|
||||
|
||||
if c == 0:
|
||||
idx += 1
|
||||
break
|
||||
|
||||
self._hostname += payload[idx + 1 : idx + 1 + c].decode("ascii") + "."
|
||||
|
||||
idx += c + 1
|
||||
|
||||
self._records: list[tuple[SupportedQueryType, int, str | HttpsRecord]] = []
|
||||
|
||||
if self._an_count:
|
||||
idx += 4
|
||||
|
||||
while idx < len(payload):
|
||||
up = struct.unpack("!HHHI", payload[idx : idx + 10])
|
||||
entry_size = struct.unpack("!H", payload[idx + 10 : idx + 12])[0]
|
||||
|
||||
data = payload[idx + 12 : idx + 12 + entry_size]
|
||||
|
||||
if len(data) == 4:
|
||||
decoded_data: str | HttpsRecord = inet4_ntoa(data)
|
||||
elif len(data) == 16:
|
||||
decoded_data = inet6_ntoa(data)
|
||||
elif data:
|
||||
decoded_data = parse_https_rdata(data)
|
||||
else:
|
||||
continue
|
||||
|
||||
try:
|
||||
self._records.append(
|
||||
(SupportedQueryType(up[1]), up[-1], decoded_data)
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
idx += 12 + entry_size
|
||||
except (struct.error, IndexError, ValueError, UnicodeDecodeError) as e:
|
||||
raise DomainNameServerParseException(
|
||||
"A protocol error occurred while parsing the DNS response payload: "
|
||||
f"{str(e)}"
|
||||
) from e
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._id # type: ignore[no-any-return]
|
||||
|
||||
@property
|
||||
def hostname(self) -> str:
|
||||
return self._hostname
|
||||
|
||||
@property
|
||||
def records(self) -> list[tuple[SupportedQueryType, int, str | HttpsRecord]]:
|
||||
return self._records
|
||||
|
||||
@property
|
||||
def is_found(self) -> bool:
|
||||
return bool(self._records)
|
||||
|
||||
@property
|
||||
def rcode(self) -> int:
|
||||
return self._rcode
|
||||
|
||||
@property
|
||||
def is_ok(self) -> bool:
|
||||
return self._rcode == 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self.is_ok:
|
||||
return f"<Records '{self.hostname}' {self._records}>"
|
||||
return f"<DNS Error '{self.hostname}' with Status {self.rcode} ({COMMON_RCODE_LABEL[self.rcode] if self.rcode in COMMON_RCODE_LABEL else 'Unknown'})>"
|
||||
@@ -0,0 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ._socket import SystemResolver
|
||||
|
||||
__all__ = ("SystemResolver",)
|
||||
@@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import typing
|
||||
|
||||
from ..protocols import BaseResolver, ProtocolResolver
|
||||
|
||||
|
||||
class SystemResolver(BaseResolver):
|
||||
implementation = "socket"
|
||||
protocol = ProtocolResolver.SYSTEM
|
||||
|
||||
def __init__(self, *patterns: str, **kwargs: typing.Any):
|
||||
if "server" in kwargs:
|
||||
kwargs.pop("server")
|
||||
if "port" in kwargs:
|
||||
kwargs.pop("port")
|
||||
super().__init__(None, None, *patterns, **kwargs)
|
||||
|
||||
def support(self, hostname: str | bytes | None) -> bool | None:
|
||||
if hostname is None:
|
||||
return True
|
||||
if isinstance(hostname, bytes):
|
||||
hostname = hostname.decode("ascii")
|
||||
if hostname == "localhost":
|
||||
return True
|
||||
return super().support(hostname)
|
||||
|
||||
def recycle(self) -> BaseResolver:
|
||||
return self
|
||||
|
||||
def close(self) -> None:
|
||||
pass # no-op!
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True
|
||||
|
||||
def getaddrinfo(
|
||||
self,
|
||||
host: bytes | str | None,
|
||||
port: str | int | None,
|
||||
family: socket.AddressFamily,
|
||||
type: socket.SocketKind,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
*,
|
||||
quic_upgrade_via_dns_rr: bool = False,
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
# the | tuple[int, bytes] is silently ignored, can't happen with our cases.
|
||||
return socket.getaddrinfo( # type: ignore[return-value]
|
||||
host=host,
|
||||
port=port,
|
||||
family=family,
|
||||
type=type,
|
||||
proto=proto,
|
||||
flags=flags,
|
||||
)
|
||||
@@ -0,0 +1,322 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import socket
|
||||
import struct
|
||||
import typing
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
|
||||
class HttpsRecord(typing.TypedDict):
|
||||
priority: int
|
||||
target: str
|
||||
alpn: list[str]
|
||||
ipv4hint: list[str]
|
||||
ipv6hint: list[str]
|
||||
echconfig: list[str]
|
||||
|
||||
|
||||
def inet4_ntoa(address: bytes) -> str:
|
||||
"""
|
||||
Convert an IPv4 address from bytes to str.
|
||||
"""
|
||||
if len(address) != 4:
|
||||
raise ValueError(
|
||||
f"IPv4 addresses are 4 bytes long, got {len(address)} byte(s) instead"
|
||||
)
|
||||
|
||||
return "%u.%u.%u.%u" % (address[0], address[1], address[2], address[3])
|
||||
|
||||
|
||||
def inet6_ntoa(address: bytes) -> str:
|
||||
"""
|
||||
Convert an IPv6 address from bytes to str.
|
||||
"""
|
||||
if len(address) != 16:
|
||||
raise ValueError(
|
||||
f"IPv6 addresses are 16 bytes long, got {len(address)} byte(s) instead"
|
||||
)
|
||||
|
||||
hex = binascii.hexlify(address)
|
||||
chunks = []
|
||||
|
||||
i = 0
|
||||
length = len(hex)
|
||||
|
||||
while i < length:
|
||||
chunk = hex[i : i + 4].decode().lstrip("0") or "0"
|
||||
chunks.append(chunk)
|
||||
i += 4
|
||||
|
||||
# Compress the longest subsequence of 0-value chunks to ::
|
||||
best_start = 0
|
||||
best_len = 0
|
||||
start = -1
|
||||
last_was_zero = False
|
||||
|
||||
for i in range(8):
|
||||
if chunks[i] != "0":
|
||||
if last_was_zero:
|
||||
end = i
|
||||
current_len = end - start
|
||||
if current_len > best_len:
|
||||
best_start = start
|
||||
best_len = current_len
|
||||
last_was_zero = False
|
||||
elif not last_was_zero:
|
||||
start = i
|
||||
last_was_zero = True
|
||||
if last_was_zero:
|
||||
end = 8
|
||||
current_len = end - start
|
||||
if current_len > best_len:
|
||||
best_start = start
|
||||
best_len = current_len
|
||||
if best_len > 1:
|
||||
if best_start == 0 and (best_len == 6 or best_len == 5 and chunks[5] == "ffff"):
|
||||
# We have an embedded IPv4 address
|
||||
if best_len == 6:
|
||||
prefix = "::"
|
||||
else:
|
||||
prefix = "::ffff:"
|
||||
thex = prefix + inet4_ntoa(address[12:])
|
||||
else:
|
||||
thex = (
|
||||
":".join(chunks[:best_start])
|
||||
+ "::"
|
||||
+ ":".join(chunks[best_start + best_len :])
|
||||
)
|
||||
else:
|
||||
thex = ":".join(chunks)
|
||||
|
||||
return thex
|
||||
|
||||
|
||||
def packet_fragment(payload: bytes, *identifiers: bytes) -> tuple[bytes, ...]:
|
||||
results = []
|
||||
|
||||
offset = 0
|
||||
|
||||
start_packet_idx = []
|
||||
lead_identifier = None
|
||||
|
||||
for identifier in identifiers:
|
||||
idx = payload[:12].find(identifier)
|
||||
|
||||
if idx == -1:
|
||||
continue
|
||||
|
||||
if idx != 0:
|
||||
offset = idx
|
||||
|
||||
start_packet_idx.append(idx - offset)
|
||||
|
||||
lead_identifier = identifier
|
||||
break
|
||||
|
||||
for identifier in identifiers:
|
||||
if identifier == lead_identifier:
|
||||
continue
|
||||
|
||||
if offset == 0:
|
||||
idx = payload.find(b"\x02" + identifier)
|
||||
else:
|
||||
idx = payload.find(identifier)
|
||||
|
||||
if idx == -1:
|
||||
continue
|
||||
|
||||
start_packet_idx.append(idx - offset)
|
||||
|
||||
if not start_packet_idx:
|
||||
raise ValueError(
|
||||
"no identifiable dns message emerged from given payload. "
|
||||
"this should not happen at all. networking issue?"
|
||||
)
|
||||
|
||||
if len(start_packet_idx) == 1:
|
||||
return (payload,)
|
||||
|
||||
start_packet_idx = sorted(start_packet_idx)
|
||||
|
||||
previous_idx = None
|
||||
|
||||
for idx in start_packet_idx:
|
||||
if previous_idx is None:
|
||||
previous_idx = idx
|
||||
continue
|
||||
results.append(payload[previous_idx:idx])
|
||||
previous_idx = idx
|
||||
|
||||
results.append(payload[previous_idx:])
|
||||
|
||||
return tuple(results)
|
||||
|
||||
|
||||
def is_ipv4(addr: str) -> bool:
|
||||
try:
|
||||
socket.inet_aton(addr)
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def is_ipv6(addr: str) -> bool:
|
||||
try:
|
||||
socket.inet_pton(socket.AF_INET6, addr)
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def validate_length_of(hostname: str) -> None:
|
||||
"""RFC 1035 impose a limit on a domain name length. We verify it there."""
|
||||
if len(hostname.strip(".")) > 253:
|
||||
raise UnicodeError("hostname to resolve exceed 253 characters")
|
||||
elif any([len(_) > 63 for _ in hostname.split(".")]):
|
||||
raise UnicodeError("at least one label to resolve exceed 63 characters")
|
||||
|
||||
|
||||
def rfc1035_should_read(payload: bytes) -> bool:
|
||||
if not payload:
|
||||
return False
|
||||
if len(payload) <= 2:
|
||||
return True
|
||||
|
||||
cursor = payload
|
||||
|
||||
while True:
|
||||
expected_size: int = struct.unpack("!H", cursor[:2])[0]
|
||||
|
||||
if len(cursor[2:]) == expected_size:
|
||||
return False
|
||||
elif len(cursor[2:]) < expected_size:
|
||||
return True
|
||||
|
||||
cursor = cursor[2 + expected_size :]
|
||||
|
||||
|
||||
def rfc1035_unpack(payload: bytes) -> tuple[bytes, ...]:
|
||||
cursor = payload
|
||||
packets = []
|
||||
|
||||
while cursor:
|
||||
expected_size: int = struct.unpack("!H", cursor[:2])[0]
|
||||
|
||||
packets.append(cursor[2 : 2 + expected_size])
|
||||
cursor = cursor[2 + expected_size :]
|
||||
|
||||
return tuple(packets)
|
||||
|
||||
|
||||
def rfc1035_pack(message: bytes) -> bytes:
|
||||
return struct.pack("!H", len(message)) + message
|
||||
|
||||
|
||||
def read_name(data: bytes, offset: int) -> tuple[str, int]:
|
||||
"""
|
||||
Read a DNS‐encoded name (with compression pointers) from data[offset:].
|
||||
Returns (name, new_offset).
|
||||
"""
|
||||
labels = []
|
||||
while True:
|
||||
length = data[offset]
|
||||
# compression pointer?
|
||||
if length & 0xC0 == 0xC0:
|
||||
pointer = struct.unpack_from("!H", data, offset)[0] & 0x3FFF
|
||||
subname, _ = read_name(data, pointer)
|
||||
labels.append(subname)
|
||||
offset += 2
|
||||
break
|
||||
if length == 0:
|
||||
offset += 1
|
||||
break
|
||||
offset += 1
|
||||
labels.append(data[offset : offset + length].decode())
|
||||
offset += length
|
||||
return ".".join(labels), offset
|
||||
|
||||
|
||||
def parse_echconfigs(buf: bytes) -> list[str]:
|
||||
"""
|
||||
buf is the raw bytes of the ECHConfig vector:
|
||||
- 2-byte total length, then for each:
|
||||
- 2-byte cfg length + that many bytes of cfg
|
||||
We return a list of Base64 strings (one per config).
|
||||
"""
|
||||
if len(buf) < 2:
|
||||
return []
|
||||
off = 2
|
||||
total = struct.unpack_from("!H", buf, 0)[0]
|
||||
end = 2 + total
|
||||
out = []
|
||||
while off + 2 <= end:
|
||||
cfg_len = struct.unpack_from("!H", buf, off)[0]
|
||||
off += 2
|
||||
cfg = buf[off : off + cfg_len]
|
||||
off += cfg_len
|
||||
out.append(base64.b64encode(cfg).decode())
|
||||
return out
|
||||
|
||||
|
||||
def parse_https_rdata(rdata: bytes) -> HttpsRecord:
|
||||
"""
|
||||
Parse the RDATA of an SVCB/HTTPS record.
|
||||
Returns a dict with keys: priority, target, alpn, ipv4hint, ipv6hint, echconfig.
|
||||
"""
|
||||
off = 0
|
||||
priority = struct.unpack_from("!H", rdata, off)[0]
|
||||
off += 2
|
||||
|
||||
target, off = read_name(rdata, off)
|
||||
|
||||
# pull out all the key/value params
|
||||
params = {}
|
||||
while off + 4 <= len(rdata):
|
||||
key, length = struct.unpack_from("!HH", rdata, off)
|
||||
off += 4
|
||||
params[key] = rdata[off : off + length]
|
||||
off += length
|
||||
|
||||
# decode ALPN (key=1), IPv4 (4), IPv6 (6), ECHConfig (5)
|
||||
def parse_alpn(buf: bytes) -> list[str]:
|
||||
out = []
|
||||
i: int = 0
|
||||
while i < len(buf):
|
||||
ln = buf[i]
|
||||
out.append(buf[i + 1 : i + 1 + ln].decode())
|
||||
i += 1 + ln
|
||||
return out
|
||||
|
||||
alpn: list[str] = parse_alpn(params.get(1, b""))
|
||||
ipv4 = [
|
||||
inet4_ntoa(params[4][i : i + 4]) for i in range(0, len(params.get(4, b"")), 4)
|
||||
]
|
||||
ipv6 = [
|
||||
inet6_ntoa(params[6][i : i + 16]) for i in range(0, len(params.get(6, b"")), 16)
|
||||
]
|
||||
echconfs = parse_echconfigs(params.get(5, b""))
|
||||
|
||||
return {
|
||||
"priority": priority,
|
||||
"target": target or ".", # empty name → root
|
||||
"alpn": alpn,
|
||||
"ipv4hint": ipv4,
|
||||
"ipv6hint": ipv6,
|
||||
"echconfig": echconfs,
|
||||
}
|
||||
|
||||
|
||||
__all__ = (
|
||||
"inet4_ntoa",
|
||||
"inet6_ntoa",
|
||||
"packet_fragment",
|
||||
"is_ipv4",
|
||||
"is_ipv6",
|
||||
"validate_length_of",
|
||||
"rfc1035_pack",
|
||||
"rfc1035_unpack",
|
||||
"rfc1035_should_read",
|
||||
"parse_https_rdata",
|
||||
)
|
||||
Reference in New Issue
Block a user