Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions src/prefect/blocks/notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
from prefect.logging import LogEavesdropper
from prefect.types import SecretDict
from prefect.utilities.templating import apply_values, find_placeholders
from prefect.utilities.urls import validate_restricted_url
from prefect.utilities.urls import (
SSRFProtectedAsyncHTTPTransport,
SSRFProtectedHTTPTransport,
validate_restricted_url,
)

PREFECT_NOTIFY_TYPE_DEFAULT = "info" # Use a valid apprise type as default

Expand Down Expand Up @@ -987,13 +991,18 @@ async def anotify(self, body: str, subject: str | None = None) -> None:
import httpx

request_args = self._build_request_args(body, subject)
client_kwargs: dict[str, Any] = {
"headers": {"user-agent": "Prefect Notifications"},
}
if not self.allow_private_urls:
validate_restricted_url(request_args["url"])
# Re-validate at connection time and pin the resolved IP to close
# the DNS-rebinding TOCTOU window.
client_kwargs["transport"] = SSRFProtectedAsyncHTTPTransport()
cookies = request_args.pop("cookies", dict())
client_kwargs["cookies"] = cookies
# make request with httpx
async with httpx.AsyncClient(
headers={"user-agent": "Prefect Notifications"}, cookies=cookies
) as client:
async with httpx.AsyncClient(**client_kwargs) as client:
resp = await client.request(**request_args)
resp.raise_for_status()

Expand All @@ -1002,13 +1011,16 @@ def notify(self, body: str, subject: str | None = None) -> None:
import httpx

request_args = self._build_request_args(body, subject)
client_kwargs: dict[str, Any] = {
"headers": {"user-agent": "Prefect Notifications"},
}
if not self.allow_private_urls:
validate_restricted_url(request_args["url"])
client_kwargs["transport"] = SSRFProtectedHTTPTransport()
cookies = request_args.pop("cookies", dict())
client_kwargs["cookies"] = cookies
# make request with httpx
with httpx.Client(
headers={"user-agent": "Prefect Notifications"}, cookies=cookies
) as client:
with httpx.Client(**client_kwargs) as client:
resp = client.request(**request_args)
resp.raise_for_status()

Expand Down
20 changes: 16 additions & 4 deletions src/prefect/blocks/webhook.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,21 @@

from prefect.blocks.core import Block
from prefect.types import SecretDict
from prefect.utilities.urls import validate_restricted_url
from prefect.utilities.urls import (
SSRFProtectedAsyncHTTPTransport,
validate_restricted_url,
)

# Use a global HTTP transport to maintain a process-wide connection pool for
# interservice requests
_http_transport = AsyncHTTPTransport()
_insecure_http_transport = AsyncHTTPTransport(verify=False)
# Separate pools for calls that must be protected from DNS-rebinding SSRF. The
# protected transport validates the resolved IP at connection time and connects
# to the pre-resolved address, closing the TOCTOU window exploited by DNS
# rebinding attacks.
_safe_http_transport = SSRFProtectedAsyncHTTPTransport()
_safe_insecure_http_transport = SSRFProtectedAsyncHTTPTransport(verify=False)


class Webhook(Block):
Expand Down Expand Up @@ -53,10 +62,13 @@ class Webhook(Block):
)

def block_initialization(self) -> None:
if self.verify:
self._client = AsyncClient(transport=_http_transport)
if self.allow_private_urls:
transport = _http_transport if self.verify else _insecure_http_transport
else:
self._client = AsyncClient(transport=_insecure_http_transport)
transport = (
_safe_http_transport if self.verify else _safe_insecure_http_transport
)
self._client = AsyncClient(transport=transport)

