Skip to content

Commit c57e178

Browse files
committed
Restore the trio implementation.
This reverts commit 740c8d3.
1 parent c107b98 commit c57e178

File tree

12 files changed

+3509
-0
lines changed

12 files changed

+3509
-0
lines changed

src/websockets/trio/__init__.py

Whitespace-only changes.

src/websockets/trio/connection.py

Lines changed: 1124 additions & 0 deletions
Large diffs are not rendered by default.

src/websockets/trio/messages.py

Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
from __future__ import annotations
2+
3+
import codecs
4+
import math
5+
from collections.abc import AsyncIterator
6+
from typing import Any, Callable, Literal, overload
7+
8+
import trio
9+
10+
from ..exceptions import ConcurrencyError
11+
from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
12+
from ..typing import Data
13+
14+
15+
__all__ = ["Assembler"]
16+
17+
UTF8Decoder = codecs.getincrementaldecoder("utf-8")
18+
19+
20+
class Assembler:
21+
"""
22+
Assemble messages from frames.
23+
24+
:class:`Assembler` expects only data frames. The stream of frames must
25+
respect the protocol; if it doesn't, the behavior is undefined.
26+
27+
Args:
28+
pause: Called when the buffer of frames goes above the high water mark;
29+
should pause reading from the network.
30+
resume: Called when the buffer of frames goes below the low water mark;
31+
should resume reading from the network.
32+
33+
"""
34+
35+
def __init__(
36+
self,
37+
high: int | None = None,
38+
low: int | None = None,
39+
pause: Callable[[], Any] = lambda: None,
40+
resume: Callable[[], Any] = lambda: None,
41+
) -> None:
42+
# Queue of incoming frames.
43+
self.send_frames: trio.MemorySendChannel[Frame]
44+
self.recv_frames: trio.MemoryReceiveChannel[Frame]
45+
self.send_frames, self.recv_frames = trio.open_memory_channel(math.inf)
46+
47+
# We cannot put a hard limit on the size of the queue because a single
48+
# call to Protocol.data_received() could produce thousands of frames,
49+
# which must be buffered. Instead, we pause reading when the buffer goes
50+
# above the high limit and we resume when it goes under the low limit.
51+
if high is not None and low is None:
52+
low = high // 4
53+
if high is None and low is not None:
54+
high = low * 4
55+
if high is not None and low is not None:
56+
if low < 0:
57+
raise ValueError("low must be positive or equal to zero")
58+
if high < low:
59+
raise ValueError("high must be greater than or equal to low")
60+
self.high, self.low = high, low
61+
self.pause = pause
62+
self.resume = resume
63+
self.paused = False
64+
65+
# This flag prevents concurrent calls to get() by user code.
66+
self.get_in_progress = False
67+
68+
# This flag marks the end of the connection.
69+
self.closed = False
70+
71+
@overload
72+
async def get(self, decode: Literal[True]) -> str: ...
73+
74+
@overload
75+
async def get(self, decode: Literal[False]) -> bytes: ...
76+
77+
@overload
78+
async def get(self, decode: bool | None = None) -> Data: ...
79+
80+
async def get(self, decode: bool | None = None) -> Data:
81+
"""
82+
Read the next message.
83+
84+
:meth:`get` returns a single :class:`str` or :class:`bytes`.
85+
86+
If the message is fragmented, :meth:`get` waits until the last frame is
87+
received, then it reassembles the message and returns it. To receive
88+
messages frame by frame, use :meth:`get_iter` instead.
89+
90+
Args:
91+
decode: :obj:`False` disables UTF-8 decoding of text frames and
92+
returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of
93+
binary frames and returns :class:`str`.
94+
95+
Raises:
96+
EOFError: If the stream of frames has ended.
97+
UnicodeDecodeError: If a text frame contains invalid UTF-8.
98+
ConcurrencyError: If two coroutines run :meth:`get` or
99+
:meth:`get_iter` concurrently.
100+
101+
"""
102+
if self.get_in_progress:
103+
raise ConcurrencyError("get() or get_iter() is already running")
104+
self.get_in_progress = True
105+
106+
# Locking with get_in_progress prevents concurrent execution
107+
# until get() fetches a complete message or is canceled.
108+
109+
try:
110+
# Fetch the first frame.
111+
try:
112+
frame = await self.recv_frames.receive()
113+
except trio.EndOfChannel:
114+
raise EOFError("stream of frames ended")
115+
self.maybe_resume()
116+
assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
117+
if decode is None:
118+
decode = frame.opcode is OP_TEXT
119+
frames = [frame]
120+
121+
# Fetch subsequent frames for fragmented messages.
122+
while not frame.fin:
123+
try:
124+
frame = await self.recv_frames.receive()
125+
except trio.Cancelled:
126+
# Put frames already received back into the queue
127+
# so that future calls to get() can return them.
128+
# Bypass the statistics() method for performance.
129+
state = self.send_frames._state
130+
assert not state.receive_tasks, "no task should receive"
131+
assert not state.data, "queue should be empty"
132+
for frame in frames:
133+
self.send_frames.send_nowait(frame)
134+
raise
135+
except trio.EndOfChannel:
136+
raise EOFError("stream of frames ended")
137+
self.maybe_resume()
138+
assert frame.opcode is OP_CONT
139+
frames.append(frame)
140+
141+
finally:
142+
self.get_in_progress = False
143+
144+
# This converts frame.data to bytes when it's a bytearray.
145+
data = b"".join(frame.data for frame in frames)
146+
if decode:
147+
return data.decode()
148+
else:
149+
return data
150+
151+
@overload
152+
def get_iter(self, decode: Literal[True]) -> AsyncIterator[str]: ...
153+
154+
@overload
155+
def get_iter(self, decode: Literal[False]) -> AsyncIterator[bytes]: ...
156+
157+
@overload
158+
def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: ...
159+
160+
async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
161+
"""
162+
Stream the next message.
163+
164+
Iterating the return value of :meth:`get_iter` asynchronously yields a
165+
:class:`str` or :class:`bytes` for each frame in the message.
166+
167+
The iterator must be fully consumed before calling :meth:`get_iter` or
168+
:meth:`get` again. Else, :exc:`ConcurrencyError` is raised.
169+
170+
This method only makes sense for fragmented messages. If messages aren't
171+
fragmented, use :meth:`get` instead.
172+
173+
Args:
174+
decode: :obj:`False` disables UTF-8 decoding of text frames and
175+
returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of
176+
binary frames and returns :class:`str`.
177+
178+
Raises:
179+
EOFError: If the stream of frames has ended.
180+
UnicodeDecodeError: If a text frame contains invalid UTF-8.
181+
ConcurrencyError: If two coroutines run :meth:`get` or
182+
:meth:`get_iter` concurrently.
183+
184+
"""
185+
if self.get_in_progress:
186+
raise ConcurrencyError("get() or get_iter() is already running")
187+
self.get_in_progress = True
188+
189+
# Locking with get_in_progress prevents concurrent execution
190+
# until get_iter() fetches a complete message or is canceled.
191+
192+
# If get_iter() raises an exception e.g. in decoder.decode(),
193+
# get_in_progress remains set and the connection becomes unusable.
194+
195+
# Yield the first frame.
196+
try:
197+
frame = await self.recv_frames.receive()
198+
except trio.Cancelled:
199+
self.get_in_progress = False
200+
raise
201+
except trio.EndOfChannel:
202+
raise EOFError("stream of frames ended")
203+
self.maybe_resume()
204+
assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
205+
if decode is None:
206+
decode = frame.opcode is OP_TEXT
207+
if decode:
208+
decoder = UTF8Decoder()
209+
yield decoder.decode(frame.data, frame.fin)
210+
else:
211+
# Convert to bytes when frame.data is a bytearray.
212+
yield bytes(frame.data)
213+
214+
# Yield subsequent frames for fragmented messages.
215+
while not frame.fin:
216+
# We cannot handle trio.Cancelled because we don't buffer
217+
# previous fragments — we're streaming them. Canceling get_iter()
218+
# here will leave the assembler in a stuck state. Future calls to
219+
# get() or get_iter() will raise ConcurrencyError.
220+
try:
221+
frame = await self.recv_frames.receive()
222+
except trio.EndOfChannel:
223+
raise EOFError("stream of frames ended")
224+
self.maybe_resume()
225+
assert frame.opcode is OP_CONT
226+
if decode:
227+
yield decoder.decode(frame.data, frame.fin)
228+
else:
229+
# Convert to bytes when frame.data is a bytearray.
230+
yield bytes(frame.data)
231+
232+
self.get_in_progress = False
233+
234+
def put(self, frame: Frame) -> None:
235+
"""
236+
Add ``frame`` to the next message.
237+
238+
Raises:
239+
EOFError: If the stream of frames has ended.
240+
241+
"""
242+
if self.closed:
243+
raise EOFError("stream of frames ended")
244+
245+
self.send_frames.send_nowait(frame)
246+
self.maybe_pause()
247+
248+
def maybe_pause(self) -> None:
249+
"""Pause the writer if queue is above the high water mark."""
250+
# Skip if flow control is disabled.
251+
if self.high is None:
252+
return
253+
254+
# Bypass the statistics() method for performance.
255+
# Check for "> high" to support high = 0.
256+
if len(self.send_frames._state.data) > self.high and not self.paused:
257+
self.paused = True
258+
self.pause()
259+
260+
def maybe_resume(self) -> None:
261+
"""Resume the writer if queue is below the low water mark."""
262+
# Skip if flow control is disabled.
263+
if self.low is None:
264+
return
265+
266+
# Bypass the statistics() method for performance.
267+
# Check for "<= low" to support low = 0.
268+
if len(self.send_frames._state.data) <= self.low and self.paused:
269+
self.paused = False
270+
self.resume()
271+
272+
def close(self) -> None:
273+
"""
274+
End the stream of frames.
275+
276+
Calling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`,
277+
or :meth:`put` is safe. They will raise :exc:`EOFError`.
278+
279+
"""
280+
if self.closed:
281+
return
282+
283+
self.closed = True
284+
285+
# Unblock get() or get_iter().
286+
self.send_frames.close()

