Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions requests_toolbelt/adapters/host_header_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def send(self, request, **kwargs):
connection_pool_kwargs = self.poolmanager.connection_pool_kw

if host_header:
# host header can include port, but we should not include it in the
# assert_hostname
host_header = host_header.split(':')[0]

connection_pool_kwargs["assert_hostname"] = host_header
elif "assert_hostname" in connection_pool_kwargs:
# an assert_hostname from a previous request may have been left
Expand Down
25 changes: 16 additions & 9 deletions tests/test_host_header_ssl_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
@pytest.fixture
def session():
"""Create a session with our adapter mounted."""
session = requests.Session()
session.mount('https://', hhssl.HostHeaderSSLAdapter())
s = requests.Session()
s.mount('https://', hhssl.HostHeaderSSLAdapter())
return s


# Let's not spam example.org:
@pytest.mark.skip
class TestHostHeaderSSLAdapter(object):
"""Tests for our HostHeaderSNIAdapter."""
Expand All @@ -30,14 +32,19 @@ def test_ssladapter(self, session):
headers={'Host': 'example.com'})
assert r.status_code == 200

def test_stream(self):
self.session.get('https://54.175.219.8/stream/20',
headers={'Host': 'httpbin.org'},
stream=True)
def test_stream(self, session):
session.get('https://54.175.219.8/stream/20',
headers={'Host': 'httpbin.org'},
stream=True)

def test_case_insensitive_header(self):
r = self.session.get('https://93.184.216.34',
headers={'hOSt': 'example.org'})
def test_case_insensitive_header(self, session):
r = session.get('https://93.184.216.34',
headers={'hOSt': 'example.org'})
assert r.status_code == 200

def test_case_header_with_port(self, session):
r = session.get('https://93.184.216.34',
headers={'Host': 'example.org:443'})
assert r.status_code == 200

def test_plain_requests(self):
Expand Down