feat(agent_utils.py): add AnthropicCallbackHandler to track token usage and costs for Anthropic models (#118)

style(agent_utils.py): format imports and code for better readability
refactor(agent_utils.py): standardize model name and cost calculation logic for clarity and maintainability
chore(anthropic_callback_handler.py): create a new file for the AnthropicCallbackHandler implementation and related functions
This commit is contained in:
Ariel Frischer 2025-03-10 04:08:12 -07:00 committed by GitHub
parent d194868cff
commit 2899b5f848
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 1040 additions and 623 deletions

View File

@ -10,6 +10,9 @@ import uuid
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Sequence from typing import Any, Dict, List, Literal, Optional, Sequence
from ra_aid.callbacks.anthropic_callback_handler import AnthropicCallbackHandler
import litellm import litellm
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
from openai import RateLimitError as OpenAIRateLimitError from openai import RateLimitError as OpenAIRateLimitError
@ -71,7 +74,11 @@ from ra_aid.prompts.human_prompts import (
from ra_aid.prompts.implementation_prompts import IMPLEMENTATION_PROMPT from ra_aid.prompts.implementation_prompts import IMPLEMENTATION_PROMPT
from ra_aid.prompts.common_prompts import NEW_PROJECT_HINTS from ra_aid.prompts.common_prompts import NEW_PROJECT_HINTS
from ra_aid.prompts.planning_prompts import PLANNING_PROMPT from ra_aid.prompts.planning_prompts import PLANNING_PROMPT
from ra_aid.prompts.reasoning_assist_prompt import REASONING_ASSIST_PROMPT_PLANNING, REASONING_ASSIST_PROMPT_IMPLEMENTATION, REASONING_ASSIST_PROMPT_RESEARCH from ra_aid.prompts.reasoning_assist_prompt import (
REASONING_ASSIST_PROMPT_PLANNING,
REASONING_ASSIST_PROMPT_IMPLEMENTATION,
REASONING_ASSIST_PROMPT_RESEARCH,
)
from ra_aid.prompts.research_prompts import ( from ra_aid.prompts.research_prompts import (
RESEARCH_ONLY_PROMPT, RESEARCH_ONLY_PROMPT,
RESEARCH_PROMPT, RESEARCH_PROMPT,
@ -90,9 +97,15 @@ from ra_aid.tool_configs import (
) )
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository from ra_aid.database.repositories.key_snippet_repository import (
from ra_aid.database.repositories.human_input_repository import get_human_input_repository get_key_snippet_repository,
from ra_aid.database.repositories.research_note_repository import get_research_note_repository )
from ra_aid.database.repositories.human_input_repository import (
get_human_input_repository,
)
from ra_aid.database.repositories.research_note_repository import (
get_research_note_repository,
)
from ra_aid.database.repositories.work_log_repository import get_work_log_repository from ra_aid.database.repositories.work_log_repository import get_work_log_repository
from ra_aid.model_formatters import format_key_facts_dict from ra_aid.model_formatters import format_key_facts_dict
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
@ -332,7 +345,9 @@ def create_agent(
if is_anthropic_claude(config): if is_anthropic_claude(config):
logger.debug("Using create_react_agent to instantiate agent.") logger.debug("Using create_react_agent to instantiate agent.")
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens) agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
return create_react_agent(model, tools, interrupt_after=['tools'], **agent_kwargs) return create_react_agent(
model, tools, interrupt_after=["tools"], **agent_kwargs
)
else: else:
logger.debug("Using CiaynAgent agent instance") logger.debug("Using CiaynAgent agent instance")
return CiaynAgent(model, tools, max_tokens=max_input_tokens, config=config) return CiaynAgent(model, tools, max_tokens=max_input_tokens, config=config)
@ -343,7 +358,9 @@ def create_agent(
config = get_config_repository().get_all() config = get_config_repository().get_all()
max_input_tokens = get_model_token_limit(config, agent_type) max_input_tokens = get_model_token_limit(config, agent_type)
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens) agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
return create_react_agent(model, tools, interrupt_after=['tools'], **agent_kwargs) return create_react_agent(
model, tools, interrupt_after=["tools"], **agent_kwargs
)
def run_research_agent( def run_research_agent(
@ -406,7 +423,9 @@ def run_research_agent(
recent_inputs = human_input_repository.get_recent(1) recent_inputs = human_input_repository.get_recent(1)
if recent_inputs and len(recent_inputs) > 0: if recent_inputs and len(recent_inputs) > 0:
last_human_input = recent_inputs[0].content last_human_input = recent_inputs[0].content
base_task = f"<last human input>{last_human_input}</last human input>\n{base_task}" base_task = (
f"<last human input>{last_human_input}</last human input>\n{base_task}"
)
except RuntimeError as e: except RuntimeError as e:
logger.error(f"Failed to access human input repository: {str(e)}") logger.error(f"Failed to access human input repository: {str(e)}")
# Continue without appending last human input # Continue without appending last human input
@ -416,7 +435,9 @@ def run_research_agent(
except RuntimeError as e: except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}") logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = "" key_facts = ""
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()) key_snippets = format_key_snippets_dict(
get_key_snippet_repository().get_snippets_dict()
)
related_files = get_related_files() related_files = get_related_files()
try: try:
@ -432,20 +453,22 @@ def run_research_agent(
human_interaction=hil, human_interaction=hil,
web_research_enabled=get_config_repository().get("web_research_enabled", False), web_research_enabled=get_config_repository().get("web_research_enabled", False),
) )
# Get model info for reasoning assistance configuration # Get model info for reasoning assistance configuration
provider = get_config_repository().get("provider", "") provider = get_config_repository().get("provider", "")
model_name = get_config_repository().get("model", "") model_name = get_config_repository().get("model", "")
# Get model configuration to check for reasoning_assist_default # Get model configuration to check for reasoning_assist_default
model_config = {} model_config = {}
provider_models = models_params.get(provider, {}) provider_models = models_params.get(provider, {})
if provider_models and model_name in provider_models: if provider_models and model_name in provider_models:
model_config = provider_models[model_name] model_config = provider_models[model_name]
# Check if reasoning assist is explicitly enabled/disabled # Check if reasoning assist is explicitly enabled/disabled
force_assistance = get_config_repository().get("force_reasoning_assistance", False) force_assistance = get_config_repository().get("force_reasoning_assistance", False)
disable_assistance = get_config_repository().get("disable_reasoning_assistance", False) disable_assistance = get_config_repository().get(
"disable_reasoning_assistance", False
)
if force_assistance: if force_assistance:
reasoning_assist_enabled = True reasoning_assist_enabled = True
elif disable_assistance: elif disable_assistance:
@ -453,26 +476,31 @@ def run_research_agent(
else: else:
# Fall back to model default # Fall back to model default
reasoning_assist_enabled = model_config.get("reasoning_assist_default", False) reasoning_assist_enabled = model_config.get("reasoning_assist_default", False)
logger.debug("Reasoning assist enabled: %s", reasoning_assist_enabled) logger.debug("Reasoning assist enabled: %s", reasoning_assist_enabled)
expert_guidance = "" expert_guidance = ""
# Get research note information for reasoning assistance # Get research note information for reasoning assistance
try: try:
research_notes = format_research_notes_dict(get_research_note_repository().get_notes_dict()) research_notes = format_research_notes_dict(
get_research_note_repository().get_notes_dict()
)
except Exception as e: except Exception as e:
logger.warning(f"Failed to get research notes: {e}") logger.warning(f"Failed to get research notes: {e}")
research_notes = "" research_notes = ""
# If reasoning assist is enabled, make a one-off call to the expert model # If reasoning assist is enabled, make a one-off call to the expert model
if reasoning_assist_enabled: if reasoning_assist_enabled:
try: try:
logger.info("Reasoning assist enabled for model %s, getting expert guidance", model_name) logger.info(
"Reasoning assist enabled for model %s, getting expert guidance",
model_name,
)
# Collect tool descriptions # Collect tool descriptions
tool_metadata = [] tool_metadata = []
from ra_aid.tools.reflection import get_function_info as get_tool_info from ra_aid.tools.reflection import get_function_info as get_tool_info
for tool in tools: for tool in tools:
try: try:
tool_info = get_tool_info(tool.func) tool_info = get_tool_info(tool.func)
@ -481,13 +509,13 @@ def run_research_agent(
tool_metadata.append(f"Tool: {name}\nDescription: {description}\n") tool_metadata.append(f"Tool: {name}\nDescription: {description}\n")
except Exception as e: except Exception as e:
logger.warning(f"Error getting tool info for {tool}: {e}") logger.warning(f"Error getting tool info for {tool}: {e}")
# Format tool metadata # Format tool metadata
formatted_tool_metadata = "\n".join(tool_metadata) formatted_tool_metadata = "\n".join(tool_metadata)
# Initialize expert model # Initialize expert model
expert_model = initialize_expert_llm(provider, model_name) expert_model = initialize_expert_llm(provider, model_name)
# Format the reasoning assist prompt # Format the reasoning assist prompt
reasoning_assist_prompt = REASONING_ASSIST_PROMPT_RESEARCH.format( reasoning_assist_prompt = REASONING_ASSIST_PROMPT_RESEARCH.format(
current_date=current_date, current_date=current_date,
@ -500,62 +528,78 @@ def run_research_agent(
env_inv=get_env_inv(), env_inv=get_env_inv(),
tool_metadata=formatted_tool_metadata, tool_metadata=formatted_tool_metadata,
) )
# Show the reasoning assist query in a panel # Show the reasoning assist query in a panel
console.print( console.print(
Panel(Markdown("Consulting with the reasoning model on the best research approach."), title="📝 Thinking about research strategy...", border_style="yellow") Panel(
Markdown(
"Consulting with the reasoning model on the best research approach."
),
title="📝 Thinking about research strategy...",
border_style="yellow",
)
) )
logger.debug("Invoking expert model for reasoning assist") logger.debug("Invoking expert model for reasoning assist")
# Make the call to the expert model # Make the call to the expert model
response = expert_model.invoke(reasoning_assist_prompt) response = expert_model.invoke(reasoning_assist_prompt)
# Check if the model supports think tags # Check if the model supports think tags
supports_think_tag = model_config.get("supports_think_tag", False) supports_think_tag = model_config.get("supports_think_tag", False)
supports_thinking = model_config.get("supports_thinking", False) supports_thinking = model_config.get("supports_thinking", False)
# Get response content, handling if it's a list (for Claude thinking mode) # Get response content, handling if it's a list (for Claude thinking mode)
content = None content = None
if hasattr(response, 'content'): if hasattr(response, "content"):
content = response.content content = response.content
else: else:
# Fallback if content attribute is missing # Fallback if content attribute is missing
content = str(response) content = str(response)
# Process content based on its type # Process content based on its type
if isinstance(content, list): if isinstance(content, list):
# Handle structured thinking mode (e.g., Claude 3.7) # Handle structured thinking mode (e.g., Claude 3.7)
thinking_content = None thinking_content = None
response_text = None response_text = None
# Process each item in the list # Process each item in the list
for item in content: for item in content:
if isinstance(item, dict): if isinstance(item, dict):
# Extract thinking content # Extract thinking content
if item.get('type') == 'thinking' and 'thinking' in item: if item.get("type") == "thinking" and "thinking" in item:
thinking_content = item['thinking'] thinking_content = item["thinking"]
logger.debug("Found structured thinking content") logger.debug("Found structured thinking content")
# Extract response text # Extract response text
elif item.get('type') == 'text' and 'text' in item: elif item.get("type") == "text" and "text" in item:
response_text = item['text'] response_text = item["text"]
logger.debug("Found structured response text") logger.debug("Found structured response text")
# Display thinking content in a separate panel if available # Display thinking content in a separate panel if available
if thinking_content and get_config_repository().get("show_thoughts", False): if thinking_content and get_config_repository().get(
logger.debug(f"Displaying structured thinking content ({len(thinking_content)} chars)") "show_thoughts", False
console.print( ):
Panel(Markdown(thinking_content), title="💭 Expert Thinking", border_style="yellow") logger.debug(
f"Displaying structured thinking content ({len(thinking_content)} chars)"
) )
console.print(
Panel(
Markdown(thinking_content),
title="💭 Expert Thinking",
border_style="yellow",
)
)
# Use response_text if available, otherwise fall back to joining # Use response_text if available, otherwise fall back to joining
if response_text: if response_text:
content = response_text content = response_text
else: else:
# Fallback: join list items if structured extraction failed # Fallback: join list items if structured extraction failed
logger.debug("No structured response text found, joining list items") logger.debug(
"No structured response text found, joining list items"
)
content = "\n".join(str(item) for item in content) content = "\n".join(str(item) for item in content)
elif (supports_think_tag or supports_thinking): elif supports_think_tag or supports_thinking:
# Process thinking content using the centralized function # Process thinking content using the centralized function
content, _ = process_thinking_content( content, _ = process_thinking_content(
content=content, content=content,
@ -563,22 +607,28 @@ def run_research_agent(
supports_thinking=supports_thinking, supports_thinking=supports_thinking,
panel_title="💭 Expert Thinking", panel_title="💭 Expert Thinking",
panel_style="yellow", panel_style="yellow",
logger=logger logger=logger,
) )
# Display the expert guidance in a panel # Display the expert guidance in a panel
console.print( console.print(
Panel(Markdown(content), title="Research Strategy Guidance", border_style="blue") Panel(
Markdown(content),
title="Research Strategy Guidance",
border_style="blue",
)
) )
# Use the content as expert guidance # Use the content as expert guidance
expert_guidance = content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY DURING RESEARCH" expert_guidance = (
content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY DURING RESEARCH"
)
logger.info("Received expert guidance for research") logger.info("Received expert guidance for research")
except Exception as e: except Exception as e:
logger.error("Error getting expert guidance for research: %s", e) logger.error("Error getting expert guidance for research: %s", e)
expert_guidance = "" expert_guidance = ""
agent = create_agent(model, tools, checkpointer=memory, agent_type="research") agent = create_agent(model, tools, checkpointer=memory, agent_type="research")
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else "" expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
@ -588,7 +638,7 @@ def run_research_agent(
if get_config_repository().get("web_research_enabled") if get_config_repository().get("web_research_enabled")
else "" else ""
) )
# Prepare expert guidance section if expert guidance is available # Prepare expert guidance section if expert guidance is available
expert_guidance_section = "" expert_guidance_section = ""
if expert_guidance: if expert_guidance:
@ -600,7 +650,7 @@ def run_research_agent(
# We get research notes earlier for reasoning assistance # We get research notes earlier for reasoning assistance
# Get environment inventory information # Get environment inventory information
prompt = (RESEARCH_ONLY_PROMPT if research_only else RESEARCH_PROMPT).format( prompt = (RESEARCH_ONLY_PROMPT if research_only else RESEARCH_PROMPT).format(
current_date=current_date, current_date=current_date,
working_directory=working_directory, working_directory=working_directory,
@ -643,9 +693,7 @@ def run_research_agent(
if agent is not None: if agent is not None:
logger.debug("Research agent created successfully") logger.debug("Research agent created successfully")
none_or_fallback_handler = init_fallback_handler(agent, tools) none_or_fallback_handler = init_fallback_handler(agent, tools)
_result = run_agent_with_retry( _result = run_agent_with_retry(agent, prompt, none_or_fallback_handler)
agent, prompt, none_or_fallback_handler
)
if _result: if _result:
# Log research completion # Log research completion
log_work_event(f"Completed research phase for: {base_task_or_query}") log_work_event(f"Completed research phase for: {base_task_or_query}")
@ -731,7 +779,9 @@ def run_web_research_agent(
logger.error(f"Failed to access key fact repository: {str(e)}") logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = "" key_facts = ""
try: try:
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()) key_snippets = format_key_snippets_dict(
get_key_snippet_repository().get_snippets_dict()
)
except RuntimeError as e: except RuntimeError as e:
logger.error(f"Failed to access key snippet repository: {str(e)}") logger.error(f"Failed to access key snippet repository: {str(e)}")
key_snippets = "" key_snippets = ""
@ -741,7 +791,7 @@ def run_web_research_agent(
working_directory = os.getcwd() working_directory = os.getcwd()
# Get environment inventory information # Get environment inventory information
prompt = WEB_RESEARCH_PROMPT.format( prompt = WEB_RESEARCH_PROMPT.format(
current_date=current_date, current_date=current_date,
working_directory=working_directory, working_directory=working_directory,
@ -771,9 +821,7 @@ def run_web_research_agent(
logger.debug("Web research agent completed successfully") logger.debug("Web research agent completed successfully")
none_or_fallback_handler = init_fallback_handler(agent, tools) none_or_fallback_handler = init_fallback_handler(agent, tools)
_result = run_agent_with_retry( _result = run_agent_with_retry(agent, prompt, none_or_fallback_handler)
agent, prompt, none_or_fallback_handler
)
if _result: if _result:
# Log web research completion # Log web research completion
log_work_event(f"Completed web research phase for: {query}") log_work_event(f"Completed web research phase for: {query}")
@ -835,17 +883,19 @@ def run_planning_agent(
provider = get_config_repository().get("provider", "") provider = get_config_repository().get("provider", "")
model_name = get_config_repository().get("model", "") model_name = get_config_repository().get("model", "")
logger.debug("Checking for reasoning_assist_default on %s/%s", provider, model_name) logger.debug("Checking for reasoning_assist_default on %s/%s", provider, model_name)
# Get model configuration to check for reasoning_assist_default # Get model configuration to check for reasoning_assist_default
model_config = {} model_config = {}
provider_models = models_params.get(provider, {}) provider_models = models_params.get(provider, {})
if provider_models and model_name in provider_models: if provider_models and model_name in provider_models:
model_config = provider_models[model_name] model_config = provider_models[model_name]
# Check if reasoning assist is explicitly enabled/disabled # Check if reasoning assist is explicitly enabled/disabled
force_assistance = get_config_repository().get("force_reasoning_assistance", False) force_assistance = get_config_repository().get("force_reasoning_assistance", False)
disable_assistance = get_config_repository().get("disable_reasoning_assistance", False) disable_assistance = get_config_repository().get(
"disable_reasoning_assistance", False
)
if force_assistance: if force_assistance:
reasoning_assist_enabled = True reasoning_assist_enabled = True
elif disable_assistance: elif disable_assistance:
@ -853,27 +903,29 @@ def run_planning_agent(
else: else:
# Fall back to model default # Fall back to model default
reasoning_assist_enabled = model_config.get("reasoning_assist_default", False) reasoning_assist_enabled = model_config.get("reasoning_assist_default", False)
logger.debug("Reasoning assist enabled: %s", reasoning_assist_enabled) logger.debug("Reasoning assist enabled: %s", reasoning_assist_enabled)
# Get all the context information (used both for normal planning and reasoning assist) # Get all the context information (used both for normal planning and reasoning assist)
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S") current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
working_directory = os.getcwd() working_directory = os.getcwd()
# Make sure key_facts is defined before using it # Make sure key_facts is defined before using it
try: try:
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict()) key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
except RuntimeError as e: except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}") logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = "" key_facts = ""
# Make sure key_snippets is defined before using it # Make sure key_snippets is defined before using it
try: try:
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()) key_snippets = format_key_snippets_dict(
get_key_snippet_repository().get_snippets_dict()
)
except RuntimeError as e: except RuntimeError as e:
logger.error(f"Failed to access key snippet repository: {str(e)}") logger.error(f"Failed to access key snippet repository: {str(e)}")
key_snippets = "" key_snippets = ""
# Get formatted research notes using repository # Get formatted research notes using repository
try: try:
repository = get_research_note_repository() repository = get_research_note_repository()
@ -882,28 +934,31 @@ def run_planning_agent(
except RuntimeError as e: except RuntimeError as e:
logger.error(f"Failed to access research note repository: {str(e)}") logger.error(f"Failed to access research note repository: {str(e)}")
formatted_research_notes = "" formatted_research_notes = ""
# Get related files # Get related files
related_files = "\n".join(get_related_files()) related_files = "\n".join(get_related_files())
# Get environment inventory information # Get environment inventory information
env_inv = get_env_inv() env_inv = get_env_inv()
# Display the planning stage header before any reasoning assistance # Display the planning stage header before any reasoning assistance
print_stage_header("Planning Stage") print_stage_header("Planning Stage")
# Initialize expert guidance section # Initialize expert guidance section
expert_guidance = "" expert_guidance = ""
# If reasoning assist is enabled, make a one-off call to the expert model # If reasoning assist is enabled, make a one-off call to the expert model
if reasoning_assist_enabled: if reasoning_assist_enabled:
try: try:
logger.info("Reasoning assist enabled for model %s, getting expert guidance", model_name) logger.info(
"Reasoning assist enabled for model %s, getting expert guidance",
model_name,
)
# Collect tool descriptions # Collect tool descriptions
tool_metadata = [] tool_metadata = []
from ra_aid.tools.reflection import get_function_info as get_tool_info from ra_aid.tools.reflection import get_function_info as get_tool_info
for tool in tools: for tool in tools:
try: try:
tool_info = get_tool_info(tool.func) tool_info = get_tool_info(tool.func)
@ -912,13 +967,13 @@ def run_planning_agent(
tool_metadata.append(f"Tool: {name}\nDescription: {description}\n") tool_metadata.append(f"Tool: {name}\nDescription: {description}\n")
except Exception as e: except Exception as e:
logger.warning(f"Error getting tool info for {tool}: {e}") logger.warning(f"Error getting tool info for {tool}: {e}")
# Format tool metadata # Format tool metadata
formatted_tool_metadata = "\n".join(tool_metadata) formatted_tool_metadata = "\n".join(tool_metadata)
# Initialize expert model # Initialize expert model
expert_model = initialize_expert_llm(provider, model_name) expert_model = initialize_expert_llm(provider, model_name)
# Format the reasoning assist prompt # Format the reasoning assist prompt
reasoning_assist_prompt = REASONING_ASSIST_PROMPT_PLANNING.format( reasoning_assist_prompt = REASONING_ASSIST_PROMPT_PLANNING.format(
current_date=current_date, current_date=current_date,
@ -931,62 +986,78 @@ def run_planning_agent(
env_inv=env_inv, env_inv=env_inv,
tool_metadata=formatted_tool_metadata, tool_metadata=formatted_tool_metadata,
) )
# Show the reasoning assist query in a panel # Show the reasoning assist query in a panel
console.print( console.print(
Panel(Markdown("Consulting with the reasoning model on the best way to do this."), title="📝 Thinking about the plan...", border_style="yellow") Panel(
Markdown(
"Consulting with the reasoning model on the best way to do this."
),
title="📝 Thinking about the plan...",
border_style="yellow",
)
) )
logger.debug("Invoking expert model for reasoning assist") logger.debug("Invoking expert model for reasoning assist")
# Make the call to the expert model # Make the call to the expert model
response = expert_model.invoke(reasoning_assist_prompt) response = expert_model.invoke(reasoning_assist_prompt)
# Check if the model supports think tags # Check if the model supports think tags
supports_think_tag = model_config.get("supports_think_tag", False) supports_think_tag = model_config.get("supports_think_tag", False)
supports_thinking = model_config.get("supports_thinking", False) supports_thinking = model_config.get("supports_thinking", False)
# Get response content, handling if it's a list (for Claude thinking mode) # Get response content, handling if it's a list (for Claude thinking mode)
content = None content = None
if hasattr(response, 'content'): if hasattr(response, "content"):
content = response.content content = response.content
else: else:
# Fallback if content attribute is missing # Fallback if content attribute is missing
content = str(response) content = str(response)
# Process content based on its type # Process content based on its type
if isinstance(content, list): if isinstance(content, list):
# Handle structured thinking mode (e.g., Claude 3.7) # Handle structured thinking mode (e.g., Claude 3.7)
thinking_content = None thinking_content = None
response_text = None response_text = None
# Process each item in the list # Process each item in the list
for item in content: for item in content:
if isinstance(item, dict): if isinstance(item, dict):
# Extract thinking content # Extract thinking content
if item.get('type') == 'thinking' and 'thinking' in item: if item.get("type") == "thinking" and "thinking" in item:
thinking_content = item['thinking'] thinking_content = item["thinking"]
logger.debug("Found structured thinking content") logger.debug("Found structured thinking content")
# Extract response text # Extract response text
elif item.get('type') == 'text' and 'text' in item: elif item.get("type") == "text" and "text" in item:
response_text = item['text'] response_text = item["text"]
logger.debug("Found structured response text") logger.debug("Found structured response text")
# Display thinking content in a separate panel if available # Display thinking content in a separate panel if available
if thinking_content and get_config_repository().get("show_thoughts", False): if thinking_content and get_config_repository().get(
logger.debug(f"Displaying structured thinking content ({len(thinking_content)} chars)") "show_thoughts", False
console.print( ):
Panel(Markdown(thinking_content), title="💭 Expert Thinking", border_style="yellow") logger.debug(
f"Displaying structured thinking content ({len(thinking_content)} chars)"
) )
console.print(
Panel(
Markdown(thinking_content),
title="💭 Expert Thinking",
border_style="yellow",
)
)
# Use response_text if available, otherwise fall back to joining # Use response_text if available, otherwise fall back to joining
if response_text: if response_text:
content = response_text content = response_text
else: else:
# Fallback: join list items if structured extraction failed # Fallback: join list items if structured extraction failed
logger.debug("No structured response text found, joining list items") logger.debug(
"No structured response text found, joining list items"
)
content = "\n".join(str(item) for item in content) content = "\n".join(str(item) for item in content)
elif (supports_think_tag or supports_thinking): elif supports_think_tag or supports_thinking:
# Process thinking content using the centralized function # Process thinking content using the centralized function
content, _ = process_thinking_content( content, _ = process_thinking_content(
content=content, content=content,
@ -994,24 +1065,28 @@ def run_planning_agent(
supports_thinking=supports_thinking, supports_thinking=supports_thinking,
panel_title="💭 Expert Thinking", panel_title="💭 Expert Thinking",
panel_style="yellow", panel_style="yellow",
logger=logger logger=logger,
) )
# Display the expert guidance in a panel # Display the expert guidance in a panel
console.print( console.print(
Panel(Markdown(content), title="Reasoning Guidance", border_style="blue") Panel(
Markdown(content), title="Reasoning Guidance", border_style="blue"
)
) )
# Use the content as expert guidance # Use the content as expert guidance
expert_guidance = content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY ON THIS TASK" expert_guidance = (
content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY ON THIS TASK"
)
logger.info("Received expert guidance for planning") logger.info("Received expert guidance for planning")
except Exception as e: except Exception as e:
logger.error("Error getting expert guidance for planning: %s", e) logger.error("Error getting expert guidance for planning: %s", e)
expert_guidance = "" expert_guidance = ""
agent = create_agent(model, tools, checkpointer=memory, agent_type="planner") agent = create_agent(model, tools, checkpointer=memory, agent_type="planner")
expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else "" expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else ""
human_section = HUMAN_PROMPT_SECTION_PLANNING if hil else "" human_section = HUMAN_PROMPT_SECTION_PLANNING if hil else ""
web_research_section = ( web_research_section = (
@ -1019,7 +1094,7 @@ def run_planning_agent(
if get_config_repository().get("web_research_enabled", False) if get_config_repository().get("web_research_enabled", False)
else "" else ""
) )
# Prepare expert guidance section if expert guidance is available # Prepare expert guidance section if expert guidance is available
expert_guidance_section = "" expert_guidance_section = ""
if expert_guidance: if expert_guidance:
@ -1050,7 +1125,9 @@ def run_planning_agent(
) )
config_values = get_config_repository().get_all() config_values = get_config_repository().get_all()
recursion_limit = get_config_repository().get("recursion_limit", DEFAULT_RECURSION_LIMIT) recursion_limit = get_config_repository().get(
"recursion_limit", DEFAULT_RECURSION_LIMIT
)
run_config = { run_config = {
"configurable": {"thread_id": thread_id}, "configurable": {"thread_id": thread_id},
"recursion_limit": recursion_limit, "recursion_limit": recursion_limit,
@ -1060,9 +1137,7 @@ def run_planning_agent(
try: try:
logger.debug("Planning agent completed successfully") logger.debug("Planning agent completed successfully")
none_or_fallback_handler = init_fallback_handler(agent, tools) none_or_fallback_handler = init_fallback_handler(agent, tools)
_result = run_agent_with_retry( _result = run_agent_with_retry(agent, planning_prompt, none_or_fallback_handler)
agent, planning_prompt, none_or_fallback_handler
)
if _result: if _result:
# Log planning completion # Log planning completion
log_work_event(f"Completed planning phase for: {base_task}") log_work_event(f"Completed planning phase for: {base_task}")
@ -1135,7 +1210,7 @@ def run_task_implementation_agent(
except RuntimeError as e: except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}") logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = "" key_facts = ""
# Get formatted research notes using repository # Get formatted research notes using repository
try: try:
repository = get_research_note_repository() repository = get_research_note_repository()
@ -1144,7 +1219,7 @@ def run_task_implementation_agent(
except RuntimeError as e: except RuntimeError as e:
logger.error(f"Failed to access research note repository: {str(e)}") logger.error(f"Failed to access research note repository: {str(e)}")
formatted_research_notes = "" formatted_research_notes = ""
# Get latest project info # Get latest project info
try: try:
project_info = get_project_info(".") project_info = get_project_info(".")
@ -1152,24 +1227,26 @@ def run_task_implementation_agent(
except Exception as e: except Exception as e:
logger.warning("Failed to get project info: %s", str(e)) logger.warning("Failed to get project info: %s", str(e))
formatted_project_info = "Project info unavailable" formatted_project_info = "Project info unavailable"
# Get environment inventory information # Get environment inventory information
env_inv = get_env_inv() env_inv = get_env_inv()
# Get model configuration to check for reasoning_assist_default # Get model configuration to check for reasoning_assist_default
provider = get_config_repository().get("provider", "") provider = get_config_repository().get("provider", "")
model_name = get_config_repository().get("model", "") model_name = get_config_repository().get("model", "")
logger.debug("Checking for reasoning_assist_default on %s/%s", provider, model_name) logger.debug("Checking for reasoning_assist_default on %s/%s", provider, model_name)
model_config = {} model_config = {}
provider_models = models_params.get(provider, {}) provider_models = models_params.get(provider, {})
if provider_models and model_name in provider_models: if provider_models and model_name in provider_models:
model_config = provider_models[model_name] model_config = provider_models[model_name]
# Check if reasoning assist is explicitly enabled/disabled # Check if reasoning assist is explicitly enabled/disabled
force_assistance = get_config_repository().get("force_reasoning_assistance", False) force_assistance = get_config_repository().get("force_reasoning_assistance", False)
disable_assistance = get_config_repository().get("disable_reasoning_assistance", False) disable_assistance = get_config_repository().get(
"disable_reasoning_assistance", False
)
if force_assistance: if force_assistance:
reasoning_assist_enabled = True reasoning_assist_enabled = True
elif disable_assistance: elif disable_assistance:
@ -1177,71 +1254,84 @@ def run_task_implementation_agent(
else: else:
# Fall back to model default # Fall back to model default
reasoning_assist_enabled = model_config.get("reasoning_assist_default", False) reasoning_assist_enabled = model_config.get("reasoning_assist_default", False)
logger.debug("Reasoning assist enabled: %s", reasoning_assist_enabled) logger.debug("Reasoning assist enabled: %s", reasoning_assist_enabled)
# Initialize implementation guidance section # Initialize implementation guidance section
implementation_guidance_section = "" implementation_guidance_section = ""
# If reasoning assist is enabled, make a one-off call to the expert model # If reasoning assist is enabled, make a one-off call to the expert model
if reasoning_assist_enabled: if reasoning_assist_enabled:
try: try:
logger.info("Reasoning assist enabled for model %s, getting implementation guidance", model_name) logger.info(
"Reasoning assist enabled for model %s, getting implementation guidance",
model_name,
)
# Collect tool descriptions # Collect tool descriptions
tool_metadata = [] tool_metadata = []
from ra_aid.tools.reflection import get_function_info as get_tool_info from ra_aid.tools.reflection import get_function_info as get_tool_info
for tool in tools: for tool in tools:
try: try:
tool_info = get_tool_info(tool.func) tool_info = get_tool_info(tool.func)
name = tool.func.__name__ name = tool.func.__name__
description = inspect.getdoc(tool.func) description = inspect.getdoc(tool.func)
tool_metadata.append(f"Tool: {name}\\nDescription: {description}\\n") tool_metadata.append(
f"Tool: {name}\\nDescription: {description}\\n"
)
except Exception as e: except Exception as e:
logger.warning(f"Error getting tool info for {tool}: {e}") logger.warning(f"Error getting tool info for {tool}: {e}")
# Format tool metadata # Format tool metadata
formatted_tool_metadata = "\\n".join(tool_metadata) formatted_tool_metadata = "\\n".join(tool_metadata)
# Initialize expert model # Initialize expert model
expert_model = initialize_expert_llm(provider, model_name) expert_model = initialize_expert_llm(provider, model_name)
# Format the reasoning assist prompt for implementation # Format the reasoning assist prompt for implementation
reasoning_assist_prompt = REASONING_ASSIST_PROMPT_IMPLEMENTATION.format( reasoning_assist_prompt = REASONING_ASSIST_PROMPT_IMPLEMENTATION.format(
current_date=current_date, current_date=current_date,
working_directory=working_directory, working_directory=working_directory,
task=task, task=task,
key_facts=key_facts, key_facts=key_facts,
key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()), key_snippets=format_key_snippets_dict(
get_key_snippet_repository().get_snippets_dict()
),
research_notes=formatted_research_notes, research_notes=formatted_research_notes,
related_files="\\n".join(related_files), related_files="\\n".join(related_files),
env_inv=env_inv, env_inv=env_inv,
tool_metadata=formatted_tool_metadata, tool_metadata=formatted_tool_metadata,
) )
# Show the reasoning assist query in a panel # Show the reasoning assist query in a panel
console.print( console.print(
Panel(Markdown("Consulting with the reasoning model on the best implementation approach."), title="📝 Thinking about implementation...", border_style="yellow") Panel(
Markdown(
"Consulting with the reasoning model on the best implementation approach."
),
title="📝 Thinking about implementation...",
border_style="yellow",
)
) )
logger.debug("Invoking expert model for implementation reasoning assist") logger.debug("Invoking expert model for implementation reasoning assist")
# Make the call to the expert model # Make the call to the expert model
response = expert_model.invoke(reasoning_assist_prompt) response = expert_model.invoke(reasoning_assist_prompt)
# Check if the model supports think tags # Check if the model supports think tags
supports_think_tag = model_config.get("supports_think_tag", False) supports_think_tag = model_config.get("supports_think_tag", False)
supports_thinking = model_config.get("supports_thinking", False) supports_thinking = model_config.get("supports_thinking", False)
# Process response content # Process response content
content = None content = None
if hasattr(response, 'content'): if hasattr(response, "content"):
content = response.content content = response.content
else: else:
# Fallback if content attribute is missing # Fallback if content attribute is missing
content = str(response) content = str(response)
# Process the response content using the centralized function # Process the response content using the centralized function
content, extracted_thinking = process_thinking_content( content, extracted_thinking = process_thinking_content(
content=content, content=content,
@ -1249,24 +1339,28 @@ def run_task_implementation_agent(
supports_thinking=supports_thinking, supports_thinking=supports_thinking,
panel_title="💭 Implementation Thinking", panel_title="💭 Implementation Thinking",
panel_style="yellow", panel_style="yellow",
logger=logger logger=logger,
) )
# Display the implementation guidance in a panel # Display the implementation guidance in a panel
console.print( console.print(
Panel(Markdown(content), title="Implementation Guidance", border_style="blue") Panel(
Markdown(content),
title="Implementation Guidance",
border_style="blue",
)
) )
# Format the implementation guidance section for the prompt # Format the implementation guidance section for the prompt
implementation_guidance_section = f"""<implementation guidance> implementation_guidance_section = f"""<implementation guidance>
{content} {content}
</implementation guidance>""" </implementation guidance>"""
logger.info("Received implementation guidance") logger.info("Received implementation guidance")
except Exception as e: except Exception as e:
logger.error("Error getting implementation guidance: %s", e) logger.error("Error getting implementation guidance: %s", e)
implementation_guidance_section = "" implementation_guidance_section = ""
prompt = IMPLEMENTATION_PROMPT.format( prompt = IMPLEMENTATION_PROMPT.format(
current_date=current_date, current_date=current_date,
working_directory=working_directory, working_directory=working_directory,
@ -1276,7 +1370,9 @@ def run_task_implementation_agent(
plan=plan, plan=plan,
related_files=related_files, related_files=related_files,
key_facts=key_facts, key_facts=key_facts,
key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()), key_snippets=format_key_snippets_dict(
get_key_snippet_repository().get_snippets_dict()
),
research_notes=formatted_research_notes, research_notes=formatted_research_notes,
work_log=get_work_log_repository().format_work_log(), work_log=get_work_log_repository().format_work_log(),
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "", expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
@ -1296,7 +1392,9 @@ def run_task_implementation_agent(
) )
config_values = get_config_repository().get_all() config_values = get_config_repository().get_all()
recursion_limit = get_config_repository().get("recursion_limit", DEFAULT_RECURSION_LIMIT) recursion_limit = get_config_repository().get(
"recursion_limit", DEFAULT_RECURSION_LIMIT
)
run_config = { run_config = {
"configurable": {"thread_id": thread_id}, "configurable": {"thread_id": thread_id},
"recursion_limit": recursion_limit, "recursion_limit": recursion_limit,
@ -1306,9 +1404,7 @@ def run_task_implementation_agent(
try: try:
logger.debug("Implementation agent completed successfully") logger.debug("Implementation agent completed successfully")
none_or_fallback_handler = init_fallback_handler(agent, tools) none_or_fallback_handler = init_fallback_handler(agent, tools)
_result = run_agent_with_retry( _result = run_agent_with_retry(agent, prompt, none_or_fallback_handler)
agent, prompt, none_or_fallback_handler
)
if _result: if _result:
# Log task implementation completion # Log task implementation completion
log_work_event(f"Completed implementation of task: {task}") log_work_event(f"Completed implementation of task: {task}")
@ -1380,27 +1476,37 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
# 1. Check if this is a ValueError with 429 code or rate limit phrases # 1. Check if this is a ValueError with 429 code or rate limit phrases
if isinstance(e, ValueError): if isinstance(e, ValueError):
error_str = str(e).lower() error_str = str(e).lower()
rate_limit_phrases = ["429", "rate limit", "too many requests", "quota exceeded"] rate_limit_phrases = [
if "code" not in error_str and not any(phrase in error_str for phrase in rate_limit_phrases): "429",
"rate limit",
"too many requests",
"quota exceeded",
]
if "code" not in error_str and not any(
phrase in error_str for phrase in rate_limit_phrases
):
raise e raise e
# 2. Check for status_code or http_status attribute equal to 429 # 2. Check for status_code or http_status attribute equal to 429
if hasattr(e, 'status_code') and e.status_code == 429: if hasattr(e, "status_code") and e.status_code == 429:
pass # This is a rate limit error, continue with retry logic pass # This is a rate limit error, continue with retry logic
elif hasattr(e, 'http_status') and e.http_status == 429: elif hasattr(e, "http_status") and e.http_status == 429:
pass # This is a rate limit error, continue with retry logic pass # This is a rate limit error, continue with retry logic
# 3. Check for rate limit phrases in error message # 3. Check for rate limit phrases in error message
elif isinstance(e, Exception) and not isinstance(e, ValueError): elif isinstance(e, Exception) and not isinstance(e, ValueError):
error_str = str(e).lower() error_str = str(e).lower()
if not any(phrase in error_str for phrase in ["rate limit", "too many requests", "quota exceeded", "429"]) and not ("rate" in error_str and "limit" in error_str): if not any(
phrase in error_str
for phrase in ["rate limit", "too many requests", "quota exceeded", "429"]
) and not ("rate" in error_str and "limit" in error_str):
# This doesn't look like a rate limit error, but we'll still retry other API errors # This doesn't look like a rate limit error, but we'll still retry other API errors
pass pass
# Apply common retry logic for all identified errors # Apply common retry logic for all identified errors
if attempt == max_retries - 1: if attempt == max_retries - 1:
logger.error("Max retries reached, failing: %s", str(e)) logger.error("Max retries reached, failing: %s", str(e))
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {e}") raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {e}")
logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e)) logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e))
delay = base_delay * (2**attempt) delay = base_delay * (2**attempt)
print_error( print_error(
@ -1457,55 +1563,78 @@ def _handle_fallback_response(
def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage]): def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage]):
""" """
Streams agent output while handling completion and interruption. Streams agent output while handling completion and interruption.
For each chunk, it logs the output, calls check_interrupt(), prints agent output, For each chunk, it logs the output, calls check_interrupt(), prints agent output,
and then checks if is_completed() or should_exit() are true. If so, it resets completion and then checks if is_completed() or should_exit() are true. If so, it resets completion
flags and returns. After finishing a stream iteration (i.e. the for-loop over chunks), flags and returns. After finishing a stream iteration (i.e. the for-loop over chunks),
the function retrieves the agent's state. If the state indicates further steps (i.e. state.next is non-empty), the function retrieves the agent's state. If the state indicates further steps (i.e. state.next is non-empty),
it resumes execution via agent.invoke(None, config); otherwise, it exits the loop. it resumes execution via agent.invoke(None, config); otherwise, it exits the loop.
This function adheres to the latest LangGraph best practices (as of March 2025) for handling This function adheres to the latest LangGraph best practices (as of March 2025) for handling
human-in-the-loop interruptions using interrupt_after=["tools"]. human-in-the-loop interruptions using interrupt_after=["tools"].
""" """
config = get_config_repository().get_all() config = get_config_repository().get_all()
stream_config = config.copy()
cb = None
if is_anthropic_claude(config):
model_name = config.get("model", "")
full_model_name = model_name
cb = AnthropicCallbackHandler(full_model_name)
if "callbacks" not in stream_config:
stream_config["callbacks"] = []
stream_config["callbacks"].append(cb)
while True: while True:
# Process each chunk from the agent stream. for chunk in agent.stream({"messages": msg_list}, stream_config):
for chunk in agent.stream({"messages": msg_list}, config):
logger.debug("Agent output: %s", chunk) logger.debug("Agent output: %s", chunk)
check_interrupt() check_interrupt()
agent_type = get_agent_type(agent) agent_type = get_agent_type(agent)
print_agent_output(chunk, agent_type) print_agent_output(chunk, agent_type)
if is_completed() or should_exit(): if is_completed() or should_exit():
reset_completion_flags() reset_completion_flags()
return True # Exit immediately when finished or signaled to exit. if cb:
logger.debug(f"AnthropicCallbackHandler:\n{cb}")
return True
logger.debug("Stream iteration ended; checking agent state for continuation.") logger.debug("Stream iteration ended; checking agent state for continuation.")
# Prepare state configuration, ensuring 'configurable' is present. # Prepare state configuration, ensuring 'configurable' is present.
state_config = get_config_repository().get_all().copy() state_config = get_config_repository().get_all().copy()
if "configurable" not in state_config: if "configurable" not in state_config:
logger.debug("Key 'configurable' not found in config; adding it as an empty dict.") logger.debug(
"Key 'configurable' not found in config; adding it as an empty dict."
)
state_config["configurable"] = {} state_config["configurable"] = {}
logger.debug("Using state_config for agent.get_state(): %s", state_config) logger.debug("Using state_config for agent.get_state(): %s", state_config)
try: try:
state = agent.get_state(state_config) state = agent.get_state(state_config)
logger.debug("Agent state retrieved: %s", state) logger.debug("Agent state retrieved: %s", state)
except Exception as e: except Exception as e:
logger.error("Error retrieving agent state with state_config %s: %s", state_config, e) logger.error(
"Error retrieving agent state with state_config %s: %s", state_config, e
)
raise raise
# If the state indicates that further steps remain (i.e. state.next is non-empty),
# then resume execution by invoking the agent with no new input.
if state.next: if state.next:
logger.debug("State indicates continuation (state.next: %s); resuming execution.", state.next) logger.debug(
agent.invoke(None, config) "State indicates continuation (state.next: %s); resuming execution.",
state.next,
)
agent.invoke(None, stream_config)
continue continue
else: else:
logger.debug("No continuation indicated in state; exiting stream loop.") logger.debug("No continuation indicated in state; exiting stream loop.")
break break
if cb:
logger.debug(f"AnthropicCallbackHandler:\n{cb}")
return True return True
def run_agent_with_retry( def run_agent_with_retry(
agent: RAgents, agent: RAgents,
prompt: str, prompt: str,
@ -1517,7 +1646,9 @@ def run_agent_with_retry(
max_retries = 20 max_retries = 20
base_delay = 1 base_delay = 1
test_attempts = 0 test_attempts = 0
_max_test_retries = get_config_repository().get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES) _max_test_retries = get_config_repository().get(
"max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES
)
auto_test = get_config_repository().get("auto_test", False) auto_test = get_config_repository().get("auto_test", False)
original_prompt = prompt original_prompt = prompt
msg_list = [HumanMessage(content=prompt)] msg_list = [HumanMessage(content=prompt)]

View File

@ -0,0 +1,270 @@
"""Custom callback handlers for tracking token usage and costs."""
import threading
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Dict, List, Optional
from langchain_core.callbacks import BaseCallbackHandler
# Define cost per 1K tokens for various models
ANTHROPIC_MODEL_COSTS = {
# Claude 3.7 Sonnet input
"claude-3-7-sonnet-20250219": 0.003,
"anthropic/claude-3.7-sonnet": 0.003,
"claude-3.7-sonnet": 0.003,
# Claude 3.7 Sonnet output
"claude-3-7-sonnet-20250219-completion": 0.015,
"anthropic/claude-3.7-sonnet-completion": 0.015,
"claude-3.7-sonnet-completion": 0.015,
# Claude 3 Opus input
"claude-3-opus-20240229": 0.015,
"anthropic/claude-3-opus": 0.015,
"claude-3-opus": 0.015,
# Claude 3 Opus output
"claude-3-opus-20240229-completion": 0.075,
"anthropic/claude-3-opus-completion": 0.075,
"claude-3-opus-completion": 0.075,
# Claude 3 Sonnet input
"claude-3-sonnet-20240229": 0.003,
"anthropic/claude-3-sonnet": 0.003,
"claude-3-sonnet": 0.003,
# Claude 3 Sonnet output
"claude-3-sonnet-20240229-completion": 0.015,
"anthropic/claude-3-sonnet-completion": 0.015,
"claude-3-sonnet-completion": 0.015,
# Claude 3 Haiku input
"claude-3-haiku-20240307": 0.00025,
"anthropic/claude-3-haiku": 0.00025,
"claude-3-haiku": 0.00025,
# Claude 3 Haiku output
"claude-3-haiku-20240307-completion": 0.00125,
"anthropic/claude-3-haiku-completion": 0.00125,
"claude-3-haiku-completion": 0.00125,
# Claude 2 input
"claude-2": 0.008,
"claude-2.0": 0.008,
"claude-2.1": 0.008,
# Claude 2 output
"claude-2-completion": 0.024,
"claude-2.0-completion": 0.024,
"claude-2.1-completion": 0.024,
# Claude Instant input
"claude-instant-1": 0.0016,
"claude-instant-1.2": 0.0016,
# Claude Instant output
"claude-instant-1-completion": 0.0055,
"claude-instant-1.2-completion": 0.0055,
}
def standardize_model_name(model_name: str, is_completion: bool = False) -> str:
"""
Standardize the model name to a format that can be used for cost calculation.
Args:
model_name: Model name to standardize.
is_completion: Whether the model is used for completion or not.
Returns:
Standardized model name.
"""
if not model_name:
model_name = "claude-3-sonnet"
model_name = model_name.lower()
# Handle OpenRouter prefixes
if model_name.startswith("anthropic/"):
model_name = model_name[len("anthropic/") :]
# Add completion suffix if needed
if is_completion and not model_name.endswith("-completion"):
model_name = model_name + "-completion"
return model_name
def get_anthropic_token_cost_for_model(
model_name: str, num_tokens: int, is_completion: bool = False
) -> float:
"""
Get the cost in USD for a given model and number of tokens.
Args:
model_name: Name of the model
num_tokens: Number of tokens.
is_completion: Whether the model is used for completion or not.
Returns:
Cost in USD.
"""
model_name = standardize_model_name(model_name, is_completion)
if model_name not in ANTHROPIC_MODEL_COSTS:
# Default to Claude 3 Sonnet pricing if model not found
model_name = (
"claude-3-sonnet" if not is_completion else "claude-3-sonnet-completion"
)
cost_per_1k = ANTHROPIC_MODEL_COSTS[model_name]
total_cost = cost_per_1k * (num_tokens / 1000)
return total_cost
class AnthropicCallbackHandler(BaseCallbackHandler):
"""Callback Handler that tracks Anthropic token usage and costs."""
total_tokens: int = 0
prompt_tokens: int = 0
completion_tokens: int = 0
successful_requests: int = 0
total_cost: float = 0.0
model_name: str = "claude-3-sonnet" # Default model
def __init__(self, model_name: Optional[str] = None) -> None:
super().__init__()
self._lock = threading.Lock()
if model_name:
self.model_name = model_name
# Default costs for Claude 3.7 Sonnet
self.input_cost_per_token = 0.003 / 1000 # $3/M input tokens
self.output_cost_per_token = 0.015 / 1000 # $15/M output tokens
def __repr__(self) -> str:
return (
f"Tokens Used: {self.total_tokens}\n"
f"\tPrompt Tokens: {self.prompt_tokens}\n"
f"\tCompletion Tokens: {self.completion_tokens}\n"
f"Successful Requests: {self.successful_requests}\n"
f"Total Cost (USD): ${self.total_cost:.6f}"
)
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Record the model name if available."""
if "name" in serialized:
self.model_name = serialized["name"]
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Count tokens as they're generated."""
with self._lock:
self.completion_tokens += 1
self.total_tokens += 1
token_cost = get_anthropic_token_cost_for_model(
self.model_name, 1, is_completion=True
)
self.total_cost += token_cost
def on_llm_end(self, response: Any, **kwargs: Any) -> None:
"""Collect token usage from response."""
token_usage = {}
# Try to extract token usage from response
if hasattr(response, "llm_output") and response.llm_output:
llm_output = response.llm_output
if "token_usage" in llm_output:
token_usage = llm_output["token_usage"]
elif "usage" in llm_output:
usage = llm_output["usage"]
# Handle Anthropic's specific usage format
if "input_tokens" in usage:
token_usage["prompt_tokens"] = usage["input_tokens"]
if "output_tokens" in usage:
token_usage["completion_tokens"] = usage["output_tokens"]
# Extract model name if available
if "model_name" in llm_output:
self.model_name = llm_output["model_name"]
# Try to get usage from response.usage
elif hasattr(response, "usage"):
usage = response.usage
if hasattr(usage, "prompt_tokens"):
token_usage["prompt_tokens"] = usage.prompt_tokens
if hasattr(usage, "completion_tokens"):
token_usage["completion_tokens"] = usage.completion_tokens
if hasattr(usage, "total_tokens"):
token_usage["total_tokens"] = usage.total_tokens
# Extract usage from generations if available
elif hasattr(response, "generations") and response.generations:
for gen in response.generations:
if gen and hasattr(gen[0], "generation_info"):
gen_info = gen[0].generation_info or {}
if "usage" in gen_info:
token_usage = gen_info["usage"]
break
# Update counts with lock to prevent race conditions
with self._lock:
prompt_tokens = token_usage.get("prompt_tokens", 0)
completion_tokens = token_usage.get("completion_tokens", 0)
# Only update prompt tokens if we have them
if prompt_tokens > 0:
self.prompt_tokens += prompt_tokens
self.total_tokens += prompt_tokens
prompt_cost = get_anthropic_token_cost_for_model(
self.model_name, prompt_tokens, is_completion=False
)
self.total_cost += prompt_cost
# Only update completion tokens if not already counted by on_llm_new_token
if completion_tokens > 0 and completion_tokens > self.completion_tokens:
additional_tokens = completion_tokens - self.completion_tokens
self.completion_tokens = completion_tokens
self.total_tokens += additional_tokens
completion_cost = get_anthropic_token_cost_for_model(
self.model_name, additional_tokens, is_completion=True
)
self.total_cost += completion_cost
self.successful_requests += 1
def __copy__(self) -> "AnthropicCallbackHandler":
"""Return a copy of the callback handler."""
return self
def __deepcopy__(self, memo: Any) -> "AnthropicCallbackHandler":
"""Return a deep copy of the callback handler."""
return self
# Create a context variable for our custom callback
anthropic_callback_var: ContextVar[Optional[AnthropicCallbackHandler]] = ContextVar(
"anthropic_callback", default=None
)
@contextmanager
def get_anthropic_callback(
model_name: Optional[str] = None,
) -> AnthropicCallbackHandler:
"""Get the Anthropic callback handler in a context manager.
which conveniently exposes token and cost information.
Args:
model_name: Optional model name to use for cost calculation.
Returns:
AnthropicCallbackHandler: The Anthropic callback handler.
Example:
>>> with get_anthropic_callback("claude-3-sonnet") as cb:
... # Use the callback handler
... # cb.total_tokens, cb.total_cost will be available after
"""
cb = AnthropicCallbackHandler(model_name)
anthropic_callback_var.set(cb)
yield cb
anthropic_callback_var.set(None)

910
uv.lock

File diff suppressed because it is too large Load Diff