feat(agent_utils.py): add AnthropicCallbackHandler to track token usage and costs for Anthropic models (#118)
style(agent_utils.py): format imports and code for better readability refactor(agent_utils.py): standardize model name and cost calculation logic for clarity and maintainability chore(anthropic_callback_handler.py): create a new file for the AnthropicCallbackHandler implementation and related functions
This commit is contained in:
parent
d194868cff
commit
2899b5f848
|
|
@ -10,6 +10,9 @@ import uuid
|
|||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||
|
||||
from ra_aid.callbacks.anthropic_callback_handler import AnthropicCallbackHandler
|
||||
|
||||
|
||||
import litellm
|
||||
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
||||
from openai import RateLimitError as OpenAIRateLimitError
|
||||
|
|
@ -71,7 +74,11 @@ from ra_aid.prompts.human_prompts import (
|
|||
from ra_aid.prompts.implementation_prompts import IMPLEMENTATION_PROMPT
|
||||
from ra_aid.prompts.common_prompts import NEW_PROJECT_HINTS
|
||||
from ra_aid.prompts.planning_prompts import PLANNING_PROMPT
|
||||
from ra_aid.prompts.reasoning_assist_prompt import REASONING_ASSIST_PROMPT_PLANNING, REASONING_ASSIST_PROMPT_IMPLEMENTATION, REASONING_ASSIST_PROMPT_RESEARCH
|
||||
from ra_aid.prompts.reasoning_assist_prompt import (
|
||||
REASONING_ASSIST_PROMPT_PLANNING,
|
||||
REASONING_ASSIST_PROMPT_IMPLEMENTATION,
|
||||
REASONING_ASSIST_PROMPT_RESEARCH,
|
||||
)
|
||||
from ra_aid.prompts.research_prompts import (
|
||||
RESEARCH_ONLY_PROMPT,
|
||||
RESEARCH_PROMPT,
|
||||
|
|
@ -90,9 +97,15 @@ from ra_aid.tool_configs import (
|
|||
)
|
||||
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
|
||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
||||
from ra_aid.database.repositories.key_snippet_repository import (
|
||||
get_key_snippet_repository,
|
||||
)
|
||||
from ra_aid.database.repositories.human_input_repository import (
|
||||
get_human_input_repository,
|
||||
)
|
||||
from ra_aid.database.repositories.research_note_repository import (
|
||||
get_research_note_repository,
|
||||
)
|
||||
from ra_aid.database.repositories.work_log_repository import get_work_log_repository
|
||||
from ra_aid.model_formatters import format_key_facts_dict
|
||||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||
|
|
@ -332,7 +345,9 @@ def create_agent(
|
|||
if is_anthropic_claude(config):
|
||||
logger.debug("Using create_react_agent to instantiate agent.")
|
||||
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
|
||||
return create_react_agent(model, tools, interrupt_after=['tools'], **agent_kwargs)
|
||||
return create_react_agent(
|
||||
model, tools, interrupt_after=["tools"], **agent_kwargs
|
||||
)
|
||||
else:
|
||||
logger.debug("Using CiaynAgent agent instance")
|
||||
return CiaynAgent(model, tools, max_tokens=max_input_tokens, config=config)
|
||||
|
|
@ -343,7 +358,9 @@ def create_agent(
|
|||
config = get_config_repository().get_all()
|
||||
max_input_tokens = get_model_token_limit(config, agent_type)
|
||||
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
|
||||
return create_react_agent(model, tools, interrupt_after=['tools'], **agent_kwargs)
|
||||
return create_react_agent(
|
||||
model, tools, interrupt_after=["tools"], **agent_kwargs
|
||||
)
|
||||
|
||||
|
||||
def run_research_agent(
|
||||
|
|
@ -406,7 +423,9 @@ def run_research_agent(
|
|||
recent_inputs = human_input_repository.get_recent(1)
|
||||
if recent_inputs and len(recent_inputs) > 0:
|
||||
last_human_input = recent_inputs[0].content
|
||||
base_task = f"<last human input>{last_human_input}</last human input>\n{base_task}"
|
||||
base_task = (
|
||||
f"<last human input>{last_human_input}</last human input>\n{base_task}"
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access human input repository: {str(e)}")
|
||||
# Continue without appending last human input
|
||||
|
|
@ -416,7 +435,9 @@ def run_research_agent(
|
|||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||
key_facts = ""
|
||||
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
||||
key_snippets = format_key_snippets_dict(
|
||||
get_key_snippet_repository().get_snippets_dict()
|
||||
)
|
||||
related_files = get_related_files()
|
||||
|
||||
try:
|
||||
|
|
@ -445,7 +466,9 @@ def run_research_agent(
|
|||
|
||||
# Check if reasoning assist is explicitly enabled/disabled
|
||||
force_assistance = get_config_repository().get("force_reasoning_assistance", False)
|
||||
disable_assistance = get_config_repository().get("disable_reasoning_assistance", False)
|
||||
disable_assistance = get_config_repository().get(
|
||||
"disable_reasoning_assistance", False
|
||||
)
|
||||
if force_assistance:
|
||||
reasoning_assist_enabled = True
|
||||
elif disable_assistance:
|
||||
|
|
@ -459,7 +482,9 @@ def run_research_agent(
|
|||
|
||||
# Get research note information for reasoning assistance
|
||||
try:
|
||||
research_notes = format_research_notes_dict(get_research_note_repository().get_notes_dict())
|
||||
research_notes = format_research_notes_dict(
|
||||
get_research_note_repository().get_notes_dict()
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get research notes: {e}")
|
||||
research_notes = ""
|
||||
|
|
@ -467,7 +492,10 @@ def run_research_agent(
|
|||
# If reasoning assist is enabled, make a one-off call to the expert model
|
||||
if reasoning_assist_enabled:
|
||||
try:
|
||||
logger.info("Reasoning assist enabled for model %s, getting expert guidance", model_name)
|
||||
logger.info(
|
||||
"Reasoning assist enabled for model %s, getting expert guidance",
|
||||
model_name,
|
||||
)
|
||||
|
||||
# Collect tool descriptions
|
||||
tool_metadata = []
|
||||
|
|
@ -503,7 +531,13 @@ def run_research_agent(
|
|||
|
||||
# Show the reasoning assist query in a panel
|
||||
console.print(
|
||||
Panel(Markdown("Consulting with the reasoning model on the best research approach."), title="📝 Thinking about research strategy...", border_style="yellow")
|
||||
Panel(
|
||||
Markdown(
|
||||
"Consulting with the reasoning model on the best research approach."
|
||||
),
|
||||
title="📝 Thinking about research strategy...",
|
||||
border_style="yellow",
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug("Invoking expert model for reasoning assist")
|
||||
|
|
@ -517,7 +551,7 @@ def run_research_agent(
|
|||
# Get response content, handling if it's a list (for Claude thinking mode)
|
||||
content = None
|
||||
|
||||
if hasattr(response, 'content'):
|
||||
if hasattr(response, "content"):
|
||||
content = response.content
|
||||
else:
|
||||
# Fallback if content attribute is missing
|
||||
|
|
@ -533,19 +567,27 @@ def run_research_agent(
|
|||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
# Extract thinking content
|
||||
if item.get('type') == 'thinking' and 'thinking' in item:
|
||||
thinking_content = item['thinking']
|
||||
if item.get("type") == "thinking" and "thinking" in item:
|
||||
thinking_content = item["thinking"]
|
||||
logger.debug("Found structured thinking content")
|
||||
# Extract response text
|
||||
elif item.get('type') == 'text' and 'text' in item:
|
||||
response_text = item['text']
|
||||
elif item.get("type") == "text" and "text" in item:
|
||||
response_text = item["text"]
|
||||
logger.debug("Found structured response text")
|
||||
|
||||
# Display thinking content in a separate panel if available
|
||||
if thinking_content and get_config_repository().get("show_thoughts", False):
|
||||
logger.debug(f"Displaying structured thinking content ({len(thinking_content)} chars)")
|
||||
if thinking_content and get_config_repository().get(
|
||||
"show_thoughts", False
|
||||
):
|
||||
logger.debug(
|
||||
f"Displaying structured thinking content ({len(thinking_content)} chars)"
|
||||
)
|
||||
console.print(
|
||||
Panel(Markdown(thinking_content), title="💭 Expert Thinking", border_style="yellow")
|
||||
Panel(
|
||||
Markdown(thinking_content),
|
||||
title="💭 Expert Thinking",
|
||||
border_style="yellow",
|
||||
)
|
||||
)
|
||||
|
||||
# Use response_text if available, otherwise fall back to joining
|
||||
|
|
@ -553,9 +595,11 @@ def run_research_agent(
|
|||
content = response_text
|
||||
else:
|
||||
# Fallback: join list items if structured extraction failed
|
||||
logger.debug("No structured response text found, joining list items")
|
||||
logger.debug(
|
||||
"No structured response text found, joining list items"
|
||||
)
|
||||
content = "\n".join(str(item) for item in content)
|
||||
elif (supports_think_tag or supports_thinking):
|
||||
elif supports_think_tag or supports_thinking:
|
||||
# Process thinking content using the centralized function
|
||||
content, _ = process_thinking_content(
|
||||
content=content,
|
||||
|
|
@ -563,16 +607,22 @@ def run_research_agent(
|
|||
supports_thinking=supports_thinking,
|
||||
panel_title="💭 Expert Thinking",
|
||||
panel_style="yellow",
|
||||
logger=logger
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
# Display the expert guidance in a panel
|
||||
console.print(
|
||||
Panel(Markdown(content), title="Research Strategy Guidance", border_style="blue")
|
||||
Panel(
|
||||
Markdown(content),
|
||||
title="Research Strategy Guidance",
|
||||
border_style="blue",
|
||||
)
|
||||
)
|
||||
|
||||
# Use the content as expert guidance
|
||||
expert_guidance = content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY DURING RESEARCH"
|
||||
expert_guidance = (
|
||||
content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY DURING RESEARCH"
|
||||
)
|
||||
|
||||
logger.info("Received expert guidance for research")
|
||||
except Exception as e:
|
||||
|
|
@ -643,9 +693,7 @@ def run_research_agent(
|
|||
if agent is not None:
|
||||
logger.debug("Research agent created successfully")
|
||||
none_or_fallback_handler = init_fallback_handler(agent, tools)
|
||||
_result = run_agent_with_retry(
|
||||
agent, prompt, none_or_fallback_handler
|
||||
)
|
||||
_result = run_agent_with_retry(agent, prompt, none_or_fallback_handler)
|
||||
if _result:
|
||||
# Log research completion
|
||||
log_work_event(f"Completed research phase for: {base_task_or_query}")
|
||||
|
|
@ -731,7 +779,9 @@ def run_web_research_agent(
|
|||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||
key_facts = ""
|
||||
try:
|
||||
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
||||
key_snippets = format_key_snippets_dict(
|
||||
get_key_snippet_repository().get_snippets_dict()
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
||||
key_snippets = ""
|
||||
|
|
@ -771,9 +821,7 @@ def run_web_research_agent(
|
|||
|
||||
logger.debug("Web research agent completed successfully")
|
||||
none_or_fallback_handler = init_fallback_handler(agent, tools)
|
||||
_result = run_agent_with_retry(
|
||||
agent, prompt, none_or_fallback_handler
|
||||
)
|
||||
_result = run_agent_with_retry(agent, prompt, none_or_fallback_handler)
|
||||
if _result:
|
||||
# Log web research completion
|
||||
log_work_event(f"Completed web research phase for: {query}")
|
||||
|
|
@ -844,7 +892,9 @@ def run_planning_agent(
|
|||
|
||||
# Check if reasoning assist is explicitly enabled/disabled
|
||||
force_assistance = get_config_repository().get("force_reasoning_assistance", False)
|
||||
disable_assistance = get_config_repository().get("disable_reasoning_assistance", False)
|
||||
disable_assistance = get_config_repository().get(
|
||||
"disable_reasoning_assistance", False
|
||||
)
|
||||
|
||||
if force_assistance:
|
||||
reasoning_assist_enabled = True
|
||||
|
|
@ -869,7 +919,9 @@ def run_planning_agent(
|
|||
|
||||
# Make sure key_snippets is defined before using it
|
||||
try:
|
||||
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
||||
key_snippets = format_key_snippets_dict(
|
||||
get_key_snippet_repository().get_snippets_dict()
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
||||
key_snippets = ""
|
||||
|
|
@ -898,7 +950,10 @@ def run_planning_agent(
|
|||
# If reasoning assist is enabled, make a one-off call to the expert model
|
||||
if reasoning_assist_enabled:
|
||||
try:
|
||||
logger.info("Reasoning assist enabled for model %s, getting expert guidance", model_name)
|
||||
logger.info(
|
||||
"Reasoning assist enabled for model %s, getting expert guidance",
|
||||
model_name,
|
||||
)
|
||||
|
||||
# Collect tool descriptions
|
||||
tool_metadata = []
|
||||
|
|
@ -934,7 +989,13 @@ def run_planning_agent(
|
|||
|
||||
# Show the reasoning assist query in a panel
|
||||
console.print(
|
||||
Panel(Markdown("Consulting with the reasoning model on the best way to do this."), title="📝 Thinking about the plan...", border_style="yellow")
|
||||
Panel(
|
||||
Markdown(
|
||||
"Consulting with the reasoning model on the best way to do this."
|
||||
),
|
||||
title="📝 Thinking about the plan...",
|
||||
border_style="yellow",
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug("Invoking expert model for reasoning assist")
|
||||
|
|
@ -948,7 +1009,7 @@ def run_planning_agent(
|
|||
# Get response content, handling if it's a list (for Claude thinking mode)
|
||||
content = None
|
||||
|
||||
if hasattr(response, 'content'):
|
||||
if hasattr(response, "content"):
|
||||
content = response.content
|
||||
else:
|
||||
# Fallback if content attribute is missing
|
||||
|
|
@ -964,19 +1025,27 @@ def run_planning_agent(
|
|||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
# Extract thinking content
|
||||
if item.get('type') == 'thinking' and 'thinking' in item:
|
||||
thinking_content = item['thinking']
|
||||
if item.get("type") == "thinking" and "thinking" in item:
|
||||
thinking_content = item["thinking"]
|
||||
logger.debug("Found structured thinking content")
|
||||
# Extract response text
|
||||
elif item.get('type') == 'text' and 'text' in item:
|
||||
response_text = item['text']
|
||||
elif item.get("type") == "text" and "text" in item:
|
||||
response_text = item["text"]
|
||||
logger.debug("Found structured response text")
|
||||
|
||||
# Display thinking content in a separate panel if available
|
||||
if thinking_content and get_config_repository().get("show_thoughts", False):
|
||||
logger.debug(f"Displaying structured thinking content ({len(thinking_content)} chars)")
|
||||
if thinking_content and get_config_repository().get(
|
||||
"show_thoughts", False
|
||||
):
|
||||
logger.debug(
|
||||
f"Displaying structured thinking content ({len(thinking_content)} chars)"
|
||||
)
|
||||
console.print(
|
||||
Panel(Markdown(thinking_content), title="💭 Expert Thinking", border_style="yellow")
|
||||
Panel(
|
||||
Markdown(thinking_content),
|
||||
title="💭 Expert Thinking",
|
||||
border_style="yellow",
|
||||
)
|
||||
)
|
||||
|
||||
# Use response_text if available, otherwise fall back to joining
|
||||
|
|
@ -984,9 +1053,11 @@ def run_planning_agent(
|
|||
content = response_text
|
||||
else:
|
||||
# Fallback: join list items if structured extraction failed
|
||||
logger.debug("No structured response text found, joining list items")
|
||||
logger.debug(
|
||||
"No structured response text found, joining list items"
|
||||
)
|
||||
content = "\n".join(str(item) for item in content)
|
||||
elif (supports_think_tag or supports_thinking):
|
||||
elif supports_think_tag or supports_thinking:
|
||||
# Process thinking content using the centralized function
|
||||
content, _ = process_thinking_content(
|
||||
content=content,
|
||||
|
|
@ -994,16 +1065,20 @@ def run_planning_agent(
|
|||
supports_thinking=supports_thinking,
|
||||
panel_title="💭 Expert Thinking",
|
||||
panel_style="yellow",
|
||||
logger=logger
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
# Display the expert guidance in a panel
|
||||
console.print(
|
||||
Panel(Markdown(content), title="Reasoning Guidance", border_style="blue")
|
||||
Panel(
|
||||
Markdown(content), title="Reasoning Guidance", border_style="blue"
|
||||
)
|
||||
)
|
||||
|
||||
# Use the content as expert guidance
|
||||
expert_guidance = content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY ON THIS TASK"
|
||||
expert_guidance = (
|
||||
content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY ON THIS TASK"
|
||||
)
|
||||
|
||||
logger.info("Received expert guidance for planning")
|
||||
except Exception as e:
|
||||
|
|
@ -1050,7 +1125,9 @@ def run_planning_agent(
|
|||
)
|
||||
|
||||
config_values = get_config_repository().get_all()
|
||||
recursion_limit = get_config_repository().get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
||||
recursion_limit = get_config_repository().get(
|
||||
"recursion_limit", DEFAULT_RECURSION_LIMIT
|
||||
)
|
||||
run_config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
"recursion_limit": recursion_limit,
|
||||
|
|
@ -1060,9 +1137,7 @@ def run_planning_agent(
|
|||
try:
|
||||
logger.debug("Planning agent completed successfully")
|
||||
none_or_fallback_handler = init_fallback_handler(agent, tools)
|
||||
_result = run_agent_with_retry(
|
||||
agent, planning_prompt, none_or_fallback_handler
|
||||
)
|
||||
_result = run_agent_with_retry(agent, planning_prompt, none_or_fallback_handler)
|
||||
if _result:
|
||||
# Log planning completion
|
||||
log_work_event(f"Completed planning phase for: {base_task}")
|
||||
|
|
@ -1168,7 +1243,9 @@ def run_task_implementation_agent(
|
|||
|
||||
# Check if reasoning assist is explicitly enabled/disabled
|
||||
force_assistance = get_config_repository().get("force_reasoning_assistance", False)
|
||||
disable_assistance = get_config_repository().get("disable_reasoning_assistance", False)
|
||||
disable_assistance = get_config_repository().get(
|
||||
"disable_reasoning_assistance", False
|
||||
)
|
||||
|
||||
if force_assistance:
|
||||
reasoning_assist_enabled = True
|
||||
|
|
@ -1186,7 +1263,10 @@ def run_task_implementation_agent(
|
|||
# If reasoning assist is enabled, make a one-off call to the expert model
|
||||
if reasoning_assist_enabled:
|
||||
try:
|
||||
logger.info("Reasoning assist enabled for model %s, getting implementation guidance", model_name)
|
||||
logger.info(
|
||||
"Reasoning assist enabled for model %s, getting implementation guidance",
|
||||
model_name,
|
||||
)
|
||||
|
||||
# Collect tool descriptions
|
||||
tool_metadata = []
|
||||
|
|
@ -1197,7 +1277,9 @@ def run_task_implementation_agent(
|
|||
tool_info = get_tool_info(tool.func)
|
||||
name = tool.func.__name__
|
||||
description = inspect.getdoc(tool.func)
|
||||
tool_metadata.append(f"Tool: {name}\\nDescription: {description}\\n")
|
||||
tool_metadata.append(
|
||||
f"Tool: {name}\\nDescription: {description}\\n"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting tool info for {tool}: {e}")
|
||||
|
||||
|
|
@ -1213,7 +1295,9 @@ def run_task_implementation_agent(
|
|||
working_directory=working_directory,
|
||||
task=task,
|
||||
key_facts=key_facts,
|
||||
key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()),
|
||||
key_snippets=format_key_snippets_dict(
|
||||
get_key_snippet_repository().get_snippets_dict()
|
||||
),
|
||||
research_notes=formatted_research_notes,
|
||||
related_files="\\n".join(related_files),
|
||||
env_inv=env_inv,
|
||||
|
|
@ -1222,7 +1306,13 @@ def run_task_implementation_agent(
|
|||
|
||||
# Show the reasoning assist query in a panel
|
||||
console.print(
|
||||
Panel(Markdown("Consulting with the reasoning model on the best implementation approach."), title="📝 Thinking about implementation...", border_style="yellow")
|
||||
Panel(
|
||||
Markdown(
|
||||
"Consulting with the reasoning model on the best implementation approach."
|
||||
),
|
||||
title="📝 Thinking about implementation...",
|
||||
border_style="yellow",
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug("Invoking expert model for implementation reasoning assist")
|
||||
|
|
@ -1236,7 +1326,7 @@ def run_task_implementation_agent(
|
|||
# Process response content
|
||||
content = None
|
||||
|
||||
if hasattr(response, 'content'):
|
||||
if hasattr(response, "content"):
|
||||
content = response.content
|
||||
else:
|
||||
# Fallback if content attribute is missing
|
||||
|
|
@ -1249,12 +1339,16 @@ def run_task_implementation_agent(
|
|||
supports_thinking=supports_thinking,
|
||||
panel_title="💭 Implementation Thinking",
|
||||
panel_style="yellow",
|
||||
logger=logger
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
# Display the implementation guidance in a panel
|
||||
console.print(
|
||||
Panel(Markdown(content), title="Implementation Guidance", border_style="blue")
|
||||
Panel(
|
||||
Markdown(content),
|
||||
title="Implementation Guidance",
|
||||
border_style="blue",
|
||||
)
|
||||
)
|
||||
|
||||
# Format the implementation guidance section for the prompt
|
||||
|
|
@ -1276,7 +1370,9 @@ def run_task_implementation_agent(
|
|||
plan=plan,
|
||||
related_files=related_files,
|
||||
key_facts=key_facts,
|
||||
key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()),
|
||||
key_snippets=format_key_snippets_dict(
|
||||
get_key_snippet_repository().get_snippets_dict()
|
||||
),
|
||||
research_notes=formatted_research_notes,
|
||||
work_log=get_work_log_repository().format_work_log(),
|
||||
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
|
||||
|
|
@ -1296,7 +1392,9 @@ def run_task_implementation_agent(
|
|||
)
|
||||
|
||||
config_values = get_config_repository().get_all()
|
||||
recursion_limit = get_config_repository().get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
||||
recursion_limit = get_config_repository().get(
|
||||
"recursion_limit", DEFAULT_RECURSION_LIMIT
|
||||
)
|
||||
run_config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
"recursion_limit": recursion_limit,
|
||||
|
|
@ -1306,9 +1404,7 @@ def run_task_implementation_agent(
|
|||
try:
|
||||
logger.debug("Implementation agent completed successfully")
|
||||
none_or_fallback_handler = init_fallback_handler(agent, tools)
|
||||
_result = run_agent_with_retry(
|
||||
agent, prompt, none_or_fallback_handler
|
||||
)
|
||||
_result = run_agent_with_retry(agent, prompt, none_or_fallback_handler)
|
||||
if _result:
|
||||
# Log task implementation completion
|
||||
log_work_event(f"Completed implementation of task: {task}")
|
||||
|
|
@ -1380,19 +1476,29 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
|
|||
# 1. Check if this is a ValueError with 429 code or rate limit phrases
|
||||
if isinstance(e, ValueError):
|
||||
error_str = str(e).lower()
|
||||
rate_limit_phrases = ["429", "rate limit", "too many requests", "quota exceeded"]
|
||||
if "code" not in error_str and not any(phrase in error_str for phrase in rate_limit_phrases):
|
||||
rate_limit_phrases = [
|
||||
"429",
|
||||
"rate limit",
|
||||
"too many requests",
|
||||
"quota exceeded",
|
||||
]
|
||||
if "code" not in error_str and not any(
|
||||
phrase in error_str for phrase in rate_limit_phrases
|
||||
):
|
||||
raise e
|
||||
|
||||
# 2. Check for status_code or http_status attribute equal to 429
|
||||
if hasattr(e, 'status_code') and e.status_code == 429:
|
||||
if hasattr(e, "status_code") and e.status_code == 429:
|
||||
pass # This is a rate limit error, continue with retry logic
|
||||
elif hasattr(e, 'http_status') and e.http_status == 429:
|
||||
elif hasattr(e, "http_status") and e.http_status == 429:
|
||||
pass # This is a rate limit error, continue with retry logic
|
||||
# 3. Check for rate limit phrases in error message
|
||||
elif isinstance(e, Exception) and not isinstance(e, ValueError):
|
||||
error_str = str(e).lower()
|
||||
if not any(phrase in error_str for phrase in ["rate limit", "too many requests", "quota exceeded", "429"]) and not ("rate" in error_str and "limit" in error_str):
|
||||
if not any(
|
||||
phrase in error_str
|
||||
for phrase in ["rate limit", "too many requests", "quota exceeded", "429"]
|
||||
) and not ("rate" in error_str and "limit" in error_str):
|
||||
# This doesn't look like a rate limit error, but we'll still retry other API errors
|
||||
pass
|
||||
|
||||
|
|
@ -1468,22 +1574,39 @@ def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage]):
|
|||
human-in-the-loop interruptions using interrupt_after=["tools"].
|
||||
"""
|
||||
config = get_config_repository().get_all()
|
||||
stream_config = config.copy()
|
||||
|
||||
cb = None
|
||||
if is_anthropic_claude(config):
|
||||
model_name = config.get("model", "")
|
||||
full_model_name = model_name
|
||||
cb = AnthropicCallbackHandler(full_model_name)
|
||||
|
||||
if "callbacks" not in stream_config:
|
||||
stream_config["callbacks"] = []
|
||||
stream_config["callbacks"].append(cb)
|
||||
|
||||
while True:
|
||||
# Process each chunk from the agent stream.
|
||||
for chunk in agent.stream({"messages": msg_list}, config):
|
||||
for chunk in agent.stream({"messages": msg_list}, stream_config):
|
||||
logger.debug("Agent output: %s", chunk)
|
||||
check_interrupt()
|
||||
agent_type = get_agent_type(agent)
|
||||
print_agent_output(chunk, agent_type)
|
||||
|
||||
if is_completed() or should_exit():
|
||||
reset_completion_flags()
|
||||
return True # Exit immediately when finished or signaled to exit.
|
||||
if cb:
|
||||
logger.debug(f"AnthropicCallbackHandler:\n{cb}")
|
||||
return True
|
||||
|
||||
logger.debug("Stream iteration ended; checking agent state for continuation.")
|
||||
|
||||
# Prepare state configuration, ensuring 'configurable' is present.
|
||||
state_config = get_config_repository().get_all().copy()
|
||||
if "configurable" not in state_config:
|
||||
logger.debug("Key 'configurable' not found in config; adding it as an empty dict.")
|
||||
logger.debug(
|
||||
"Key 'configurable' not found in config; adding it as an empty dict."
|
||||
)
|
||||
state_config["configurable"] = {}
|
||||
logger.debug("Using state_config for agent.get_state(): %s", state_config)
|
||||
|
||||
|
|
@ -1491,21 +1614,27 @@ def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage]):
|
|||
state = agent.get_state(state_config)
|
||||
logger.debug("Agent state retrieved: %s", state)
|
||||
except Exception as e:
|
||||
logger.error("Error retrieving agent state with state_config %s: %s", state_config, e)
|
||||
logger.error(
|
||||
"Error retrieving agent state with state_config %s: %s", state_config, e
|
||||
)
|
||||
raise
|
||||
|
||||
# If the state indicates that further steps remain (i.e. state.next is non-empty),
|
||||
# then resume execution by invoking the agent with no new input.
|
||||
if state.next:
|
||||
logger.debug("State indicates continuation (state.next: %s); resuming execution.", state.next)
|
||||
agent.invoke(None, config)
|
||||
logger.debug(
|
||||
"State indicates continuation (state.next: %s); resuming execution.",
|
||||
state.next,
|
||||
)
|
||||
agent.invoke(None, stream_config)
|
||||
continue
|
||||
else:
|
||||
logger.debug("No continuation indicated in state; exiting stream loop.")
|
||||
break
|
||||
|
||||
if cb:
|
||||
logger.debug(f"AnthropicCallbackHandler:\n{cb}")
|
||||
return True
|
||||
|
||||
|
||||
def run_agent_with_retry(
|
||||
agent: RAgents,
|
||||
prompt: str,
|
||||
|
|
@ -1517,7 +1646,9 @@ def run_agent_with_retry(
|
|||
max_retries = 20
|
||||
base_delay = 1
|
||||
test_attempts = 0
|
||||
_max_test_retries = get_config_repository().get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES)
|
||||
_max_test_retries = get_config_repository().get(
|
||||
"max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES
|
||||
)
|
||||
auto_test = get_config_repository().get("auto_test", False)
|
||||
original_prompt = prompt
|
||||
msg_list = [HumanMessage(content=prompt)]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,270 @@
|
|||
"""Custom callback handlers for tracking token usage and costs."""
|
||||
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
|
||||
# Define cost per 1K tokens for various models
|
||||
ANTHROPIC_MODEL_COSTS = {
|
||||
# Claude 3.7 Sonnet input
|
||||
"claude-3-7-sonnet-20250219": 0.003,
|
||||
"anthropic/claude-3.7-sonnet": 0.003,
|
||||
"claude-3.7-sonnet": 0.003,
|
||||
# Claude 3.7 Sonnet output
|
||||
"claude-3-7-sonnet-20250219-completion": 0.015,
|
||||
"anthropic/claude-3.7-sonnet-completion": 0.015,
|
||||
"claude-3.7-sonnet-completion": 0.015,
|
||||
# Claude 3 Opus input
|
||||
"claude-3-opus-20240229": 0.015,
|
||||
"anthropic/claude-3-opus": 0.015,
|
||||
"claude-3-opus": 0.015,
|
||||
# Claude 3 Opus output
|
||||
"claude-3-opus-20240229-completion": 0.075,
|
||||
"anthropic/claude-3-opus-completion": 0.075,
|
||||
"claude-3-opus-completion": 0.075,
|
||||
# Claude 3 Sonnet input
|
||||
"claude-3-sonnet-20240229": 0.003,
|
||||
"anthropic/claude-3-sonnet": 0.003,
|
||||
"claude-3-sonnet": 0.003,
|
||||
# Claude 3 Sonnet output
|
||||
"claude-3-sonnet-20240229-completion": 0.015,
|
||||
"anthropic/claude-3-sonnet-completion": 0.015,
|
||||
"claude-3-sonnet-completion": 0.015,
|
||||
# Claude 3 Haiku input
|
||||
"claude-3-haiku-20240307": 0.00025,
|
||||
"anthropic/claude-3-haiku": 0.00025,
|
||||
"claude-3-haiku": 0.00025,
|
||||
# Claude 3 Haiku output
|
||||
"claude-3-haiku-20240307-completion": 0.00125,
|
||||
"anthropic/claude-3-haiku-completion": 0.00125,
|
||||
"claude-3-haiku-completion": 0.00125,
|
||||
# Claude 2 input
|
||||
"claude-2": 0.008,
|
||||
"claude-2.0": 0.008,
|
||||
"claude-2.1": 0.008,
|
||||
# Claude 2 output
|
||||
"claude-2-completion": 0.024,
|
||||
"claude-2.0-completion": 0.024,
|
||||
"claude-2.1-completion": 0.024,
|
||||
# Claude Instant input
|
||||
"claude-instant-1": 0.0016,
|
||||
"claude-instant-1.2": 0.0016,
|
||||
# Claude Instant output
|
||||
"claude-instant-1-completion": 0.0055,
|
||||
"claude-instant-1.2-completion": 0.0055,
|
||||
}
|
||||
|
||||
|
||||
def standardize_model_name(model_name: str, is_completion: bool = False) -> str:
|
||||
"""
|
||||
Standardize the model name to a format that can be used for cost calculation.
|
||||
|
||||
Args:
|
||||
model_name: Model name to standardize.
|
||||
is_completion: Whether the model is used for completion or not.
|
||||
|
||||
Returns:
|
||||
Standardized model name.
|
||||
"""
|
||||
if not model_name:
|
||||
model_name = "claude-3-sonnet"
|
||||
|
||||
model_name = model_name.lower()
|
||||
|
||||
# Handle OpenRouter prefixes
|
||||
if model_name.startswith("anthropic/"):
|
||||
model_name = model_name[len("anthropic/") :]
|
||||
|
||||
# Add completion suffix if needed
|
||||
if is_completion and not model_name.endswith("-completion"):
|
||||
model_name = model_name + "-completion"
|
||||
|
||||
return model_name
|
||||
|
||||
|
||||
def get_anthropic_token_cost_for_model(
|
||||
model_name: str, num_tokens: int, is_completion: bool = False
|
||||
) -> float:
|
||||
"""
|
||||
Get the cost in USD for a given model and number of tokens.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
num_tokens: Number of tokens.
|
||||
is_completion: Whether the model is used for completion or not.
|
||||
|
||||
Returns:
|
||||
Cost in USD.
|
||||
"""
|
||||
model_name = standardize_model_name(model_name, is_completion)
|
||||
|
||||
if model_name not in ANTHROPIC_MODEL_COSTS:
|
||||
# Default to Claude 3 Sonnet pricing if model not found
|
||||
model_name = (
|
||||
"claude-3-sonnet" if not is_completion else "claude-3-sonnet-completion"
|
||||
)
|
||||
|
||||
cost_per_1k = ANTHROPIC_MODEL_COSTS[model_name]
|
||||
total_cost = cost_per_1k * (num_tokens / 1000)
|
||||
|
||||
return total_cost
|
||||
|
||||
|
||||
class AnthropicCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that tracks Anthropic token usage and costs."""
|
||||
|
||||
total_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
successful_requests: int = 0
|
||||
total_cost: float = 0.0
|
||||
model_name: str = "claude-3-sonnet" # Default model
|
||||
|
||||
def __init__(self, model_name: Optional[str] = None) -> None:
|
||||
super().__init__()
|
||||
self._lock = threading.Lock()
|
||||
if model_name:
|
||||
self.model_name = model_name
|
||||
|
||||
# Default costs for Claude 3.7 Sonnet
|
||||
self.input_cost_per_token = 0.003 / 1000 # $3/M input tokens
|
||||
self.output_cost_per_token = 0.015 / 1000 # $15/M output tokens
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"Tokens Used: {self.total_tokens}\n"
|
||||
f"\tPrompt Tokens: {self.prompt_tokens}\n"
|
||||
f"\tCompletion Tokens: {self.completion_tokens}\n"
|
||||
f"Successful Requests: {self.successful_requests}\n"
|
||||
f"Total Cost (USD): ${self.total_cost:.6f}"
|
||||
)
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return True
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Record the model name if available."""
|
||||
if "name" in serialized:
|
||||
self.model_name = serialized["name"]
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Count tokens as they're generated."""
|
||||
with self._lock:
|
||||
self.completion_tokens += 1
|
||||
self.total_tokens += 1
|
||||
token_cost = get_anthropic_token_cost_for_model(
|
||||
self.model_name, 1, is_completion=True
|
||||
)
|
||||
self.total_cost += token_cost
|
||||
|
||||
def on_llm_end(self, response: Any, **kwargs: Any) -> None:
|
||||
"""Collect token usage from response."""
|
||||
token_usage = {}
|
||||
|
||||
# Try to extract token usage from response
|
||||
if hasattr(response, "llm_output") and response.llm_output:
|
||||
llm_output = response.llm_output
|
||||
if "token_usage" in llm_output:
|
||||
token_usage = llm_output["token_usage"]
|
||||
elif "usage" in llm_output:
|
||||
usage = llm_output["usage"]
|
||||
|
||||
# Handle Anthropic's specific usage format
|
||||
if "input_tokens" in usage:
|
||||
token_usage["prompt_tokens"] = usage["input_tokens"]
|
||||
if "output_tokens" in usage:
|
||||
token_usage["completion_tokens"] = usage["output_tokens"]
|
||||
|
||||
# Extract model name if available
|
||||
if "model_name" in llm_output:
|
||||
self.model_name = llm_output["model_name"]
|
||||
|
||||
# Try to get usage from response.usage
|
||||
elif hasattr(response, "usage"):
|
||||
usage = response.usage
|
||||
if hasattr(usage, "prompt_tokens"):
|
||||
token_usage["prompt_tokens"] = usage.prompt_tokens
|
||||
if hasattr(usage, "completion_tokens"):
|
||||
token_usage["completion_tokens"] = usage.completion_tokens
|
||||
if hasattr(usage, "total_tokens"):
|
||||
token_usage["total_tokens"] = usage.total_tokens
|
||||
|
||||
# Extract usage from generations if available
|
||||
elif hasattr(response, "generations") and response.generations:
|
||||
for gen in response.generations:
|
||||
if gen and hasattr(gen[0], "generation_info"):
|
||||
gen_info = gen[0].generation_info or {}
|
||||
if "usage" in gen_info:
|
||||
token_usage = gen_info["usage"]
|
||||
break
|
||||
|
||||
# Update counts with lock to prevent race conditions
|
||||
with self._lock:
|
||||
prompt_tokens = token_usage.get("prompt_tokens", 0)
|
||||
completion_tokens = token_usage.get("completion_tokens", 0)
|
||||
|
||||
# Only update prompt tokens if we have them
|
||||
if prompt_tokens > 0:
|
||||
self.prompt_tokens += prompt_tokens
|
||||
self.total_tokens += prompt_tokens
|
||||
prompt_cost = get_anthropic_token_cost_for_model(
|
||||
self.model_name, prompt_tokens, is_completion=False
|
||||
)
|
||||
self.total_cost += prompt_cost
|
||||
|
||||
# Only update completion tokens if not already counted by on_llm_new_token
|
||||
if completion_tokens > 0 and completion_tokens > self.completion_tokens:
|
||||
additional_tokens = completion_tokens - self.completion_tokens
|
||||
self.completion_tokens = completion_tokens
|
||||
self.total_tokens += additional_tokens
|
||||
completion_cost = get_anthropic_token_cost_for_model(
|
||||
self.model_name, additional_tokens, is_completion=True
|
||||
)
|
||||
self.total_cost += completion_cost
|
||||
|
||||
self.successful_requests += 1
|
||||
|
||||
def __copy__(self) -> "AnthropicCallbackHandler":
|
||||
"""Return a copy of the callback handler."""
|
||||
return self
|
||||
|
||||
def __deepcopy__(self, memo: Any) -> "AnthropicCallbackHandler":
|
||||
"""Return a deep copy of the callback handler."""
|
||||
return self
|
||||
|
||||
|
||||
# Create a context variable for our custom callback
|
||||
anthropic_callback_var: ContextVar[Optional[AnthropicCallbackHandler]] = ContextVar(
|
||||
"anthropic_callback", default=None
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_anthropic_callback(
|
||||
model_name: Optional[str] = None,
|
||||
) -> AnthropicCallbackHandler:
|
||||
"""Get the Anthropic callback handler in a context manager.
|
||||
which conveniently exposes token and cost information.
|
||||
|
||||
Args:
|
||||
model_name: Optional model name to use for cost calculation.
|
||||
|
||||
Returns:
|
||||
AnthropicCallbackHandler: The Anthropic callback handler.
|
||||
|
||||
Example:
|
||||
>>> with get_anthropic_callback("claude-3-sonnet") as cb:
|
||||
... # Use the callback handler
|
||||
... # cb.total_tokens, cb.total_cost will be available after
|
||||
"""
|
||||
cb = AnthropicCallbackHandler(model_name)
|
||||
anthropic_callback_var.set(cb)
|
||||
yield cb
|
||||
anthropic_callback_var.set(None)
|
||||
Loading…
Reference in New Issue