feat(agent_utils.py): add AnthropicCallbackHandler to track token usage and costs for Anthropic models (#118)

style(agent_utils.py): format imports and code for better readability
refactor(agent_utils.py): standardize model name and cost calculation logic for clarity and maintainability
chore(anthropic_callback_handler.py): create a new file for the AnthropicCallbackHandler implementation and related functions
This commit is contained in:
Ariel Frischer 2025-03-10 04:08:12 -07:00 committed by GitHub
parent d194868cff
commit 2899b5f848
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 1040 additions and 623 deletions

View File

@ -10,6 +10,9 @@ import uuid
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Sequence from typing import Any, Dict, List, Literal, Optional, Sequence
from ra_aid.callbacks.anthropic_callback_handler import AnthropicCallbackHandler
import litellm import litellm
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
from openai import RateLimitError as OpenAIRateLimitError from openai import RateLimitError as OpenAIRateLimitError
@ -71,7 +74,11 @@ from ra_aid.prompts.human_prompts import (
from ra_aid.prompts.implementation_prompts import IMPLEMENTATION_PROMPT from ra_aid.prompts.implementation_prompts import IMPLEMENTATION_PROMPT
from ra_aid.prompts.common_prompts import NEW_PROJECT_HINTS from ra_aid.prompts.common_prompts import NEW_PROJECT_HINTS
from ra_aid.prompts.planning_prompts import PLANNING_PROMPT from ra_aid.prompts.planning_prompts import PLANNING_PROMPT
from ra_aid.prompts.reasoning_assist_prompt import REASONING_ASSIST_PROMPT_PLANNING, REASONING_ASSIST_PROMPT_IMPLEMENTATION, REASONING_ASSIST_PROMPT_RESEARCH from ra_aid.prompts.reasoning_assist_prompt import (
REASONING_ASSIST_PROMPT_PLANNING,
REASONING_ASSIST_PROMPT_IMPLEMENTATION,
REASONING_ASSIST_PROMPT_RESEARCH,
)
from ra_aid.prompts.research_prompts import ( from ra_aid.prompts.research_prompts import (
RESEARCH_ONLY_PROMPT, RESEARCH_ONLY_PROMPT,
RESEARCH_PROMPT, RESEARCH_PROMPT,
@ -90,9 +97,15 @@ from ra_aid.tool_configs import (
) )
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository from ra_aid.database.repositories.key_snippet_repository import (
from ra_aid.database.repositories.human_input_repository import get_human_input_repository get_key_snippet_repository,
from ra_aid.database.repositories.research_note_repository import get_research_note_repository )
from ra_aid.database.repositories.human_input_repository import (
get_human_input_repository,
)
from ra_aid.database.repositories.research_note_repository import (
get_research_note_repository,
)
from ra_aid.database.repositories.work_log_repository import get_work_log_repository from ra_aid.database.repositories.work_log_repository import get_work_log_repository
from ra_aid.model_formatters import format_key_facts_dict from ra_aid.model_formatters import format_key_facts_dict
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
@ -332,7 +345,9 @@ def create_agent(
if is_anthropic_claude(config): if is_anthropic_claude(config):
logger.debug("Using create_react_agent to instantiate agent.") logger.debug("Using create_react_agent to instantiate agent.")
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens) agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
return create_react_agent(model, tools, interrupt_after=['tools'], **agent_kwargs) return create_react_agent(
model, tools, interrupt_after=["tools"], **agent_kwargs
)
else: else:
logger.debug("Using CiaynAgent agent instance") logger.debug("Using CiaynAgent agent instance")
return CiaynAgent(model, tools, max_tokens=max_input_tokens, config=config) return CiaynAgent(model, tools, max_tokens=max_input_tokens, config=config)
@ -343,7 +358,9 @@ def create_agent(
config = get_config_repository().get_all() config = get_config_repository().get_all()
max_input_tokens = get_model_token_limit(config, agent_type) max_input_tokens = get_model_token_limit(config, agent_type)
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens) agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
return create_react_agent(model, tools, interrupt_after=['tools'], **agent_kwargs) return create_react_agent(
model, tools, interrupt_after=["tools"], **agent_kwargs
)
def run_research_agent( def run_research_agent(
@ -406,7 +423,9 @@ def run_research_agent(
recent_inputs = human_input_repository.get_recent(1) recent_inputs = human_input_repository.get_recent(1)
if recent_inputs and len(recent_inputs) > 0: if recent_inputs and len(recent_inputs) > 0:
last_human_input = recent_inputs[0].content last_human_input = recent_inputs[0].content
base_task = f"<last human input>{last_human_input}</last human input>\n{base_task}" base_task = (
f"<last human input>{last_human_input}</last human input>\n{base_task}"
)
except RuntimeError as e: except RuntimeError as e:
logger.error(f"Failed to access human input repository: {str(e)}") logger.error(f"Failed to access human input repository: {str(e)}")
# Continue without appending last human input # Continue without appending last human input
@ -416,7 +435,9 @@ def run_research_agent(
except RuntimeError as e: except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}") logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = "" key_facts = ""
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()) key_snippets = format_key_snippets_dict(
get_key_snippet_repository().get_snippets_dict()
)
related_files = get_related_files() related_files = get_related_files()
try: try:
@ -445,7 +466,9 @@ def run_research_agent(
# Check if reasoning assist is explicitly enabled/disabled # Check if reasoning assist is explicitly enabled/disabled
force_assistance = get_config_repository().get("force_reasoning_assistance", False) force_assistance = get_config_repository().get("force_reasoning_assistance", False)
disable_assistance = get_config_repository().get("disable_reasoning_assistance", False) disable_assistance = get_config_repository().get(
"disable_reasoning_assistance", False
)
if force_assistance: if force_assistance:
reasoning_assist_enabled = True reasoning_assist_enabled = True
elif disable_assistance: elif disable_assistance:
@ -459,7 +482,9 @@ def run_research_agent(
# Get research note information for reasoning assistance # Get research note information for reasoning assistance
try: try:
research_notes = format_research_notes_dict(get_research_note_repository().get_notes_dict()) research_notes = format_research_notes_dict(
get_research_note_repository().get_notes_dict()
)
except Exception as e: except Exception as e:
logger.warning(f"Failed to get research notes: {e}") logger.warning(f"Failed to get research notes: {e}")
research_notes = "" research_notes = ""
@ -467,7 +492,10 @@ def run_research_agent(
# If reasoning assist is enabled, make a one-off call to the expert model # If reasoning assist is enabled, make a one-off call to the expert model
if reasoning_assist_enabled: if reasoning_assist_enabled:
try: try:
logger.info("Reasoning assist enabled for model %s, getting expert guidance", model_name) logger.info(
"Reasoning assist enabled for model %s, getting expert guidance",
model_name,
)
# Collect tool descriptions # Collect tool descriptions
tool_metadata = [] tool_metadata = []
@ -503,7 +531,13 @@ def run_research_agent(
# Show the reasoning assist query in a panel # Show the reasoning assist query in a panel
console.print( console.print(
Panel(Markdown("Consulting with the reasoning model on the best research approach."), title="📝 Thinking about research strategy...", border_style="yellow") Panel(
Markdown(
"Consulting with the reasoning model on the best research approach."
),
title="📝 Thinking about research strategy...",
border_style="yellow",
)
) )
logger.debug("Invoking expert model for reasoning assist") logger.debug("Invoking expert model for reasoning assist")
@ -517,7 +551,7 @@ def run_research_agent(
# Get response content, handling if it's a list (for Claude thinking mode) # Get response content, handling if it's a list (for Claude thinking mode)
content = None content = None
if hasattr(response, 'content'): if hasattr(response, "content"):
content = response.content content = response.content
else: else:
# Fallback if content attribute is missing # Fallback if content attribute is missing
@ -533,19 +567,27 @@ def run_research_agent(
for item in content: for item in content:
if isinstance(item, dict): if isinstance(item, dict):
# Extract thinking content # Extract thinking content
if item.get('type') == 'thinking' and 'thinking' in item: if item.get("type") == "thinking" and "thinking" in item:
thinking_content = item['thinking'] thinking_content = item["thinking"]
logger.debug("Found structured thinking content") logger.debug("Found structured thinking content")
# Extract response text # Extract response text
elif item.get('type') == 'text' and 'text' in item: elif item.get("type") == "text" and "text" in item:
response_text = item['text'] response_text = item["text"]
logger.debug("Found structured response text") logger.debug("Found structured response text")
# Display thinking content in a separate panel if available # Display thinking content in a separate panel if available
if thinking_content and get_config_repository().get("show_thoughts", False): if thinking_content and get_config_repository().get(
logger.debug(f"Displaying structured thinking content ({len(thinking_content)} chars)") "show_thoughts", False
):
logger.debug(
f"Displaying structured thinking content ({len(thinking_content)} chars)"
)
console.print( console.print(
Panel(Markdown(thinking_content), title="💭 Expert Thinking", border_style="yellow") Panel(
Markdown(thinking_content),
title="💭 Expert Thinking",
border_style="yellow",
)
) )
# Use response_text if available, otherwise fall back to joining # Use response_text if available, otherwise fall back to joining
@ -553,9 +595,11 @@ def run_research_agent(
content = response_text content = response_text
else: else:
# Fallback: join list items if structured extraction failed # Fallback: join list items if structured extraction failed
logger.debug("No structured response text found, joining list items") logger.debug(
"No structured response text found, joining list items"
)
content = "\n".join(str(item) for item in content) content = "\n".join(str(item) for item in content)
elif (supports_think_tag or supports_thinking): elif supports_think_tag or supports_thinking:
# Process thinking content using the centralized function # Process thinking content using the centralized function
content, _ = process_thinking_content( content, _ = process_thinking_content(
content=content, content=content,
@ -563,16 +607,22 @@ def run_research_agent(
supports_thinking=supports_thinking, supports_thinking=supports_thinking,
panel_title="💭 Expert Thinking", panel_title="💭 Expert Thinking",
panel_style="yellow", panel_style="yellow",
logger=logger logger=logger,
) )
# Display the expert guidance in a panel # Display the expert guidance in a panel
console.print( console.print(
Panel(Markdown(content), title="Research Strategy Guidance", border_style="blue") Panel(
Markdown(content),
title="Research Strategy Guidance",
border_style="blue",
)
) )
# Use the content as expert guidance # Use the content as expert guidance
expert_guidance = content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY DURING RESEARCH" expert_guidance = (
content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY DURING RESEARCH"
)
logger.info("Received expert guidance for research") logger.info("Received expert guidance for research")
except Exception as e: except Exception as e:
@ -643,9 +693,7 @@ def run_research_agent(
if agent is not None: if agent is not None:
logger.debug("Research agent created successfully") logger.debug("Research agent created successfully")
none_or_fallback_handler = init_fallback_handler(agent, tools) none_or_fallback_handler = init_fallback_handler(agent, tools)
_result = run_agent_with_retry( _result = run_agent_with_retry(agent, prompt, none_or_fallback_handler)
agent, prompt, none_or_fallback_handler
)
if _result: if _result:
# Log research completion # Log research completion
log_work_event(f"Completed research phase for: {base_task_or_query}") log_work_event(f"Completed research phase for: {base_task_or_query}")
@ -731,7 +779,9 @@ def run_web_research_agent(
logger.error(f"Failed to access key fact repository: {str(e)}") logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = "" key_facts = ""
try: try:
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()) key_snippets = format_key_snippets_dict(
get_key_snippet_repository().get_snippets_dict()
)
except RuntimeError as e: except RuntimeError as e:
logger.error(f"Failed to access key snippet repository: {str(e)}") logger.error(f"Failed to access key snippet repository: {str(e)}")
key_snippets = "" key_snippets = ""
@ -771,9 +821,7 @@ def run_web_research_agent(
logger.debug("Web research agent completed successfully") logger.debug("Web research agent completed successfully")
none_or_fallback_handler = init_fallback_handler(agent, tools) none_or_fallback_handler = init_fallback_handler(agent, tools)
_result = run_agent_with_retry( _result = run_agent_with_retry(agent, prompt, none_or_fallback_handler)
agent, prompt, none_or_fallback_handler
)
if _result: if _result:
# Log web research completion # Log web research completion
log_work_event(f"Completed web research phase for: {query}") log_work_event(f"Completed web research phase for: {query}")
@ -844,7 +892,9 @@ def run_planning_agent(
# Check if reasoning assist is explicitly enabled/disabled # Check if reasoning assist is explicitly enabled/disabled
force_assistance = get_config_repository().get("force_reasoning_assistance", False) force_assistance = get_config_repository().get("force_reasoning_assistance", False)
disable_assistance = get_config_repository().get("disable_reasoning_assistance", False) disable_assistance = get_config_repository().get(
"disable_reasoning_assistance", False
)
if force_assistance: if force_assistance:
reasoning_assist_enabled = True reasoning_assist_enabled = True
@ -869,7 +919,9 @@ def run_planning_agent(
# Make sure key_snippets is defined before using it # Make sure key_snippets is defined before using it
try: try:
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()) key_snippets = format_key_snippets_dict(
get_key_snippet_repository().get_snippets_dict()
)
except RuntimeError as e: except RuntimeError as e:
logger.error(f"Failed to access key snippet repository: {str(e)}") logger.error(f"Failed to access key snippet repository: {str(e)}")
key_snippets = "" key_snippets = ""
@ -898,7 +950,10 @@ def run_planning_agent(
# If reasoning assist is enabled, make a one-off call to the expert model # If reasoning assist is enabled, make a one-off call to the expert model
if reasoning_assist_enabled: if reasoning_assist_enabled:
try: try:
logger.info("Reasoning assist enabled for model %s, getting expert guidance", model_name) logger.info(
"Reasoning assist enabled for model %s, getting expert guidance",
model_name,
)
# Collect tool descriptions # Collect tool descriptions
tool_metadata = [] tool_metadata = []
@ -934,7 +989,13 @@ def run_planning_agent(
# Show the reasoning assist query in a panel # Show the reasoning assist query in a panel
console.print( console.print(
Panel(Markdown("Consulting with the reasoning model on the best way to do this."), title="📝 Thinking about the plan...", border_style="yellow") Panel(
Markdown(
"Consulting with the reasoning model on the best way to do this."
),
title="📝 Thinking about the plan...",
border_style="yellow",
)
) )
logger.debug("Invoking expert model for reasoning assist") logger.debug("Invoking expert model for reasoning assist")
@ -948,7 +1009,7 @@ def run_planning_agent(
# Get response content, handling if it's a list (for Claude thinking mode) # Get response content, handling if it's a list (for Claude thinking mode)
content = None content = None
if hasattr(response, 'content'): if hasattr(response, "content"):
content = response.content content = response.content
else: else:
# Fallback if content attribute is missing # Fallback if content attribute is missing
@ -964,19 +1025,27 @@ def run_planning_agent(
for item in content: for item in content:
if isinstance(item, dict): if isinstance(item, dict):
# Extract thinking content # Extract thinking content
if item.get('type') == 'thinking' and 'thinking' in item: if item.get("type") == "thinking" and "thinking" in item:
thinking_content = item['thinking'] thinking_content = item["thinking"]
logger.debug("Found structured thinking content") logger.debug("Found structured thinking content")
# Extract response text # Extract response text
elif item.get('type') == 'text' and 'text' in item: elif item.get("type") == "text" and "text" in item:
response_text = item['text'] response_text = item["text"]
logger.debug("Found structured response text") logger.debug("Found structured response text")
# Display thinking content in a separate panel if available # Display thinking content in a separate panel if available
if thinking_content and get_config_repository().get("show_thoughts", False): if thinking_content and get_config_repository().get(
logger.debug(f"Displaying structured thinking content ({len(thinking_content)} chars)") "show_thoughts", False
):
logger.debug(
f"Displaying structured thinking content ({len(thinking_content)} chars)"
)
console.print( console.print(
Panel(Markdown(thinking_content), title="💭 Expert Thinking", border_style="yellow") Panel(
Markdown(thinking_content),
title="💭 Expert Thinking",
border_style="yellow",
)
) )
# Use response_text if available, otherwise fall back to joining # Use response_text if available, otherwise fall back to joining
@ -984,9 +1053,11 @@ def run_planning_agent(
content = response_text content = response_text
else: else:
# Fallback: join list items if structured extraction failed # Fallback: join list items if structured extraction failed
logger.debug("No structured response text found, joining list items") logger.debug(
"No structured response text found, joining list items"
)
content = "\n".join(str(item) for item in content) content = "\n".join(str(item) for item in content)
elif (supports_think_tag or supports_thinking): elif supports_think_tag or supports_thinking:
# Process thinking content using the centralized function # Process thinking content using the centralized function
content, _ = process_thinking_content( content, _ = process_thinking_content(
content=content, content=content,
@ -994,16 +1065,20 @@ def run_planning_agent(
supports_thinking=supports_thinking, supports_thinking=supports_thinking,
panel_title="💭 Expert Thinking", panel_title="💭 Expert Thinking",
panel_style="yellow", panel_style="yellow",
logger=logger logger=logger,
) )
# Display the expert guidance in a panel # Display the expert guidance in a panel
console.print( console.print(
Panel(Markdown(content), title="Reasoning Guidance", border_style="blue") Panel(
Markdown(content), title="Reasoning Guidance", border_style="blue"
)
) )
# Use the content as expert guidance # Use the content as expert guidance
expert_guidance = content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY ON THIS TASK" expert_guidance = (
content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY ON THIS TASK"
)
logger.info("Received expert guidance for planning") logger.info("Received expert guidance for planning")
except Exception as e: except Exception as e:
@ -1050,7 +1125,9 @@ def run_planning_agent(
) )
config_values = get_config_repository().get_all() config_values = get_config_repository().get_all()
recursion_limit = get_config_repository().get("recursion_limit", DEFAULT_RECURSION_LIMIT) recursion_limit = get_config_repository().get(
"recursion_limit", DEFAULT_RECURSION_LIMIT
)
run_config = { run_config = {
"configurable": {"thread_id": thread_id}, "configurable": {"thread_id": thread_id},
"recursion_limit": recursion_limit, "recursion_limit": recursion_limit,
@ -1060,9 +1137,7 @@ def run_planning_agent(
try: try:
logger.debug("Planning agent completed successfully") logger.debug("Planning agent completed successfully")
none_or_fallback_handler = init_fallback_handler(agent, tools) none_or_fallback_handler = init_fallback_handler(agent, tools)
_result = run_agent_with_retry( _result = run_agent_with_retry(agent, planning_prompt, none_or_fallback_handler)
agent, planning_prompt, none_or_fallback_handler
)
if _result: if _result:
# Log planning completion # Log planning completion
log_work_event(f"Completed planning phase for: {base_task}") log_work_event(f"Completed planning phase for: {base_task}")
@ -1168,7 +1243,9 @@ def run_task_implementation_agent(
# Check if reasoning assist is explicitly enabled/disabled # Check if reasoning assist is explicitly enabled/disabled
force_assistance = get_config_repository().get("force_reasoning_assistance", False) force_assistance = get_config_repository().get("force_reasoning_assistance", False)
disable_assistance = get_config_repository().get("disable_reasoning_assistance", False) disable_assistance = get_config_repository().get(
"disable_reasoning_assistance", False
)
if force_assistance: if force_assistance:
reasoning_assist_enabled = True reasoning_assist_enabled = True
@ -1186,7 +1263,10 @@ def run_task_implementation_agent(
# If reasoning assist is enabled, make a one-off call to the expert model # If reasoning assist is enabled, make a one-off call to the expert model
if reasoning_assist_enabled: if reasoning_assist_enabled:
try: try:
logger.info("Reasoning assist enabled for model %s, getting implementation guidance", model_name) logger.info(
"Reasoning assist enabled for model %s, getting implementation guidance",
model_name,
)
# Collect tool descriptions # Collect tool descriptions
tool_metadata = [] tool_metadata = []
@ -1197,7 +1277,9 @@ def run_task_implementation_agent(
tool_info = get_tool_info(tool.func) tool_info = get_tool_info(tool.func)
name = tool.func.__name__ name = tool.func.__name__
description = inspect.getdoc(tool.func) description = inspect.getdoc(tool.func)
tool_metadata.append(f"Tool: {name}\\nDescription: {description}\\n") tool_metadata.append(
f"Tool: {name}\\nDescription: {description}\\n"
)
except Exception as e: except Exception as e:
logger.warning(f"Error getting tool info for {tool}: {e}") logger.warning(f"Error getting tool info for {tool}: {e}")
@ -1213,7 +1295,9 @@ def run_task_implementation_agent(
working_directory=working_directory, working_directory=working_directory,
task=task, task=task,
key_facts=key_facts, key_facts=key_facts,
key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()), key_snippets=format_key_snippets_dict(
get_key_snippet_repository().get_snippets_dict()
),
research_notes=formatted_research_notes, research_notes=formatted_research_notes,
related_files="\\n".join(related_files), related_files="\\n".join(related_files),
env_inv=env_inv, env_inv=env_inv,
@ -1222,7 +1306,13 @@ def run_task_implementation_agent(
# Show the reasoning assist query in a panel # Show the reasoning assist query in a panel
console.print( console.print(
Panel(Markdown("Consulting with the reasoning model on the best implementation approach."), title="📝 Thinking about implementation...", border_style="yellow") Panel(
Markdown(
"Consulting with the reasoning model on the best implementation approach."
),
title="📝 Thinking about implementation...",
border_style="yellow",
)
) )
logger.debug("Invoking expert model for implementation reasoning assist") logger.debug("Invoking expert model for implementation reasoning assist")
@ -1236,7 +1326,7 @@ def run_task_implementation_agent(
# Process response content # Process response content
content = None content = None
if hasattr(response, 'content'): if hasattr(response, "content"):
content = response.content content = response.content
else: else:
# Fallback if content attribute is missing # Fallback if content attribute is missing
@ -1249,12 +1339,16 @@ def run_task_implementation_agent(
supports_thinking=supports_thinking, supports_thinking=supports_thinking,
panel_title="💭 Implementation Thinking", panel_title="💭 Implementation Thinking",
panel_style="yellow", panel_style="yellow",
logger=logger logger=logger,
) )
# Display the implementation guidance in a panel # Display the implementation guidance in a panel
console.print( console.print(
Panel(Markdown(content), title="Implementation Guidance", border_style="blue") Panel(
Markdown(content),
title="Implementation Guidance",
border_style="blue",
)
) )
# Format the implementation guidance section for the prompt # Format the implementation guidance section for the prompt
@ -1276,7 +1370,9 @@ def run_task_implementation_agent(
plan=plan, plan=plan,
related_files=related_files, related_files=related_files,
key_facts=key_facts, key_facts=key_facts,
key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()), key_snippets=format_key_snippets_dict(
get_key_snippet_repository().get_snippets_dict()
),
research_notes=formatted_research_notes, research_notes=formatted_research_notes,
work_log=get_work_log_repository().format_work_log(), work_log=get_work_log_repository().format_work_log(),
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "", expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
@ -1296,7 +1392,9 @@ def run_task_implementation_agent(
) )
config_values = get_config_repository().get_all() config_values = get_config_repository().get_all()
recursion_limit = get_config_repository().get("recursion_limit", DEFAULT_RECURSION_LIMIT) recursion_limit = get_config_repository().get(
"recursion_limit", DEFAULT_RECURSION_LIMIT
)
run_config = { run_config = {
"configurable": {"thread_id": thread_id}, "configurable": {"thread_id": thread_id},
"recursion_limit": recursion_limit, "recursion_limit": recursion_limit,
@ -1306,9 +1404,7 @@ def run_task_implementation_agent(
try: try:
logger.debug("Implementation agent completed successfully") logger.debug("Implementation agent completed successfully")
none_or_fallback_handler = init_fallback_handler(agent, tools) none_or_fallback_handler = init_fallback_handler(agent, tools)
_result = run_agent_with_retry( _result = run_agent_with_retry(agent, prompt, none_or_fallback_handler)
agent, prompt, none_or_fallback_handler
)
if _result: if _result:
# Log task implementation completion # Log task implementation completion
log_work_event(f"Completed implementation of task: {task}") log_work_event(f"Completed implementation of task: {task}")
@ -1380,19 +1476,29 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
# 1. Check if this is a ValueError with 429 code or rate limit phrases # 1. Check if this is a ValueError with 429 code or rate limit phrases
if isinstance(e, ValueError): if isinstance(e, ValueError):
error_str = str(e).lower() error_str = str(e).lower()
rate_limit_phrases = ["429", "rate limit", "too many requests", "quota exceeded"] rate_limit_phrases = [
if "code" not in error_str and not any(phrase in error_str for phrase in rate_limit_phrases): "429",
"rate limit",
"too many requests",
"quota exceeded",
]
if "code" not in error_str and not any(
phrase in error_str for phrase in rate_limit_phrases
):
raise e raise e
# 2. Check for status_code or http_status attribute equal to 429 # 2. Check for status_code or http_status attribute equal to 429
if hasattr(e, 'status_code') and e.status_code == 429: if hasattr(e, "status_code") and e.status_code == 429:
pass # This is a rate limit error, continue with retry logic pass # This is a rate limit error, continue with retry logic
elif hasattr(e, 'http_status') and e.http_status == 429: elif hasattr(e, "http_status") and e.http_status == 429:
pass # This is a rate limit error, continue with retry logic pass # This is a rate limit error, continue with retry logic
# 3. Check for rate limit phrases in error message # 3. Check for rate limit phrases in error message
elif isinstance(e, Exception) and not isinstance(e, ValueError): elif isinstance(e, Exception) and not isinstance(e, ValueError):
error_str = str(e).lower() error_str = str(e).lower()
if not any(phrase in error_str for phrase in ["rate limit", "too many requests", "quota exceeded", "429"]) and not ("rate" in error_str and "limit" in error_str): if not any(
phrase in error_str
for phrase in ["rate limit", "too many requests", "quota exceeded", "429"]
) and not ("rate" in error_str and "limit" in error_str):
# This doesn't look like a rate limit error, but we'll still retry other API errors # This doesn't look like a rate limit error, but we'll still retry other API errors
pass pass
@ -1468,22 +1574,39 @@ def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage]):
human-in-the-loop interruptions using interrupt_after=["tools"]. human-in-the-loop interruptions using interrupt_after=["tools"].
""" """
config = get_config_repository().get_all() config = get_config_repository().get_all()
stream_config = config.copy()
cb = None
if is_anthropic_claude(config):
model_name = config.get("model", "")
full_model_name = model_name
cb = AnthropicCallbackHandler(full_model_name)
if "callbacks" not in stream_config:
stream_config["callbacks"] = []
stream_config["callbacks"].append(cb)
while True: while True:
# Process each chunk from the agent stream. for chunk in agent.stream({"messages": msg_list}, stream_config):
for chunk in agent.stream({"messages": msg_list}, config):
logger.debug("Agent output: %s", chunk) logger.debug("Agent output: %s", chunk)
check_interrupt() check_interrupt()
agent_type = get_agent_type(agent) agent_type = get_agent_type(agent)
print_agent_output(chunk, agent_type) print_agent_output(chunk, agent_type)
if is_completed() or should_exit(): if is_completed() or should_exit():
reset_completion_flags() reset_completion_flags()
return True # Exit immediately when finished or signaled to exit. if cb:
logger.debug(f"AnthropicCallbackHandler:\n{cb}")
return True
logger.debug("Stream iteration ended; checking agent state for continuation.") logger.debug("Stream iteration ended; checking agent state for continuation.")
# Prepare state configuration, ensuring 'configurable' is present. # Prepare state configuration, ensuring 'configurable' is present.
state_config = get_config_repository().get_all().copy() state_config = get_config_repository().get_all().copy()
if "configurable" not in state_config: if "configurable" not in state_config:
logger.debug("Key 'configurable' not found in config; adding it as an empty dict.") logger.debug(
"Key 'configurable' not found in config; adding it as an empty dict."
)
state_config["configurable"] = {} state_config["configurable"] = {}
logger.debug("Using state_config for agent.get_state(): %s", state_config) logger.debug("Using state_config for agent.get_state(): %s", state_config)
@ -1491,21 +1614,27 @@ def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage]):
state = agent.get_state(state_config) state = agent.get_state(state_config)
logger.debug("Agent state retrieved: %s", state) logger.debug("Agent state retrieved: %s", state)
except Exception as e: except Exception as e:
logger.error("Error retrieving agent state with state_config %s: %s", state_config, e) logger.error(
"Error retrieving agent state with state_config %s: %s", state_config, e
)
raise raise
# If the state indicates that further steps remain (i.e. state.next is non-empty),
# then resume execution by invoking the agent with no new input.
if state.next: if state.next:
logger.debug("State indicates continuation (state.next: %s); resuming execution.", state.next) logger.debug(
agent.invoke(None, config) "State indicates continuation (state.next: %s); resuming execution.",
state.next,
)
agent.invoke(None, stream_config)
continue continue
else: else:
logger.debug("No continuation indicated in state; exiting stream loop.") logger.debug("No continuation indicated in state; exiting stream loop.")
break break
if cb:
logger.debug(f"AnthropicCallbackHandler:\n{cb}")
return True return True
def run_agent_with_retry( def run_agent_with_retry(
agent: RAgents, agent: RAgents,
prompt: str, prompt: str,
@ -1517,7 +1646,9 @@ def run_agent_with_retry(
max_retries = 20 max_retries = 20
base_delay = 1 base_delay = 1
test_attempts = 0 test_attempts = 0
_max_test_retries = get_config_repository().get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES) _max_test_retries = get_config_repository().get(
"max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES
)
auto_test = get_config_repository().get("auto_test", False) auto_test = get_config_repository().get("auto_test", False)
original_prompt = prompt original_prompt = prompt
msg_list = [HumanMessage(content=prompt)] msg_list = [HumanMessage(content=prompt)]

View File

@ -0,0 +1,270 @@
"""Custom callback handlers for tracking token usage and costs."""
import threading
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Dict, List, Optional
from langchain_core.callbacks import BaseCallbackHandler
# Define cost per 1K tokens for various models
ANTHROPIC_MODEL_COSTS = {
# Claude 3.7 Sonnet input
"claude-3-7-sonnet-20250219": 0.003,
"anthropic/claude-3.7-sonnet": 0.003,
"claude-3.7-sonnet": 0.003,
# Claude 3.7 Sonnet output
"claude-3-7-sonnet-20250219-completion": 0.015,
"anthropic/claude-3.7-sonnet-completion": 0.015,
"claude-3.7-sonnet-completion": 0.015,
# Claude 3 Opus input
"claude-3-opus-20240229": 0.015,
"anthropic/claude-3-opus": 0.015,
"claude-3-opus": 0.015,
# Claude 3 Opus output
"claude-3-opus-20240229-completion": 0.075,
"anthropic/claude-3-opus-completion": 0.075,
"claude-3-opus-completion": 0.075,
# Claude 3 Sonnet input
"claude-3-sonnet-20240229": 0.003,
"anthropic/claude-3-sonnet": 0.003,
"claude-3-sonnet": 0.003,
# Claude 3 Sonnet output
"claude-3-sonnet-20240229-completion": 0.015,
"anthropic/claude-3-sonnet-completion": 0.015,
"claude-3-sonnet-completion": 0.015,
# Claude 3 Haiku input
"claude-3-haiku-20240307": 0.00025,
"anthropic/claude-3-haiku": 0.00025,
"claude-3-haiku": 0.00025,
# Claude 3 Haiku output
"claude-3-haiku-20240307-completion": 0.00125,
"anthropic/claude-3-haiku-completion": 0.00125,
"claude-3-haiku-completion": 0.00125,
# Claude 2 input
"claude-2": 0.008,
"claude-2.0": 0.008,
"claude-2.1": 0.008,
# Claude 2 output
"claude-2-completion": 0.024,
"claude-2.0-completion": 0.024,
"claude-2.1-completion": 0.024,
# Claude Instant input
"claude-instant-1": 0.0016,
"claude-instant-1.2": 0.0016,
# Claude Instant output
"claude-instant-1-completion": 0.0055,
"claude-instant-1.2-completion": 0.0055,
}
def standardize_model_name(model_name: str, is_completion: bool = False) -> str:
"""
Standardize the model name to a format that can be used for cost calculation.
Args:
model_name: Model name to standardize.
is_completion: Whether the model is used for completion or not.
Returns:
Standardized model name.
"""
if not model_name:
model_name = "claude-3-sonnet"
model_name = model_name.lower()
# Handle OpenRouter prefixes
if model_name.startswith("anthropic/"):
model_name = model_name[len("anthropic/") :]
# Add completion suffix if needed
if is_completion and not model_name.endswith("-completion"):
model_name = model_name + "-completion"
return model_name
def get_anthropic_token_cost_for_model(
model_name: str, num_tokens: int, is_completion: bool = False
) -> float:
"""
Get the cost in USD for a given model and number of tokens.
Args:
model_name: Name of the model
num_tokens: Number of tokens.
is_completion: Whether the model is used for completion or not.
Returns:
Cost in USD.
"""
model_name = standardize_model_name(model_name, is_completion)
if model_name not in ANTHROPIC_MODEL_COSTS:
# Default to Claude 3 Sonnet pricing if model not found
model_name = (
"claude-3-sonnet" if not is_completion else "claude-3-sonnet-completion"
)
cost_per_1k = ANTHROPIC_MODEL_COSTS[model_name]
total_cost = cost_per_1k * (num_tokens / 1000)
return total_cost
class AnthropicCallbackHandler(BaseCallbackHandler):
"""Callback Handler that tracks Anthropic token usage and costs."""
total_tokens: int = 0
prompt_tokens: int = 0
completion_tokens: int = 0
successful_requests: int = 0
total_cost: float = 0.0
model_name: str = "claude-3-sonnet" # Default model
def __init__(self, model_name: Optional[str] = None) -> None:
super().__init__()
self._lock = threading.Lock()
if model_name:
self.model_name = model_name
# Default costs for Claude 3.7 Sonnet
self.input_cost_per_token = 0.003 / 1000 # $3/M input tokens
self.output_cost_per_token = 0.015 / 1000 # $15/M output tokens
def __repr__(self) -> str:
return (
f"Tokens Used: {self.total_tokens}\n"
f"\tPrompt Tokens: {self.prompt_tokens}\n"
f"\tCompletion Tokens: {self.completion_tokens}\n"
f"Successful Requests: {self.successful_requests}\n"
f"Total Cost (USD): ${self.total_cost:.6f}"
)
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Record the model name if available."""
if "name" in serialized:
self.model_name = serialized["name"]
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Count tokens as they're generated."""
with self._lock:
self.completion_tokens += 1
self.total_tokens += 1
token_cost = get_anthropic_token_cost_for_model(
self.model_name, 1, is_completion=True
)
self.total_cost += token_cost
def on_llm_end(self, response: Any, **kwargs: Any) -> None:
"""Collect token usage from response."""
token_usage = {}
# Try to extract token usage from response
if hasattr(response, "llm_output") and response.llm_output:
llm_output = response.llm_output
if "token_usage" in llm_output:
token_usage = llm_output["token_usage"]
elif "usage" in llm_output:
usage = llm_output["usage"]
# Handle Anthropic's specific usage format
if "input_tokens" in usage:
token_usage["prompt_tokens"] = usage["input_tokens"]
if "output_tokens" in usage:
token_usage["completion_tokens"] = usage["output_tokens"]
# Extract model name if available
if "model_name" in llm_output:
self.model_name = llm_output["model_name"]
# Try to get usage from response.usage
elif hasattr(response, "usage"):
usage = response.usage
if hasattr(usage, "prompt_tokens"):
token_usage["prompt_tokens"] = usage.prompt_tokens
if hasattr(usage, "completion_tokens"):
token_usage["completion_tokens"] = usage.completion_tokens
if hasattr(usage, "total_tokens"):
token_usage["total_tokens"] = usage.total_tokens
# Extract usage from generations if available
elif hasattr(response, "generations") and response.generations:
for gen in response.generations:
if gen and hasattr(gen[0], "generation_info"):
gen_info = gen[0].generation_info or {}
if "usage" in gen_info:
token_usage = gen_info["usage"]
break
# Update counts with lock to prevent race conditions
with self._lock:
prompt_tokens = token_usage.get("prompt_tokens", 0)
completion_tokens = token_usage.get("completion_tokens", 0)
# Only update prompt tokens if we have them
if prompt_tokens > 0:
self.prompt_tokens += prompt_tokens
self.total_tokens += prompt_tokens
prompt_cost = get_anthropic_token_cost_for_model(
self.model_name, prompt_tokens, is_completion=False
)
self.total_cost += prompt_cost
# Only update completion tokens if not already counted by on_llm_new_token
if completion_tokens > 0 and completion_tokens > self.completion_tokens:
additional_tokens = completion_tokens - self.completion_tokens
self.completion_tokens = completion_tokens
self.total_tokens += additional_tokens
completion_cost = get_anthropic_token_cost_for_model(
self.model_name, additional_tokens, is_completion=True
)
self.total_cost += completion_cost
self.successful_requests += 1
def __copy__(self) -> "AnthropicCallbackHandler":
"""Return a copy of the callback handler."""
return self
def __deepcopy__(self, memo: Any) -> "AnthropicCallbackHandler":
"""Return a deep copy of the callback handler."""
return self
# Create a context variable for our custom callback
anthropic_callback_var: ContextVar[Optional[AnthropicCallbackHandler]] = ContextVar(
"anthropic_callback", default=None
)
@contextmanager
def get_anthropic_callback(
model_name: Optional[str] = None,
) -> AnthropicCallbackHandler:
"""Get the Anthropic callback handler in a context manager.
which conveniently exposes token and cost information.
Args:
model_name: Optional model name to use for cost calculation.
Returns:
AnthropicCallbackHandler: The Anthropic callback handler.
Example:
>>> with get_anthropic_callback("claude-3-sonnet") as cb:
... # Use the callback handler
... # cb.total_tokens, cb.total_cost will be available after
"""
cb = AnthropicCallbackHandler(model_name)
anthropic_callback_var.set(cb)
yield cb
anthropic_callback_var.set(None)

910
uv.lock

File diff suppressed because it is too large Load Diff