|
5 | 5 | import email.utils |
6 | 6 | import http |
7 | 7 | import warnings |
| 8 | +import re |
8 | 9 | from collections.abc import Generator, Sequence |
9 | 10 | from typing import Any, Callable, cast |
10 | 11 |
|
@@ -49,7 +50,7 @@ class ServerProtocol(Protocol): |
49 | 50 | Sans-I/O implementation of a WebSocket server connection. |
50 | 51 |
|
51 | 52 | Args: |
52 | | - origins: Acceptable values of the ``Origin`` header; include |
| 53 | + origins: Acceptable values of the ``Origin`` header, including regular expressions; include |
53 | 54 | :obj:`None` in the list if the lack of an origin is acceptable. |
54 | 55 | This is useful for defending against Cross-Site WebSocket |
55 | 56 | Hijacking attacks. |
@@ -309,10 +310,37 @@ def process_origin(self, headers: Headers) -> Origin | None: |
309 | 310 | if origin is not None: |
310 | 311 | origin = cast(Origin, origin) |
311 | 312 | if self.origins is not None: |
312 | | - if origin not in self.origins: |
| 313 | + valid = False |
| 314 | + for origin_maybe_regex in self.origins: |
| 315 | + if origin_maybe_regex is None: |
| 316 | + continue |
| 317 | + if self.probably_regex(origin_maybe_regex): |
| 318 | + valid = re.match(origin_maybe_regex, origin) |
| 319 | + else: |
| 320 | + valid = origin_maybe_regex == origin |
| 321 | + if valid: |
| 322 | + break |
| 323 | + if not valid: |
313 | 324 | raise InvalidOrigin(origin) |
314 | 325 | return origin |
315 | 326 |
|
| 327 | + @staticmethod |
| 328 | + def probably_regex(maybe_regex: str) -> bool: |
| 329 | + """ |
| 330 | + Determine if the given string is a regex. |
| 331 | +
|
| 332 | + Args: |
| 333 | + maybe_regex: A string that may be a regex. |
| 334 | +
|
| 335 | + Returns: |
| 336 | + True if the string is a regex, False otherwise. |
| 337 | +
|
| 338 | + """ |
| 339 | + common_regex_chars = ['*', '\\', ']', '?', '$', '^', '[', ']', '(', ')'] |
| 340 | + # Use common characters used in regular expressions as a proxy |
| 341 | + # for if this string is in fact a regex. |
| 342 | + return any((c in maybe_regex for c in common_regex_chars)) |
| 343 | + |
316 | 344 | def process_extensions( |
317 | 345 | self, |
318 | 346 | headers: Headers, |
|
0 commit comments