Skip to content

Commit 75db4cc

Browse files
wip
1 parent 2fbd8ad commit 75db4cc

File tree

3 files changed

+130
-39
lines changed

3 files changed

+130
-39
lines changed

sdk/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ dependencies = [
2626
"pyyaml<7.0.0,>=6.0.0",
2727
]
2828
name = "beta9"
29-
version = "0.1.243"
29+
version = "0.1.244"
3030
description = ""
3131

3232
[project.scripts]

sdk/src/beta9/abstractions/sandbox.py

Lines changed: 92 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from ..env import is_remote
5555
from ..exceptions import SandboxConnectionError, SandboxFileSystemError, SandboxProcessError
5656
from ..type import GpuType, GpuTypeAlias
57+
from ..utils import retry_on_transient_error
5758

5859

5960
class Sandbox(Pod):
@@ -871,14 +872,18 @@ def _exec(
871872
command = list(args) if not isinstance(args[0], list) else args[0]
872873
shell_command = " ".join(shlex.quote(arg) for arg in command)
873874

874-
response = self.sandbox_instance.stub.sandbox_exec(
875-
PodSandboxExecRequest(
876-
container_id=self.sandbox_instance.container_id,
877-
command=shell_command,
878-
cwd=cwd,
879-
env=env,
875+
def _do_exec():
876+
return self.sandbox_instance.stub.sandbox_exec(
877+
PodSandboxExecRequest(
878+
container_id=self.sandbox_instance.container_id,
879+
command=shell_command,
880+
cwd=cwd,
881+
env=env,
882+
)
880883
)
881-
)
884+
885+
response = retry_on_transient_error(_do_exec)
886+
882887
if not response.ok or response.pid <= 0:
883888
raise SandboxProcessError(response.error_msg)
884889

@@ -1063,7 +1068,11 @@ def _fetch_next_chunk(self):
10631068
Returns:
10641069
str: New output chunk, or empty string if no new output.
10651070
"""
1066-
output = self.fetch_fn()
1071+
try:
1072+
output = retry_on_transient_error(self.fetch_fn, max_retries=2, delay=0.2)
1073+
except Exception:
1074+
# On persistent failure, return empty to avoid blocking
1075+
return ""
10671076

10681077
if output == self._last_output:
10691078
return ""
@@ -1220,9 +1229,15 @@ def status(self) -> Tuple[int, str]:
12201229
time.sleep(1)
12211230
```
12221231
"""
1223-
response = self.sandbox_instance.stub.sandbox_status(
1224-
PodSandboxStatusRequest(container_id=self.sandbox_instance.container_id, pid=self.pid)
1225-
)
1232+
1233+
def _get_status():
1234+
return self.sandbox_instance.stub.sandbox_status(
1235+
PodSandboxStatusRequest(
1236+
container_id=self.sandbox_instance.container_id, pid=self.pid
1237+
)
1238+
)
1239+
1240+
response = retry_on_transient_error(_get_status)
12261241

12271242
if not response.ok:
12281243
raise SandboxProcessError(response.error_msg)
@@ -1244,14 +1259,17 @@ def stdout(self):
12441259
print(f"STDOUT: {stdout_content}")
12451260
```
12461261
"""
1247-
return SandboxProcessStream(
1248-
self,
1249-
lambda: self.sandbox_instance.stub.sandbox_stdout(
1250-
PodSandboxStdoutRequest(
1251-
container_id=self.sandbox_instance.container_id, pid=self.pid
1252-
)
1253-
).stdout,
1254-
)
1262+
# Cache the stream to preserve buffer state across multiple accesses
1263+
if not hasattr(self, "_stdout_stream"):
1264+
self._stdout_stream = SandboxProcessStream(
1265+
self,
1266+
lambda: self.sandbox_instance.stub.sandbox_stdout(
1267+
PodSandboxStdoutRequest(
1268+
container_id=self.sandbox_instance.container_id, pid=self.pid
1269+
)
1270+
).stdout,
1271+
)
1272+
return self._stdout_stream
12551273

12561274
@property
12571275
def stderr(self):
@@ -1268,14 +1286,17 @@ def stderr(self):
12681286
print(f"STDERR: {stderr_content}")
12691287
```
12701288
"""
1271-
return SandboxProcessStream(
1272-
self,
1273-
lambda: self.sandbox_instance.stub.sandbox_stderr(
1274-
PodSandboxStderrRequest(
1275-
container_id=self.sandbox_instance.container_id, pid=self.pid
1276-
)
1277-
).stderr,
1278-
)
1289+
# Cache the stream to preserve buffer state across multiple accesses
1290+
if not hasattr(self, "_stderr_stream"):
1291+
self._stderr_stream = SandboxProcessStream(
1292+
self,
1293+
lambda: self.sandbox_instance.stub.sandbox_stderr(
1294+
PodSandboxStderrRequest(
1295+
container_id=self.sandbox_instance.container_id, pid=self.pid
1296+
)
1297+
).stderr,
1298+
)
1299+
return self._stderr_stream
12791300

12801301
@property
12811302
def logs(self):
@@ -1301,6 +1322,9 @@ def logs(self):
13011322
all_logs = process.logs.read()
13021323
```
13031324
"""
1325+
# Cache to preserve state across multiple accesses
1326+
if hasattr(self, "_logs_stream"):
1327+
return self._logs_stream
13041328

13051329
class CombinedStream:
13061330
def __init__(self, process: "SandboxProcess"):
@@ -1369,7 +1393,8 @@ def read(self):
13691393
stderr_data = self._stderr.read()
13701394
return stdout_data + stderr_data
13711395

1372-
return CombinedStream(self)
1396+
self._logs_stream = CombinedStream(self)
1397+
return self._logs_stream
13731398

13741399
@property
13751400
def aio(self) -> "AsyncSandboxProcess":
@@ -1394,6 +1419,9 @@ def aio(self) -> "AsyncSandboxProcess":
13941419
def __getstate__(self):
13951420
state = self.__dict__.copy()
13961421
state.pop("_aio", None)
1422+
state.pop("_stdout_stream", None)
1423+
state.pop("_stderr_stream", None)
1424+
state.pop("_logs_stream", None)
13971425
return state
13981426

13991427
def __setstate__(self, state):
@@ -2949,19 +2977,28 @@ class AsyncSandboxProcessStream:
29492977
```
29502978
"""
29512979

2980+
_STOP_SENTINEL = object()
2981+
29522982
def __init__(self, sync_stream: "SandboxProcessStream"):
29532983
self._sync = sync_stream
29542984

29552985
def __aiter__(self):
29562986
"""Return an async iterator for reading the stream line by line."""
29572987
return self
29582988

2959-
async def __anext__(self):
2960-
"""Get the next line from the stream asynchronously."""
2989+
def _get_next(self):
2990+
"""Get next item, returning sentinel on StopIteration."""
29612991
try:
2962-
return await asyncio.to_thread(self._sync.__next__)
2992+
return self._sync.__next__()
29632993
except StopIteration:
2994+
return AsyncSandboxProcessStream._STOP_SENTINEL
2995+
2996+
async def __anext__(self):
2997+
"""Get the next line from the stream asynchronously."""
2998+
result = await asyncio.to_thread(self._get_next)
2999+
if result is AsyncSandboxProcessStream._STOP_SENTINEL:
29643000
raise StopAsyncIteration
3001+
return result
29653002

29663003
async def read(self) -> str:
29673004
"""Return whatever output is currently available in the stream."""
@@ -2979,6 +3016,8 @@ class AsyncCombinedStream:
29793016
```
29803017
"""
29813018

3019+
_STOP_SENTINEL = object()
3020+
29823021
def __init__(self, sync_process: "SandboxProcess"):
29833022
self._sync_process = sync_process
29843023
self._sync_logs = None
@@ -2988,15 +3027,21 @@ def _get_sync_logs(self):
29883027
self._sync_logs = self._sync_process.logs
29893028
return self._sync_logs
29903029

3030+
def _get_next(self):
3031+
"""Get next item, returning sentinel on StopIteration."""
3032+
try:
3033+
return self._get_sync_logs().__next__()
3034+
except StopIteration:
3035+
return AsyncCombinedStream._STOP_SENTINEL
3036+
29913037
def __aiter__(self):
29923038
return self
29933039

29943040
async def __anext__(self):
2995-
try:
2996-
sync_logs = self._get_sync_logs()
2997-
return await asyncio.to_thread(sync_logs.__next__)
2998-
except StopIteration:
3041+
result = await asyncio.to_thread(self._get_next)
3042+
if result is AsyncCombinedStream._STOP_SENTINEL:
29993043
raise StopAsyncIteration
3044+
return result
30003045

30013046
async def read(self) -> str:
30023047
"""Read all combined output."""
@@ -3096,7 +3141,10 @@ def stdout(self) -> "AsyncSandboxProcessStream":
30963141
Returns:
30973142
AsyncSandboxProcessStream: An async stream object for reading stdout.
30983143
"""
3099-
return AsyncSandboxProcessStream(self._sync.stdout)
3144+
# Cache to preserve buffer state across multiple accesses
3145+
if not hasattr(self, "_stdout_stream"):
3146+
self._stdout_stream = AsyncSandboxProcessStream(self._sync.stdout)
3147+
return self._stdout_stream
31003148

31013149
@property
31023150
def stderr(self) -> "AsyncSandboxProcessStream":
@@ -3106,7 +3154,10 @@ def stderr(self) -> "AsyncSandboxProcessStream":
31063154
Returns:
31073155
AsyncSandboxProcessStream: An async stream object for reading stderr.
31083156
"""
3109-
return AsyncSandboxProcessStream(self._sync.stderr)
3157+
# Cache to preserve buffer state across multiple accesses
3158+
if not hasattr(self, "_stderr_stream"):
3159+
self._stderr_stream = AsyncSandboxProcessStream(self._sync.stderr)
3160+
return self._stderr_stream
31103161

31113162
@property
31123163
def logs(self) -> "AsyncCombinedStream":
@@ -3116,7 +3167,10 @@ def logs(self) -> "AsyncCombinedStream":
31163167
Returns:
31173168
AsyncCombinedStream: An async stream object that combines stdout and stderr.
31183169
"""
3119-
return AsyncCombinedStream(self._sync)
3170+
# Cache to preserve state across multiple accesses
3171+
if not hasattr(self, "_logs_stream"):
3172+
self._logs_stream = AsyncCombinedStream(self._sync)
3173+
return self._logs_stream
31203174

31213175

31223176
class AsyncSandboxProcessResponse:

sdk/src/beta9/utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,48 @@
22
import inspect
33
import os
44
import sys
5+
import time
56
from pathlib import Path
67

78
from . import terminal
89

910

11+
def retry_on_transient_error(fn, max_retries: int = 3, delay: float = 0.5):
12+
"""
13+
Retry a function on transient gRPC/connection errors.
14+
Returns the result or raises the last exception.
15+
16+
Args:
17+
fn: The function to call.
18+
max_retries: Maximum number of retry attempts. Default is 3.
19+
delay: Base delay in seconds between retries (uses exponential backoff). Default is 0.5.
20+
21+
Returns:
22+
The result of the function call.
23+
24+
Raises:
25+
The last exception encountered if all retries fail.
26+
"""
27+
last_exception = None
28+
for attempt in range(max_retries):
29+
try:
30+
return fn()
31+
except Exception as e:
32+
error_str = str(e).lower()
33+
# Retry on transient connection errors
34+
if any(
35+
keyword in error_str
36+
for keyword in ["connection", "unavailable", "timeout", "reset", "eof"]
37+
):
38+
last_exception = e
39+
if attempt < max_retries - 1:
40+
time.sleep(delay * (attempt + 1)) # Exponential backoff
41+
continue
42+
# Non-transient error, raise immediately
43+
raise
44+
raise last_exception
45+
46+
1047
class TempFile:
1148
"""
1249
A temporary file that is automatically deleted when closed. This class exists

0 commit comments

Comments
 (0)