async def call(self, payload: dict[str, Any] | str | None = None) -> Response:
"""
Expand Down
238 changes: 230 additions & 8 deletions src/prefect/utilities/urls.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import inspect
import ipaddress
import socket
import time
import urllib.parse
from collections.abc import Iterable
from logging import Logger
from string import Formatter
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
from urllib.parse import urlparse
from uuid import UUID

import anyio.to_thread
import httpcore
import httpx
from pydantic import BaseModel

from prefect import settings
Expand Down Expand Up @@ -71,7 +76,14 @@ def validate_restricted_url(url: str) -> None:
Validate that the provided URL is safe for outbound requests. This prevents
attacks like SSRF (Server Side Request Forgery), where an attacker can make
requests to internal services (like the GCP metadata service, localhost addresses,
or in-cluster Kubernetes services)
or in-cluster Kubernetes services).

This is a pre-flight check that validates every address the hostname resolves
to via `getaddrinfo`. Because DNS can change between this check and the
actual HTTP connection, callers that need hardened SSRF protection should
also use `SSRFProtectedAsyncHTTPTransport` / `SSRFProtectedHTTPTransport`,
which re-validate at connection time and connect to the pre-resolved IP to
close the TOCTOU window exploited by DNS rebinding attacks.

Args:
url: The URL to validate.
Expand Down Expand Up @@ -100,17 +112,227 @@ def validate_restricted_url(url: str) -> None:
raise ValueError(f"{url!r} is not a valid URL.")

try:
ip_address = socket.gethostbyname(hostname)
ip = ipaddress.ip_address(ip_address)
_validate_resolved_hostname(hostname)
except _RestrictedHostError as exc:
raise ValueError(f"{url!r} is not a valid URL. {exc}")


class _RestrictedHostError(Exception):
"""Internal exception raised when a hostname resolves to a restricted address."""


def _validate_resolved_hostname(hostname: str) -> list[str]:
"""Resolve `hostname` and validate every returned address.

Returns the list of resolved IPs (as strings) in resolution order. Raises
`_RestrictedHostError` if any resolved address is private, or if the
hostname cannot be resolved.

Using `getaddrinfo` (rather than `gethostbyname`, which returns only the
first A record) closes an SSRF bypass where a hostname publishes both a
public and a private A/AAAA record: every resolved address is checked.
"""
# IP literal: validate directly.
try:
ip = ipaddress.ip_address(hostname)
except ValueError:
pass
else:
if ip.is_private:
raise _RestrictedHostError(f"It resolves to the private address {ip}.")
return [str(ip)]

try:
addrinfos = socket.getaddrinfo(hostname, None, type=socket.SOCK_STREAM)
except socket.gaierror:
raise _RestrictedHostError("It could not be resolved.")

resolved: list[str] = []
for addrinfo in addrinfos:
sockaddr = addrinfo[4]
ip_str = sockaddr[0]
# Strip IPv6 zone identifier if present (e.g. "fe80::1%eth0")
ip_str = ip_str.split("%", 1)[0]
try:
ip = ipaddress.ip_address(hostname)
ip = ipaddress.ip_address(ip_str)
except ValueError:
raise ValueError(f"{url!r} is not a valid URL. It could not be resolved.")
continue
if ip.is_private:
raise _RestrictedHostError(f"It resolves to the private address {ip}.")
if ip_str not in resolved:
resolved.append(ip_str)

if ip.is_private:
raise ValueError(
f"{url!r} is not a valid URL. It resolves to the private address {ip}."
if not resolved:
raise _RestrictedHostError("It could not be resolved.")

return resolved


class _SSRFProtectedAsyncBackend(httpcore.AsyncNetworkBackend):
"""An `httpcore.AsyncNetworkBackend` that validates resolved addresses.

Wraps an existing backend and, on each `connect_tcp` call, resolves the
hostname itself, rejects any resolved address that is private, and then
connects to the validated IP directly (rather than the hostname) so that
the underlying backend cannot re-resolve to a different address.

