Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,141 changes: 1,141 additions & 0 deletions DEVELOPMENT.md

Large diffs are not rendered by default.

16 changes: 11 additions & 5 deletions src/class_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import traceback
import uuid
import inspect
from contextlib import redirect_stdout, redirect_stderr
from datetime import datetime
from typing import Dict, Any, Tuple
Expand All @@ -18,11 +19,11 @@ def __init__(self):
self.class_instances: Dict[str, Any] = {}
self.instance_metadata: Dict[str, Dict[str, Any]] = {}

def execute(self, request: FunctionRequest) -> FunctionResponse:
async def execute(self, request: FunctionRequest) -> FunctionResponse:
"""Execute class method."""
return self.execute_class_method(request)
return await self.execute_class_method(request)

def execute_class_method(self, request: FunctionRequest) -> FunctionResponse:
async def execute_class_method(self, request: FunctionRequest) -> FunctionResponse:
"""
Execute a class method with instance management.
"""
Expand Down Expand Up @@ -55,8 +56,13 @@ def execute_class_method(self, request: FunctionRequest) -> FunctionResponse:
args = SerializationUtils.deserialize_args(request.args)
kwargs = SerializationUtils.deserialize_kwargs(request.kwargs)

# Execute the method
result = method(*args, **kwargs)
# Execute the method (handle both sync and async)
if inspect.iscoroutinefunction(method):
# Async method - await directly
result = await method(*args, **kwargs)
else:
# Sync method - call directly
result = method(*args, **kwargs)

# Update instance metadata
self._update_instance_metadata(instance_id)
Expand Down
12 changes: 9 additions & 3 deletions src/function_executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import io
import logging
import traceback
import inspect
from contextlib import redirect_stdout, redirect_stderr
from typing import Dict, Any

Expand All @@ -11,7 +12,7 @@
class FunctionExecutor:
"""Handles execution of individual functions with output capture."""

def execute(self, request: FunctionRequest) -> FunctionResponse:
async def execute(self, request: FunctionRequest) -> FunctionResponse:
"""
Execute a function with full output capture.

Expand Down Expand Up @@ -50,8 +51,13 @@ def execute(self, request: FunctionRequest) -> FunctionResponse:
args = SerializationUtils.deserialize_args(request.args)
kwargs = SerializationUtils.deserialize_kwargs(request.kwargs)

# Execute the function
result = func(*args, **kwargs)
# Execute the function (handle both sync and async)
if inspect.iscoroutinefunction(func):
# Async function - await directly
result = await func(*args, **kwargs)
else:
# Sync function - call directly
result = func(*args, **kwargs)

except Exception as e:
# Combine output streams
Expand Down
4 changes: 2 additions & 2 deletions src/remote_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ async def ExecuteFunction(self, request: FunctionRequest) -> FunctionResponse:

# Execute the function/class
if execution_type == "class":
result = self.class_executor.execute_class_method(request)
result = await self.class_executor.execute_class_method(request)
else:
result = self.function_executor.execute(request)
result = await self.function_executor.execute(request)

# Add all captured system logs to the result
system_logs = get_streamed_logs(clear_buffer=True)
Expand Down
5 changes: 3 additions & 2 deletions tests/integration/test_handler_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,8 @@ def bad_function(:
assert "SyntaxError" in result["error"] or "invalid syntax" in result["error"]

@pytest.mark.integration
def test_remote_executor_direct_execution(self):
@pytest.mark.asyncio
async def test_remote_executor_direct_execution(self):
"""Test RemoteExecutor direct method calls."""
executor = RemoteExecutor()

Expand All @@ -370,7 +371,7 @@ def direct_test():
kwargs={},
)

result = executor.function_executor.execute(request)
result = await executor.function_executor.execute(request)

assert result.success is True
deserialized_result = cloudpickle.loads(base64.b64decode(result.result))
Expand Down
Loading
Loading