Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions tests/middleware/test_proxy_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,3 +493,47 @@ async def test_proxy_headers_empty_x_forwarded_for() -> None:
response = await client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "https://127.0.0.1:123"


@pytest.mark.anyio
@pytest.mark.parametrize(
("trusted_hosts", "forwarded_for", "expected"),
[
# IPv4 with port
("127.0.0.1", "1.2.3.4:1024", "https://1.2.3.4:1024"),
# IPv4 without port (existing behavior)
("127.0.0.1", "1.2.3.4", "https://1.2.3.4:0"),
# Bracketed IPv6 with port
("127.0.0.1", "[::1]:8080", "https://::1:8080"),
# Bare IPv6 without port
("127.0.0.1", "::1", "https://::1:0"),
# Trusted proxy with port in XFF should still be recognized
("127.0.0.1, 10.0.0.1", "1.2.3.4:5678, 10.0.0.1:9999", "https://1.2.3.4:5678"),
],
)
async def test_proxy_headers_xff_with_port(trusted_hosts: str | list[str], forwarded_for: str, expected: str) -> None:
async with make_httpx_client(trusted_hosts) as client:
headers = {X_FORWARDED_FOR: forwarded_for, X_FORWARDED_PROTO: "https"}
response = await client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == expected


@pytest.mark.parametrize(
("host_str", "expected_host", "expected_port"),
[
("1.2.3.4", "1.2.3.4", 0),
("1.2.3.4:8080", "1.2.3.4", 8080),
("[::1]:8080", "::1", 8080),
("[::1]", "::1", 0),
("::1", "::1", 0),
("2001:db8::1", "2001:db8::1", 0),
("1.2.3.4:notaport", "1.2.3.4", 0),
],
)
def test_parse_host_and_port(host_str: str, expected_host: str, expected_port: int) -> None:
from uvicorn.middleware.proxy_headers import _parse_host_and_port

host, port = _parse_host_and_port(host_str)
assert host == expected_host
assert port == expected_port
62 changes: 55 additions & 7 deletions uvicorn/middleware/proxy_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,15 @@ async def __call__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGIS

if b"x-forwarded-for" in headers:
x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1")
host = self.trusted_hosts.get_trusted_client_host(x_forwarded_for)
raw_host = self.trusted_hosts.get_trusted_client_host(x_forwarded_for)

if host:
if raw_host:
# If the x-forwarded-for header is empty then host is an empty string.
# Only set the client if we actually got something usable.
# See: https://github.com/Kludex/uvicorn/issues/1068

# We've lost the connecting client's port information by now,
# so only include the host.
port = 0
# Parse port from X-Forwarded-For entry if present (e.g. "1.2.3.4:8080").
host, port = _parse_host_and_port(raw_host)
scope["client"] = (host, port)

return await self.app(scope, receive, send)
Expand All @@ -64,6 +63,41 @@ def _parse_raw_hosts(value: str) -> list[str]:
return [item.strip() for item in value.split(",")]


def _parse_host_and_port(host: str) -> tuple[str, int]:
"""Parse a host string that may include a port number.

Handles IPv4 (``1.2.3.4:8080``), bracketed IPv6 (``[::1]:8080``),
and bare IPv6 (``::1``) addresses.

Returns:
A ``(host, port)`` tuple. *port* is ``0`` when no port is present.
"""
if host.startswith("["):
# Bracketed IPv6, e.g. [::1]:8080 or [::1]
bracket_end = host.find("]")
if bracket_end == -1:
return host, 0
ip_part = host[1:bracket_end]
rest = host[bracket_end + 1 :]
if rest.startswith(":"):
try:
return ip_part, int(rest[1:])
except ValueError:
return ip_part, 0
return ip_part, 0

# Check for IPv4:port — only split on the *last* colon if there is
# exactly one colon (bare IPv6 addresses contain multiple colons).
if host.count(":") == 1:
ip_part, _, port_str = host.rpartition(":")
try:
return ip_part, int(port_str)
except ValueError:
return ip_part, 0

return host, 0


class _TrustedHosts:
"""Container for trusted hosts and networks"""

Expand Down Expand Up @@ -113,14 +147,28 @@ def __contains__(self, host: str | None) -> bool:
if not host:
return False

# First try the raw value as an IP address
try:
ip = ipaddress.ip_address(host)
if ip in self.trusted_hosts:
return True
return any(ip in net for net in self.trusted_networks)

except ValueError:
return host in self.trusted_literals
pass

# Strip port and retry as IP (e.g. "1.2.3.4:8080" → "1.2.3.4")
host_without_port, _ = _parse_host_and_port(host)
if host_without_port != host:
try:
ip = ipaddress.ip_address(host_without_port)
if ip in self.trusted_hosts:
return True
return any(ip in net for net in self.trusted_networks)
except ValueError:
pass

# Fall back to literal matching (unix sockets, etc.)
return host in self.trusted_literals

def get_trusted_client_host(self, x_forwarded_for: str) -> str:
"""Extract the client host from x_forwarded_for header
Expand Down
Loading