TLS SNI / certificate validation is unaffected because httpcore passes the
original hostname to `start_tls` via `server_hostname`, independently of
the host used for the TCP connection.
"""

def __init__(self, wrapped: httpcore.AsyncNetworkBackend) -> None:
self._wrapped = wrapped

async def connect_tcp(
self,
host: str,
port: int,
timeout: Optional[float] = None,
local_address: Optional[str] = None,
socket_options: Optional[Iterable[Any]] = None,
) -> httpcore.AsyncNetworkStream:
# Resolve in a worker thread so the event loop is not blocked during
# DNS lookups (which can be slow or intermittently failing).
validated_ips = await anyio.to_thread.run_sync(
_resolve_and_validate_for_connect, host
)
last_exc: Optional[BaseException] = None
deadline = time.monotonic() + timeout if timeout is not None else None
for ip in validated_ips:
remaining = _remaining_timeout(deadline)
if remaining == 0.0:
break
try:
return await self._wrapped.connect_tcp(
ip,
port,
timeout=remaining,
local_address=local_address,
Comment on lines +207 to +211
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep connect timeout bounded across multi-IP retries

The retry loop applies the full timeout to each resolved IP attempt, so hosts with multiple addresses (for example, an unreachable AAAA followed by a reachable A record) can take N * timeout before succeeding or failing. This can significantly increase webhook/notification latency in dual-stack or partially broken network environments; the retry logic should preserve a single timeout budget across attempts (or emulate happy-eyeballs behavior) instead of resetting it per address.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in e0a9ff2. Both backends now derive a shared deadline from the caller's timeout and pass the remaining budget to each per-IP attempt, so total connect time stays bounded by the caller's timeout rather than scaling with the number of resolved addresses. Added test_{async,sync}_backend_shares_timeout_budget_across_retries covering the shared-budget behavior.

socket_options=socket_options,
)
except (httpcore.ConnectError, httpcore.ConnectTimeout, OSError) as exc:
last_exc = exc
assert last_exc is not None
raise last_exc

async def connect_unix_socket(
self,
path: str,
timeout: Optional[float] = None,
socket_options: Optional[Iterable[Any]] = None,
) -> httpcore.AsyncNetworkStream:
return await self._wrapped.connect_unix_socket(
path, timeout=timeout, socket_options=socket_options
)

async def sleep(self, seconds: float) -> None:
await self._wrapped.sleep(seconds)


class _SSRFProtectedSyncBackend(httpcore.NetworkBackend):
"""Synchronous counterpart of `_SSRFProtectedAsyncBackend`."""

def __init__(self, wrapped: httpcore.NetworkBackend) -> None:
self._wrapped = wrapped

def connect_tcp(
self,
host: str,
port: int,
timeout: Optional[float] = None,
local_address: Optional[str] = None,
socket_options: Optional[Iterable[Any]] = None,
) -> httpcore.NetworkStream:
validated_ips = _resolve_and_validate_for_connect(host)
last_exc: Optional[BaseException] = None
deadline = time.monotonic() + timeout if timeout is not None else None
for ip in validated_ips:
remaining = _remaining_timeout(deadline)
if remaining == 0.0:
break
try:
return self._wrapped.connect_tcp(
ip,
port,
timeout=remaining,
local_address=local_address,
socket_options=socket_options,
)
except (httpcore.ConnectError, httpcore.ConnectTimeout, OSError) as exc:
last_exc = exc
assert last_exc is not None
raise last_exc

def connect_unix_socket(
self,
path: str,
timeout: Optional[float] = None,
socket_options: Optional[Iterable[Any]] = None,
) -> httpcore.NetworkStream:
return self._wrapped.connect_unix_socket(
path, timeout=timeout, socket_options=socket_options
)

def sleep(self, seconds: float) -> None:
self._wrapped.sleep(seconds)


def _remaining_timeout(deadline: Optional[float]) -> Optional[float]:
"""Return the time left until `deadline`, or `None` if no deadline is set.

Clamps to `0.0` when the deadline has already passed. Callers use this to
share a single connect-timeout budget across multiple per-IP retries so
that connect time stays bounded (roughly) by the caller's timeout rather
than scaling with the number of resolved addresses.
"""
if deadline is None:
return None
return max(0.0, deadline - time.monotonic())


def _resolve_and_validate_for_connect(host: str) -> list[str]:
"""Resolve `host` and return all safe IPs to connect to.

Every returned IP has been validated against the private-address blocklist;
callers iterate them in order and retry on connect failures so that dual-
stack hostnames still work in single-stack environments.

Raises `httpcore.ConnectError` if any resolved address is private or if the
hostname cannot be resolved. The returned IPs are passed to the underlying
network backend as IP literals, so it will not perform further DNS
resolution — eliminating the DNS rebinding TOCTOU window.
"""
try:
return _validate_resolved_hostname(host)
except _RestrictedHostError as exc:
raise httpcore.ConnectError(f"Refusing to connect to {host!r}: {exc}") from None


class SSRFProtectedAsyncHTTPTransport(httpx.AsyncHTTPTransport):
"""An `httpx.AsyncHTTPTransport` that guards against DNS rebinding SSRF.

Behaves identically to `httpx.AsyncHTTPTransport` except that, for every
request, the hostname is resolved, every resolved address is checked
against the private-address blocklist, and the connection is made to the
specific validated IP. This closes the TOCTOU window between a pre-flight
`validate_restricted_url` check and the actual HTTP connection.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._pool._network_backend = _SSRFProtectedAsyncBackend(
self._pool._network_backend
)


class SSRFProtectedHTTPTransport(httpx.HTTPTransport):
"""Synchronous counterpart of `SSRFProtectedAsyncHTTPTransport`."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._pool._network_backend = _SSRFProtectedSyncBackend(
self._pool._network_backend
)


Expand Down
Loading
Loading