Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,10 @@ dependencies = ["pytest>=7.0.0"]
[dependency-groups]
dev = [
"coverage[toml] >= 7.9",
"pytest-httpbin >= 2.1.0",
"pytest-randomly >= 3.15.0",
"requests >= 2.32.4",
"starlette >= 0.47.1",
"httpx >= 0.28.1",
"mypy >= 1.20",
"pytest-randomly >= 3.15.0",
"starlette >= 0.47.1",
]

[project.urls]
Expand Down
36 changes: 36 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import socket
import threading
from dataclasses import dataclass
from http.server import BaseHTTPRequestHandler, HTTPServer

import pytest

Expand All @@ -9,3 +12,36 @@
not hasattr(socket, "AF_UNIX"),
reason="Skip any platform that does not support AF_UNIX",
)


class _SimpleHandler(BaseHTTPRequestHandler):
def do_GET(self) -> None:
self.send_response(200)
self.send_header("Content-Type", "text/plain")
self.end_headers()
self.wfile.write(b"OK")

def log_message(self, format, *args) -> None:
pass # suppress request logging during tests


@dataclass
class _ServerInfo:
host: str
port: int

@property
def url(self) -> str:
return f"http://{self.host}:{self.port}"


@pytest.fixture(scope="session")
def httpserver() -> _ServerInfo:
"""A lightweight local HTTP server for testing socket connections."""
server = HTTPServer(("127.0.0.1", 0), _SimpleHandler)
host, port = server.server_address
thread = threading.Thread(target=server.serve_forever)
thread.daemon = True
thread.start()
yield _ServerInfo(host=host, port=port)
server.shutdown()
34 changes: 15 additions & 19 deletions tests/test_combinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,36 @@
from .conftest import unix_sockets_only


def test_parametrize_with_socket_enabled_and_allow_hosts(pytester, httpbin):
def test_parametrize_with_socket_enabled_and_allow_hosts(pytester, httpserver):
"""This is a complex test that demonstrates the use of `parametrize`,
`enable_socket` fixture, allow_hosts CLI flag.

TODO: This test makes real http calls. httpbin only provides a single IP.
Is there a better way to express multiple **working** IPs?

From: https://github.com/miketheman/pytest-socket/issues/56
"""
pytester.makepyfile(f"""
import socket
import pytest
import requests
from urllib.request import urlopen


@pytest.mark.parametrize(
"url",
[
"https://google.com",
"https://amazon.com",
"https://www.microsoft.com",
],
"host",
["google.com", "www.amazon.com", "www.microsoft.com"],
)
def test_domain(url, socket_enabled):
requests.get(url)
def test_domain(host, socket_enabled):
# Just verify socket creation and connect aren't blocked
sock = socket.create_connection((host, 443), timeout=5)
sock.close()

def test_localhost_works():
requests.get("{httpbin.url}/")
urlopen("{httpserver.url}/")

def test_remote_not_allowed_fails():
requests.get("http://172.1.1.1/")
urlopen("http://172.1.1.1/")
""")
pytester.makeini(f"""
[pytest]
addopts = --disable-socket --allow-hosts={httpbin.host}
addopts = --disable-socket --allow-hosts={httpserver.host}
""")
result = pytester.runpytest()
result.assert_outcomes(passed=4, failed=1)
Expand All @@ -47,7 +43,7 @@ def test_remote_not_allowed_fails():


@unix_sockets_only
def test_combine_unix_and_allow_hosts(pytester, httpbin):
def test_combine_unix_and_allow_hosts(pytester, httpserver):
"""Test combination of disable, allow-unix and allow-hosts.

From https://github.com/miketheman/pytest-socket/issues/78
Expand All @@ -66,9 +62,9 @@ def test_unix_connect():

def test_inet_connect():
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('{httpbin.host}', {httpbin.port}))
sock.connect(('{httpserver.host}', {httpserver.port}))
""")
result = pytester.runpytest(
"--disable-socket", "--allow-unix-socket", f"--allow-hosts={httpbin.host}"
"--disable-socket", "--allow-unix-socket", f"--allow-hosts={httpserver.host}"
)
result.assert_outcomes(passed=2)
6 changes: 3 additions & 3 deletions tests/test_precedence.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,16 @@ def test_socket_disabled():
assert_socket_blocked(result, passed=1, failed=1)


