Merge pull request #120 from ariel-frischer/cost-display
Add Cost Display for Default ReAct Agent
This commit is contained in:
commit
b4b0fdd686
|
|
@ -542,7 +542,7 @@ def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage]):
|
||||||
logger.debug("Agent output: %s", chunk)
|
logger.debug("Agent output: %s", chunk)
|
||||||
check_interrupt()
|
check_interrupt()
|
||||||
agent_type = get_agent_type(agent)
|
agent_type = get_agent_type(agent)
|
||||||
print_agent_output(chunk, agent_type)
|
print_agent_output(chunk, agent_type, cost_cb=cb)
|
||||||
|
|
||||||
if is_completed() or should_exit():
|
if is_completed() or should_exit():
|
||||||
reset_completion_flags()
|
reset_completion_flags()
|
||||||
|
|
|
||||||
|
|
@ -5,13 +5,23 @@ from rich.markdown import Markdown
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
|
||||||
from ra_aid.exceptions import ToolExecutionError
|
from ra_aid.exceptions import ToolExecutionError
|
||||||
|
from ra_aid.callbacks.anthropic_callback_handler import AnthropicCallbackHandler
|
||||||
|
|
||||||
# Import shared console instance
|
# Import shared console instance
|
||||||
from .formatting import console
|
from .formatting import console
|
||||||
|
|
||||||
|
|
||||||
|
def get_cost_subtitle(cost_cb: Optional[AnthropicCallbackHandler]) -> Optional[str]:
|
||||||
|
"""Generate a subtitle with cost information if a callback is provided."""
|
||||||
|
if cost_cb:
|
||||||
|
return f"Cost: ${cost_cb.total_cost:.6f} | Tokens: {cost_cb.total_tokens}"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def print_agent_output(
|
def print_agent_output(
|
||||||
chunk: Dict[str, Any], agent_type: Literal["CiaynAgent", "React"]
|
chunk: Dict[str, Any],
|
||||||
|
agent_type: Literal["CiaynAgent", "React"],
|
||||||
|
cost_cb: Optional[AnthropicCallbackHandler] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Print only the agent's message content, not tool calls.
|
"""Print only the agent's message content, not tool calls.
|
||||||
|
|
||||||
|
|
@ -27,22 +37,40 @@ def print_agent_output(
|
||||||
if isinstance(msg.content, list):
|
if isinstance(msg.content, list):
|
||||||
for content in msg.content:
|
for content in msg.content:
|
||||||
if content["type"] == "text" and content["text"].strip():
|
if content["type"] == "text" and content["text"].strip():
|
||||||
|
subtitle = get_cost_subtitle(cost_cb)
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Markdown(content["text"]), title="🤖 Assistant")
|
Panel(
|
||||||
|
Markdown(content["text"]),
|
||||||
|
title="🤖 Assistant",
|
||||||
|
subtitle=subtitle,
|
||||||
|
subtitle_align="right",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if msg.content.strip():
|
if msg.content.strip():
|
||||||
|
subtitle = get_cost_subtitle(cost_cb)
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Markdown(msg.content.strip()), title="🤖 Assistant")
|
Panel(
|
||||||
|
Markdown(msg.content.strip()),
|
||||||
|
title="🤖 Assistant",
|
||||||
|
subtitle=subtitle,
|
||||||
|
subtitle_align="right",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
elif "tools" in chunk and "messages" in chunk["tools"]:
|
elif "tools" in chunk and "messages" in chunk["tools"]:
|
||||||
for msg in chunk["tools"]["messages"]:
|
for msg in chunk["tools"]["messages"]:
|
||||||
if msg.status == "error" and msg.content:
|
if msg.status == "error" and msg.content:
|
||||||
err_msg = msg.content.strip()
|
err_msg = msg.content.strip()
|
||||||
|
subtitle = get_cost_subtitle(cost_cb)
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
Panel(
|
Panel(
|
||||||
Markdown(err_msg),
|
Markdown(err_msg),
|
||||||
title="❌ Tool Error",
|
title="❌ Tool Error",
|
||||||
|
subtitle=subtitle,
|
||||||
|
subtitle_align="right",
|
||||||
border_style="red bold",
|
border_style="red bold",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -397,7 +397,7 @@ def test_run_agent_stream(monkeypatch, mock_config_repository):
|
||||||
call_flag = {"called": False}
|
call_flag = {"called": False}
|
||||||
|
|
||||||
def fake_print_agent_output(
|
def fake_print_agent_output(
|
||||||
chunk: Dict[str, Any], agent_type: Literal["CiaynAgent", "React"]
|
chunk: Dict[str, Any], agent_type: Literal["CiaynAgent", "React"], cost_cb=None
|
||||||
):
|
):
|
||||||
call_flag["called"] = True
|
call_flag["called"] = True
|
||||||
|
|
||||||
|
|
@ -725,4 +725,4 @@ def test_handle_api_error_resource_exhausted():
|
||||||
|
|
||||||
# ResourceExhausted exception should be handled without raising
|
# ResourceExhausted exception should be handled without raising
|
||||||
resource_exhausted_error = ResourceExhausted("429 Resource has been exhausted (e.g. check quota).")
|
resource_exhausted_error = ResourceExhausted("429 Resource has been exhausted (e.g. check quota).")
|
||||||
_handle_api_error(resource_exhausted_error, 0, 5, 1)
|
_handle_api_error(resource_exhausted_error, 0, 5, 1)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue