@@ -97,7 +97,7 @@ async def handshake(
9797 self .request = self .protocol .connect ()
9898 if additional_headers is not None :
9999 self .request .headers .update (additional_headers )
100- if user_agent_header :
100+ if user_agent_header is not None :
101101 self .request .headers .setdefault ("User-Agent" , user_agent_header )
102102 self .protocol .send_request (self .request )
103103
@@ -363,10 +363,8 @@ def protocol_factory(uri: WebSocketURI) -> ClientConnection:
363363
364364 self .proxy = proxy
365365 self .protocol_factory = protocol_factory
366- self .handshake_args = (
367- additional_headers ,
368- user_agent_header ,
369- )
366+ self .additional_headers = additional_headers
367+ self .user_agent_header = user_agent_header
370368 self .process_exception = process_exception
371369 self .open_timeout = open_timeout
372370 self .logger = logger
@@ -442,6 +440,7 @@ def factory() -> ClientConnection:
442440 transport = await connect_http_proxy (
443441 proxy_parsed ,
444442 ws_uri ,
443+ user_agent_header = self .user_agent_header ,
445444 ** proxy_kwargs ,
446445 )
447446 # Initialize WebSocket connection via the proxy.
@@ -541,7 +540,10 @@ async def __await_impl__(self) -> ClientConnection:
541540 for _ in range (MAX_REDIRECTS ):
542541 self .connection = await self .create_connection ()
543542 try :
544- await self .connection .handshake (* self .handshake_args )
543+ await self .connection .handshake (
544+ self .additional_headers ,
545+ self .user_agent_header ,
546+ )
545547 except asyncio .CancelledError :
546548 self .connection .transport .abort ()
547549 raise
@@ -717,10 +719,16 @@ async def connect_socks_proxy(
717719 raise ImportError ("python-socks is required to use a SOCKS proxy" )
718720
719721
720- def prepare_connect_request (proxy : Proxy , ws_uri : WebSocketURI ) -> bytes :
722+ def prepare_connect_request (
723+ proxy : Proxy ,
724+ ws_uri : WebSocketURI ,
725+ user_agent_header : str | None = None ,
726+ ) -> bytes :
721727 host = build_host (ws_uri .host , ws_uri .port , ws_uri .secure , always_include_port = True )
722728 headers = Headers ()
723729 headers ["Host" ] = build_host (ws_uri .host , ws_uri .port , ws_uri .secure )
730+ if user_agent_header is not None :
731+ headers ["User-Agent" ] = user_agent_header
724732 if proxy .username is not None :
725733 assert proxy .password is not None # enforced by parse_proxy()
726734 headers ["Proxy-Authorization" ] = build_authorization_basic (
@@ -731,9 +739,15 @@ def prepare_connect_request(proxy: Proxy, ws_uri: WebSocketURI) -> bytes:
731739
732740
733741class HTTPProxyConnection (asyncio .Protocol ):
734- def __init__ (self , ws_uri : WebSocketURI , proxy : Proxy ):
742+ def __init__ (
743+ self ,
744+ ws_uri : WebSocketURI ,
745+ proxy : Proxy ,
746+ user_agent_header : str | None = None ,
747+ ):
735748 self .ws_uri = ws_uri
736749 self .proxy = proxy
750+ self .user_agent_header = user_agent_header
737751
738752 self .reader = StreamReader ()
739753 self .parser = Response .parse (
@@ -765,7 +779,9 @@ def run_parser(self) -> None:
765779 def connection_made (self , transport : asyncio .BaseTransport ) -> None :
766780 transport = cast (asyncio .Transport , transport )
767781 self .transport = transport
768- self .transport .write (prepare_connect_request (self .proxy , self .ws_uri ))
782+ self .transport .write (
783+ prepare_connect_request (self .proxy , self .ws_uri , self .user_agent_header )
784+ )
769785
770786 def data_received (self , data : bytes ) -> None :
771787 self .reader .feed_data (data )
@@ -784,10 +800,11 @@ def connection_lost(self, exc: Exception | None) -> None:
784800async def connect_http_proxy (
785801 proxy : Proxy ,
786802 ws_uri : WebSocketURI ,
803+ user_agent_header : str | None = None ,
787804 ** kwargs : Any ,
788805) -> asyncio .Transport :
789806 transport , protocol = await asyncio .get_running_loop ().create_connection (
790- lambda : HTTPProxyConnection (ws_uri , proxy ),
807+ lambda : HTTPProxyConnection (ws_uri , proxy , user_agent_header ),
791808 proxy .host ,
792809 proxy .port ,
793810 ** kwargs ,
0 commit comments