def test_global_disable_and_allow_host(pytester, httpbin):
def test_global_disable_and_allow_host(pytester, httpserver):
"""Disable socket globally, but allow a specific host"""
pytester.makepyfile(f"""
from urllib.request import urlopen

def test_urlopen():
assert urlopen("{httpbin.url}/")
assert urlopen("{httpserver.url}/")

def test_urlopen_disabled():
assert urlopen("https://google.com/")
""")
result = pytester.runpytest("--disable-socket", f"--allow-hosts={httpbin.host}")
result = pytester.runpytest("--disable-socket", f"--allow-hosts={httpserver.host}")
assert_socket_blocked(result, passed=1, failed=1)
36 changes: 18 additions & 18 deletions tests/test_restrict_hosts.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ def assert_host_blocked(result, host):


@pytest.fixture
def assert_connect(httpbin, pytester):
def assert_connect(httpserver, pytester):
def assert_socket_connect(should_pass, **kwargs):
# get the name of the calling function
test_name = inspect.stack()[1][3]

mark = ""
host = kwargs.get("host", httpbin.host)
host = kwargs.get("host", httpserver.host)
cli_arg = kwargs.get("cli_arg", None)
code_template = kwargs.get("code_template", connect_code_template)
mark_arg = kwargs.get("mark_arg", None)
Expand All @@ -70,7 +70,7 @@ def assert_socket_connect(should_pass, **kwargs):
elif isinstance(mark_arg, list):
hosts = '","'.join(mark_arg)
mark = f'@pytest.mark.allow_hosts(["{hosts}"])'
code = code_template.format(host, httpbin.port, test_name, mark)
code = code_template.format(host, httpserver.port, test_name, mark)
pytester.makepyfile(code)

if cli_arg:
Expand Down Expand Up @@ -253,62 +253,62 @@ def test_global_restrict_via_config_fail():
assert_host_blocked(result, "127.0.0.1")


def test_global_restrict_via_config_pass(pytester, httpbin):
def test_global_restrict_via_config_pass(pytester, httpserver):
pytester.makepyfile(f"""
import socket

def test_global_restrict_via_config_pass():
socket.socket().connect(('{httpbin.host}', {httpbin.port}))
socket.socket().connect(('{httpserver.host}', {httpserver.port}))
""")
pytester.makeini(f"""
[pytest]
addopts = --allow-hosts={httpbin.host}
addopts = --allow-hosts={httpserver.host}
""")
result = pytester.runpytest()
result.assert_outcomes(passed=1)


def test_test_isolation(pytester, httpbin):
def test_test_isolation(pytester, httpserver):
pytester.makepyfile(f"""
import pytest
import socket

@pytest.mark.allow_hosts('{httpbin.host}')
@pytest.mark.allow_hosts('{httpserver.host}')
def test_pass():
socket.socket().connect(('{httpbin.host}', {httpbin.port}))
socket.socket().connect(('{httpserver.host}', {httpserver.port}))

@pytest.mark.allow_hosts('2.2.2.2')
def test_fail():
socket.socket().connect(('{httpbin.host}', {httpbin.port}))
socket.socket().connect(('{httpserver.host}', {httpserver.port}))

def test_pass_2():
socket.socket().connect(('{httpbin.host}', {httpbin.port}))
socket.socket().connect(('{httpserver.host}', {httpserver.port}))
""")
result = pytester.runpytest()
result.assert_outcomes(passed=2, failed=1)
assert_host_blocked(result, httpbin.host)
assert_host_blocked(result, httpserver.host)


def test_conflicting_cli_vs_marks(pytester, httpbin):
def test_conflicting_cli_vs_marks(pytester, httpserver):
pytester.makepyfile(f"""
import pytest
import socket

@pytest.mark.allow_hosts('{httpbin.host}')
@pytest.mark.allow_hosts('{httpserver.host}')
def test_pass():
socket.socket().connect(('{httpbin.host}', {httpbin.port}))
socket.socket().connect(('{httpserver.host}', {httpserver.port}))

@pytest.mark.allow_hosts('2.2.2.2')
def test_fail():
socket.socket().connect(('{httpbin.host}', {httpbin.port}))
socket.socket().connect(('{httpserver.host}', {httpserver.port}))

def test_fail_2():
socket.socket().connect(('2.2.2.2', {httpbin.port}))
socket.socket().connect(('2.2.2.2', {httpserver.port}))
""")
result = pytester.runpytest("--allow-hosts=1.2.3.4")
result.assert_outcomes(passed=1, failed=2)
assert_host_blocked(result, "2.2.2.2")
assert_host_blocked(result, httpbin.host)
assert_host_blocked(result, httpserver.host)


def test_normalize_allowed_hosts(getaddrinfo_hosts):
Expand Down
Loading
Loading