src/websockets/trio/utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import sys
2+
3+
import trio
4+
5+
6+
if sys.version_info[:2] < (3, 11): # pragma: no cover
7+
from exceptiongroup import BaseExceptionGroup
8+
9+
10+
__all__ = ["race_events"]
11+
12+
13+
# Based on https://trio.readthedocs.io/en/stable/reference-core.html#custom-supervisors
14+
15+
16+
async def jockey(event: trio.Event, cancel_scope: trio.CancelScope) -> None:
17+
await event.wait()
18+
cancel_scope.cancel()
19+
20+
21+
async def race_events(*events: trio.Event) -> None:
22+
"""
23+
Wait for any of the given events to be set.
24+
25+
Args:
26+
*events: The events to wait for.
27+
28+
"""
29+
if not events:
30+
raise ValueError("no events provided")
31+
32+
try:
33+
async with trio.open_nursery() as nursery:
34+
for event in events:
35+
nursery.start_soon(jockey, event, nursery.cancel_scope)
36+
except BaseExceptionGroup as exc:
37+
try:
38+
trio._util.raise_single_exception_from_group(exc)
39+
except trio._util.MultipleExceptionError: # pragma: no cover
40+
raise AssertionError(
41+
"race_events should be canceled; please file a bug report"
42+
) from exc

tests/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
mitmproxy
22
python-socks[asyncio]
3+
trio

tests/trio/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)