diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index dec0883..e325745 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -542,7 +542,7 @@ def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage]): logger.debug("Agent output: %s", chunk) check_interrupt() agent_type = get_agent_type(agent) - print_agent_output(chunk, agent_type) + print_agent_output(chunk, agent_type, cost_cb=cb) if is_completed() or should_exit(): reset_completion_flags() diff --git a/ra_aid/console/output.py b/ra_aid/console/output.py index dfba0a8..8a45fec 100644 --- a/ra_aid/console/output.py +++ b/ra_aid/console/output.py @@ -5,13 +5,23 @@ from rich.markdown import Markdown from rich.panel import Panel from ra_aid.exceptions import ToolExecutionError +from ra_aid.callbacks.anthropic_callback_handler import AnthropicCallbackHandler # Import shared console instance from .formatting import console +def get_cost_subtitle(cost_cb: Optional[AnthropicCallbackHandler]) -> Optional[str]: + """Generate a subtitle with cost information if a callback is provided.""" + if cost_cb: + return f"Cost: ${cost_cb.total_cost:.6f} | Tokens: {cost_cb.total_tokens}" + return None + + def print_agent_output( - chunk: Dict[str, Any], agent_type: Literal["CiaynAgent", "React"] + chunk: Dict[str, Any], + agent_type: Literal["CiaynAgent", "React"], + cost_cb: Optional[AnthropicCallbackHandler] = None, ) -> None: """Print only the agent's message content, not tool calls. @@ -27,22 +37,40 @@ def print_agent_output( if isinstance(msg.content, list): for content in msg.content: if content["type"] == "text" and content["text"].strip(): + subtitle = get_cost_subtitle(cost_cb) + console.print( - Panel(Markdown(content["text"]), title="🤖 Assistant") + Panel( + Markdown(content["text"]), + title="🤖 Assistant", + subtitle=subtitle, + subtitle_align="right", + ) ) else: if msg.content.strip(): + subtitle = get_cost_subtitle(cost_cb) + console.print( - Panel(Markdown(msg.content.strip()), title="🤖 Assistant") + Panel( + Markdown(msg.content.strip()), + title="🤖 Assistant", + subtitle=subtitle, + subtitle_align="right", + ) ) elif "tools" in chunk and "messages" in chunk["tools"]: for msg in chunk["tools"]["messages"]: if msg.status == "error" and msg.content: err_msg = msg.content.strip() + subtitle = get_cost_subtitle(cost_cb) + console.print( Panel( Markdown(err_msg), title="❌ Tool Error", + subtitle=subtitle, + subtitle_align="right", border_style="red bold", ) ) diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index 2361518..5292317 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -397,7 +397,7 @@ def test_run_agent_stream(monkeypatch, mock_config_repository): call_flag = {"called": False} def fake_print_agent_output( - chunk: Dict[str, Any], agent_type: Literal["CiaynAgent", "React"] + chunk: Dict[str, Any], agent_type: Literal["CiaynAgent", "React"], cost_cb=None ): call_flag["called"] = True @@ -725,4 +725,4 @@ def test_handle_api_error_resource_exhausted(): # ResourceExhausted exception should be handled without raising resource_exhausted_error = ResourceExhausted("429 Resource has been exhausted (e.g. check quota).") - _handle_api_error(resource_exhausted_error, 0, 5, 1) \ No newline at end of file + _handle_api_error(resource_exhausted_error, 0, 5, 1)