Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions src/cai/repl/commands/compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from rich.panel import Panel

from cai.repl.commands.base import Command, register_command
from cai.repl.ui.tty_input import read_repl_yes_no
from cai.sdk.agents.models.openai_chatcompletions import get_current_active_model
from cai.repl.commands.model import (
get_all_predefined_models,
Expand Down Expand Up @@ -572,11 +573,7 @@ def handle_cleanup():
f"\n[#9aa0a6][CAI] Compact current conversation? [/]"
f"[bold white]({msg_count} messages)[/bold white]"
)
confirm = console.input(
"[#9aa0a6][CAI] Compact conversation? [/][bold #00ff9d](y/N): [/]"
)

if confirm.lower() == "y":
if read_repl_yes_no(console, "Compact conversation?"):
# Pass the detected agent name to _perform_compaction
return self._perform_compaction(None, None, agent_name=agent_name)
else:
Expand Down
10 changes: 10 additions & 0 deletions src/cai/repl/ui/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def _toolbar_with_separator():
return [sep_line]

# Get user input with all features
result = ""
try:
result = prompt(
[("class:prompt", "CAI> ")],
Expand Down Expand Up @@ -244,6 +245,15 @@ def _toolbar_with_separator():
except (AttributeError, OSError):
_REPL_STDIN_EXHAUSTED_PENDING = True
return ""
finally:
# prompt_toolkit restores termios from the snapshot taken after we cleared
# ICRNL above; Rich console.input and plain input() then echo Enter as ^M.
try:
from cai.util.streaming import restore_terminal_state

restore_terminal_state(emit_trailing_newline=False)
except Exception:
pass

# Print bottom separator only when user submitted non-empty input,
# so that empty Enter produces a single separator between prompts.
Expand Down
78 changes: 78 additions & 0 deletions src/cai/repl/ui/tty_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Line-oriented prompts after prompt_toolkit or Rich Live (y/N, confirmations)."""

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from rich.console import Console

_CAI_GREY = "#9aa0a6"
_CAI_GREEN = "#00ff9d"


def _restore_tty_after_prompt() -> None:
from cai.util.streaming import restore_terminal_state

restore_terminal_state(emit_trailing_newline=False)


def normalize_repl_line(value: str) -> str:
"""Strip carriage returns left over from non-canonical TTY reads."""
return value.replace("\r", "").strip()


def read_repl_line(
console: Console,
prompt: str = "",
*,
markup: bool = True,
) -> str:
"""Read one line using prompt_toolkit (same stack as the CAI> prompt)."""
from prompt_toolkit import prompt as ptk_prompt

_restore_tty_after_prompt()
try:
if prompt and markup:
# Rich markup is for scrollback context only; ptk owns the input line.
console.print(prompt, markup=markup, emoji=False, end="")
raw = ptk_prompt("")
elif prompt:
raw = ptk_prompt(prompt)
else:
raw = ptk_prompt("")
except (EOFError, KeyboardInterrupt):
return ""
finally:
_restore_tty_after_prompt()
return normalize_repl_line(raw)


def read_repl_yes_no(
console: Console,
label: str,
*,
default: bool = False,
markup: bool = True,
) -> bool:
"""Return True when the user answers y/yes (default answer when they press Enter)."""
from prompt_toolkit import prompt as ptk_prompt
from prompt_toolkit.formatted_text import HTML

_restore_tty_after_prompt()
suffix = "Y/n" if default else "y/N"
ptk_prompt_text = HTML(
f'<style fg="{_CAI_GREY}">[CAI]</style> {label} '
f'<style fg="{_CAI_GREEN}"><b>({suffix})</b></style>: '
)
try:
raw = ptk_prompt(ptk_prompt_text)
except (EOFError, KeyboardInterrupt):
return default
finally:
_restore_tty_after_prompt()

answer = normalize_repl_line(raw)
if not answer:
return default
return answer.lower() in ("y", "yes")
40 changes: 29 additions & 11 deletions src/cai/sdk/agents/models/chatcompletions/httpx_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
# HTTP status codes that trigger automatic retry
_RETRYABLE_STATUS = {429, 502, 503, 504, 529}

# User-facing label — never embed gateway URLs in exception text shown in the REPL.
_LLM_SERVER_LABEL = "Alias Robotics® LLM servers"

_LOG = logging.getLogger(__name__)


Expand Down Expand Up @@ -197,11 +200,17 @@ async def _stream_gen() -> AsyncIterator:
# Retries exhausted — raise typed error
if resp.status_code == 429:
raise LLMRateLimited(
f"Rate limited (429) after {_MAX_RETRIES} retries from {url}",
(
f"Rate limited (429) after {_MAX_RETRIES} retries "
f"from {_LLM_SERVER_LABEL}"
),
retry_after=_extract_retry_after(resp),
)
raise LLMProviderUnavailable(
f"Server error ({resp.status_code}) after {_MAX_RETRIES} retries from {url}"
(
f"Server error ({resp.status_code}) after "
f"{_MAX_RETRIES} retries from {_LLM_SERVER_LABEL}"
)
)

# HTTP 413: request body exceeds gateway/proxy POST
Expand All @@ -215,7 +224,7 @@ async def _stream_gen() -> AsyncIterator:
pass
_log_failed_completion_response(resp, url)
raise LLMContextOverflow(
f"Request body too large (413) for {url}",
f"Request body too large (413) for {_LLM_SERVER_LABEL}",
details=_build_413_details(url, body),
)

Expand Down Expand Up @@ -312,11 +321,11 @@ async def _stream_gen() -> AsyncIterator:
continue
if e.response.status_code == 429:
raise LLMRateLimited(
f"Rate limited (429) after retries from {url}"
f"Rate limited (429) after retries from {_LLM_SERVER_LABEL}"
) from e
if e.response.status_code in (408, 504):
raise LLMTimeout(
f"Timeout ({e.response.status_code}) from {url}"
f"Timeout ({e.response.status_code}) from {_LLM_SERVER_LABEL}"
) from e
raise

Expand Down Expand Up @@ -368,19 +377,28 @@ async def _stream_gen() -> AsyncIterator:
# Retries exhausted — raise typed error
if resp.status_code == 429:
raise LLMRateLimited(
f"Rate limited (429) after {_MAX_RETRIES} retries from {url}",
(
f"Rate limited (429) after {_MAX_RETRIES} retries "
f"from {_LLM_SERVER_LABEL}"
),
retry_after=_extract_retry_after(resp),
)
if resp.status_code in (408, 504):
raise LLMTimeout(
f"Timeout ({resp.status_code}) after {_MAX_RETRIES} retries from {url}"
(
f"Timeout ({resp.status_code}) after {_MAX_RETRIES} "
f"retries from {_LLM_SERVER_LABEL}"
)
)
raise LLMProviderUnavailable(
f"Server error ({resp.status_code}) after {_MAX_RETRIES} retries from {url}"
(
f"Server error ({resp.status_code}) after "
f"{_MAX_RETRIES} retries from {_LLM_SERVER_LABEL}"
)
)

if resp.status_code in (408,):
raise LLMTimeout(f"Timeout ({resp.status_code}) from {url}")
raise LLMTimeout(f"Timeout ({resp.status_code}) from {_LLM_SERVER_LABEL}")

# HTTP 413: request body exceeds gateway/proxy POST size cap.
# Not in _RETRYABLE_STATUS because resending the same body
Expand All @@ -390,7 +408,7 @@ async def _stream_gen() -> AsyncIterator:
if resp.status_code == 413:
_log_failed_completion_response(resp, url)
raise LLMContextOverflow(
f"Request body too large (413) for {url}",
f"Request body too large (413) for {_LLM_SERVER_LABEL}",
details=_build_413_details(url, body),
)

Expand Down Expand Up @@ -422,4 +440,4 @@ async def _stream_gen() -> AsyncIterator:
# Should not reach here
if last_error:
raise last_error
raise LLMProviderUnavailable(f"All retries exhausted for {url}")
raise LLMProviderUnavailable(f"All retries exhausted for {_LLM_SERVER_LABEL}")
12 changes: 12 additions & 0 deletions src/cai/sdk/agents/models/chatcompletions/message_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,18 @@ def ensure_assistant_message() -> ChatCompletionAssistantMessageParam:
arguments = "{}"
elif isinstance(arguments, dict):
arguments = json.dumps(arguments)
else:
# Truncated/streamed-then-cut function call args (e.g. "{")
# are not valid JSON and the upstream proxy rejects the whole
# request with HTTP 400, wedging the conversation. Only repair
# half-finished JSON object/array buffers; leave anything else
# (raw concatenated deltas, provider-specific blobs) alone.
stripped = arguments.lstrip() if isinstance(arguments, str) else ""
if stripped and stripped[0] in "{[":
try:
json.loads(arguments)
except (TypeError, ValueError):
arguments = "{}"
tool_calls_param.append(
ChatCompletionMessageToolCallParam(
id=tc.get("id", "")[:40],
Expand Down
17 changes: 17 additions & 0 deletions src/cai/sdk/agents/models/openai_chatcompletions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2676,6 +2676,23 @@ def next_sequence_number() -> int:
type="response.content_part.done",
)

# Repair partially-streamed function calls before emission so a
# truncated args string like "{" cannot poison conversation history
# and trigger HTTP 400 on the next request. Only rewrite when the
# buffer is clearly a half-finished JSON object/array; leave other
# accumulator contents (raw concatenated deltas, provider quirks)
# untouched so the rest of the pipeline can decide.
for _fc in state.function_calls.values():
_args = _fc.arguments or ""
_stripped = _args.lstrip()
if not _stripped:
_fc.arguments = "{}"
elif _stripped[0] in "{[":
try:
json.loads(_args)
except (TypeError, ValueError):
_fc.arguments = "{}"

# Actually send events for the function calls
for function_call in state.function_calls.values():
# First, a ResponseOutputItemAdded for the function call
Expand Down
26 changes: 26 additions & 0 deletions src/cai/util/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,6 +1386,32 @@ def _reset_controlling_tty_sane() -> None:
)
except Exception:
pass
ensure_cooked_tty()


def ensure_cooked_tty() -> None:
"""Force canonical line input so ``input()`` / Rich ``console.input`` accept Enter.

``stty sane`` alone is not always enough after prompt_toolkit: it may restore a
snapshot taken while ICRNL was cleared, leaving Enter as raw ``\\r`` (shown as ^M).
"""
if not sys.stdin.isatty():
return
try:
import termios

fd = sys.stdin.fileno()
attrs = termios.tcgetattr(fd)
iflag, _, _, lflag, _, cc = attrs
iflag |= termios.ICRNL | termios.BRKINT
iflag &= ~(termios.INLCR | termios.IGNCR)
lflag |= termios.ICANON | termios.ECHO | termios.ISIG
cc[termios.VMIN] = 1
cc[termios.VTIME] = 0
termios.tcsetattr(fd, termios.TCSADRAIN, (iflag, attrs[1], attrs[2], lflag, attrs[4], cc))
termios.tcflush(fd, termios.TCIFLUSH)
except Exception:
pass


def restore_terminal_state(
Expand Down
34 changes: 34 additions & 0 deletions tests/repl/test_multiline_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,40 @@ def test_multiline_is_enabled_in_prompt_config(self):
'Without this, Enter may only insert newlines after long agent turns.'
)

def test_tty_restored_after_prompt(self):
"""Cooked TTY must be restored after prompt() for follow-up console.input prompts.

We clear ICRNL before prompt_toolkit starts; on exit it restores that snapshot,
so /compact and other y/N confirmations see Enter as ^M unless we run
restore_terminal_state again in a finally block.
"""
from cai.repl.ui.prompt import get_user_input
import inspect

source = inspect.getsource(get_user_input)

assert 'finally:' in source, (
'REGRESSION: get_user_input must use finally to restore the TTY after '
'prompt_toolkit exits. Without this, Rich console.input shows ^M on Enter.'
)
assert source.index('finally:') > source.index('prompt('), (
'REGRESSION: TTY restore must run after prompt(), not only before it.'
)
assert 'restore_terminal_state' in source[source.index('finally:'):], (
'REGRESSION: finally block must call restore_terminal_state after prompt().'
)

def test_ensure_cooked_tty_after_stty_sane(self):
"""stty sane must be followed by explicit cooked termios for y/N prompts."""
from cai.util.streaming import _reset_controlling_tty_sane
import inspect

source = inspect.getsource(_reset_controlling_tty_sane)
assert 'ensure_cooked_tty' in source, (
'REGRESSION: _reset_controlling_tty_sane must call ensure_cooked_tty so '
'follow-up console.input accepts Enter instead of echoing ^M.'
)

def test_icrnl_cleared_before_prompt(self):
"""ICRNL/INLCR/IGNCR must be cleared before prompt() to keep Enter as submit.

Expand Down
48 changes: 48 additions & 0 deletions tests/repl/test_tty_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Tests for REPL follow-up prompts (y/N after prompt_toolkit)."""

from __future__ import annotations

from unittest.mock import MagicMock, patch

from cai.repl.ui.tty_input import (
normalize_repl_line,
read_repl_line,
read_repl_yes_no,
)


class TestNormalizeReplLine:
def test_strips_carriage_return_from_enter(self):
assert normalize_repl_line("y\r") == "y"

def test_strips_whitespace(self):
assert normalize_repl_line(" yes \r\n") == "yes"


class TestReadReplYesNo:
@patch("prompt_toolkit.prompt", return_value="y\r")
@patch("cai.repl.ui.tty_input._restore_tty_after_prompt")
def test_yes_with_carriage_return(self, _restore, _prompt):
console = MagicMock()
assert read_repl_yes_no(console, "Continue") is True

@patch("prompt_toolkit.prompt", return_value="")
@patch("cai.repl.ui.tty_input._restore_tty_after_prompt")
def test_empty_defaults_to_no(self, _restore, _prompt):
console = MagicMock()
assert read_repl_yes_no(console, "Continue", default=False) is False

@patch("prompt_toolkit.prompt", return_value="n")
@patch("cai.repl.ui.tty_input._restore_tty_after_prompt")
def test_no_answer(self, _restore, _prompt):
console = MagicMock()
assert read_repl_yes_no(console, "Continue") is False


class TestReadReplLine:
@patch("prompt_toolkit.prompt", return_value="RESET\r")
@patch("cai.repl.ui.tty_input._restore_tty_after_prompt")
def test_restores_tty_around_prompt(self, restore, _prompt):
console = MagicMock()
assert read_repl_line(console, "> ", markup=False) == "RESET"
assert restore.call_count == 2
Loading
Loading