feat(agent_utils.py): add AnthropicCallbackHandler to track token usage and costs for Anthropic models

style(agent_utils.py): format imports and code for better readability
refactor(agent_utils.py): standardize model name and cost calculation logic for clarity and maintainability
chore(anthropic_callback_handler.py): create a new file for the AnthropicCallbackHandler implementation and related functions
This commit is contained in:
Ariel Frischer 2025-03-10 01:18:44 -07:00
parent d194868cff
commit ddd0e2ae2d
3 changed files with 1040 additions and 623 deletions

View File

@ -10,6 +10,9 @@ import uuid
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Sequence
from ra_aid.callbacks.anthropic_callback_handler import AnthropicCallbackHandler
import litellm
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
from openai import RateLimitError as OpenAIRateLimitError
@ -71,7 +74,11 @@ from ra_aid.prompts.human_prompts import (
from ra_aid.prompts.implementation_prompts import IMPLEMENTATION_PROMPT
from ra_aid.prompts.common_prompts import NEW_PROJECT_HINTS
from ra_aid.prompts.planning_prompts import PLANNING_PROMPT
from ra_aid.prompts.reasoning_assist_prompt import REASONING_ASSIST_PROMPT_PLANNING, REASONING_ASSIST_PROMPT_IMPLEMENTATION, REASONING_ASSIST_PROMPT_RESEARCH
from ra_aid.prompts.reasoning_assist_prompt import (
REASONING_ASSIST_PROMPT_PLANNING,
REASONING_ASSIST_PROMPT_IMPLEMENTATION,
REASONING_ASSIST_PROMPT_RESEARCH,
)
from ra_aid.prompts.research_prompts import (
RESEARCH_ONLY_PROMPT,
RESEARCH_PROMPT,
@ -90,9 +97,15 @@ from ra_aid.tool_configs import (
)
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
from ra_aid.database.repositories.key_snippet_repository import (
get_key_snippet_repository,
)
from ra_aid.database.repositories.human_input_repository import (
get_human_input_repository,
)
from ra_aid.database.repositories.research_note_repository import (
get_research_note_repository,
)
from ra_aid.database.repositories.work_log_repository import get_work_log_repository
from ra_aid.model_formatters import format_key_facts_dict
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
@ -332,7 +345,9 @@ def create_agent(
if is_anthropic_claude(config):
logger.debug("Using create_react_agent to instantiate agent.")
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
return create_react_agent(model, tools, interrupt_after=['tools'], **agent_kwargs)
return create_react_agent(
model, tools, interrupt_after=["tools"], **agent_kwargs
)
else:
logger.debug("Using CiaynAgent agent instance")
return CiaynAgent(model, tools, max_tokens=max_input_tokens, config=config)
@ -343,7 +358,9 @@ def create_agent(
config = get_config_repository().get_all()
max_input_tokens = get_model_token_limit(config, agent_type)
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
return create_react_agent(model, tools, interrupt_after=['tools'], **agent_kwargs)
return create_react_agent(
model, tools, interrupt_after=["tools"], **agent_kwargs
)
def run_research_agent(
@ -406,7 +423,9 @@ def run_research_agent(
recent_inputs = human_input_repository.get_recent(1)
if recent_inputs and len(recent_inputs) > 0:
last_human_input = recent_inputs[0].content
base_task = f"<last human input>{last_human_input}</last human input>\n{base_task}"
base_task = (
f"<last human input>{last_human_input}</last human input>\n{base_task}"
)
except RuntimeError as e:
logger.error(f"Failed to access human input repository: {str(e)}")
# Continue without appending last human input
@ -416,7 +435,9 @@ def run_research_agent(
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = ""
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
key_snippets = format_key_snippets_dict(
get_key_snippet_repository().get_snippets_dict()
)
related_files = get_related_files()
try:
@ -432,20 +453,22 @@ def run_research_agent(
human_interaction=hil,
web_research_enabled=get_config_repository().get("web_research_enabled", False),
)
# Get model info for reasoning assistance configuration
provider = get_config_repository().get("provider", "")
model_name = get_config_repository().get("model", "")
# Get model configuration to check for reasoning_assist_default
model_config = {}
provider_models = models_params.get(provider, {})
if provider_models and model_name in provider_models:
model_config = provider_models[model_name]
# Check if reasoning assist is explicitly enabled/disabled
force_assistance = get_config_repository().get("force_reasoning_assistance", False)
disable_assistance = get_config_repository().get("disable_reasoning_assistance", False)
disable_assistance = get_config_repository().get(
"disable_reasoning_assistance", False
)
if force_assistance:
reasoning_assist_enabled = True
elif disable_assistance:
@ -453,26 +476,31 @@ def run_research_agent(
else:
# Fall back to model default
reasoning_assist_enabled = model_config.get("reasoning_assist_default", False)
logger.debug("Reasoning assist enabled: %s", reasoning_assist_enabled)
expert_guidance = ""
# Get research note information for reasoning assistance
try:
research_notes = format_research_notes_dict(get_research_note_repository().get_notes_dict())
research_notes = format_research_notes_dict(
get_research_note_repository().get_notes_dict()
)
except Exception as e:
logger.warning(f"Failed to get research notes: {e}")
research_notes = ""
# If reasoning assist is enabled, make a one-off call to the expert model
if reasoning_assist_enabled:
try:
logger.info("Reasoning assist enabled for model %s, getting expert guidance", model_name)
logger.info(
"Reasoning assist enabled for model %s, getting expert guidance",
model_name,
)
# Collect tool descriptions
tool_metadata = []
from ra_aid.tools.reflection import get_function_info as get_tool_info
for tool in tools:
try:
tool_info = get_tool_info(tool.func)
@ -481,13 +509,13 @@ def run_research_agent(
tool_metadata.append(f"Tool: {name}\nDescription: {description}\n")
except Exception as e:
logger.warning(f"Error getting tool info for {tool}: {e}")
# Format tool metadata
formatted_tool_metadata = "\n".join(tool_metadata)
# Initialize expert model
expert_model = initialize_expert_llm(provider, model_name)
# Format the reasoning assist prompt
reasoning_assist_prompt = REASONING_ASSIST_PROMPT_RESEARCH.format(
current_date=current_date,
@ -500,62 +528,78 @@ def run_research_agent(
env_inv=get_env_inv(),
tool_metadata=formatted_tool_metadata,
)
# Show the reasoning assist query in a panel
console.print(
Panel(Markdown("Consulting with the reasoning model on the best research approach."), title="📝 Thinking about research strategy...", border_style="yellow")
Panel(
Markdown(
"Consulting with the reasoning model on the best research approach."
),
title="📝 Thinking about research strategy...",
border_style="yellow",
)
)
logger.debug("Invoking expert model for reasoning assist")
# Make the call to the expert model
response = expert_model.invoke(reasoning_assist_prompt)
# Check if the model supports think tags
supports_think_tag = model_config.get("supports_think_tag", False)
supports_thinking = model_config.get("supports_thinking", False)
# Get response content, handling if it's a list (for Claude thinking mode)
content = None
if hasattr(response, 'content'):
if hasattr(response, "content"):
content = response.content
else:
# Fallback if content attribute is missing
content = str(response)
# Process content based on its type
if isinstance(content, list):
# Handle structured thinking mode (e.g., Claude 3.7)
thinking_content = None
response_text = None
# Process each item in the list
for item in content:
if isinstance(item, dict):
# Extract thinking content
if item.get('type') == 'thinking' and 'thinking' in item:
thinking_content = item['thinking']
if item.get("type") == "thinking" and "thinking" in item:
thinking_content = item["thinking"]
logger.debug("Found structured thinking content")
# Extract response text
elif item.get('type') == 'text' and 'text' in item:
response_text = item['text']
elif item.get("type") == "text" and "text" in item:
response_text = item["text"]
logger.debug("Found structured response text")
# Display thinking content in a separate panel if available
if thinking_content and get_config_repository().get("show_thoughts", False):
logger.debug(f"Displaying structured thinking content ({len(thinking_content)} chars)")
console.print(
Panel(Markdown(thinking_content), title="💭 Expert Thinking", border_style="yellow")
if thinking_content and get_config_repository().get(
"show_thoughts", False
):
logger.debug(
f"Displaying structured thinking content ({len(thinking_content)} chars)"
)
console.print(
Panel(
Markdown(thinking_content),
title="💭 Expert Thinking",
border_style="yellow",
)
)
# Use response_text if available, otherwise fall back to joining
if response_text:
content = response_text
else:
# Fallback: join list items if structured extraction failed
logger.debug("No structured response text found, joining list items")
logger.debug(
"No structured response text found, joining list items"
)
content = "\n".join(str(item) for item in content)
elif (supports_think_tag or supports_thinking):
elif supports_think_tag or supports_thinking:
# Process thinking content using the centralized function
content, _ = process_thinking_content(
content=content,
@ -563,22 +607,28 @@ def run_research_agent(
supports_thinking=supports_thinking,
panel_title="💭 Expert Thinking",
panel_style="yellow",
logger=logger
logger=logger,
)
# Display the expert guidance in a panel
console.print(
Panel(Markdown(content), title="Research Strategy Guidance", border_style="blue")
Panel(
Markdown(content),
title="Research Strategy Guidance",
border_style="blue",
)
)
# Use the content as expert guidance
expert_guidance = content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY DURING RESEARCH"
expert_guidance = (
content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY DURING RESEARCH"
)
logger.info("Received expert guidance for research")
except Exception as e:
logger.error("Error getting expert guidance for research: %s", e)
expert_guidance = ""
agent = create_agent(model, tools, checkpointer=memory, agent_type="research")
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
@ -588,7 +638,7 @@ def run_research_agent(
if get_config_repository().get("web_research_enabled")
else ""
)
# Prepare expert guidance section if expert guidance is available
expert_guidance_section = ""
if expert_guidance:
@ -600,7 +650,7 @@ def run_research_agent(
# We get research notes earlier for reasoning assistance
# Get environment inventory information
prompt = (RESEARCH_ONLY_PROMPT if research_only else RESEARCH_PROMPT).format(
current_date=current_date,
working_directory=working_directory,
@ -643,9 +693,7 @@ def run_research_agent(
if agent is not None:
logger.debug("Research agent created successfully")
none_or_fallback_handler = init_fallback_handler(agent, tools)
_result = run_agent_with_retry(
agent, prompt, none_or_fallback_handler
)
_result = run_agent_with_retry(agent, prompt, none_or_fallback_handler)
if _result:
# Log research completion
log_work_event(f"Completed research phase for: {base_task_or_query}")
@ -731,7 +779,9 @@ def run_web_research_agent(
logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = ""
try:
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
key_snippets = format_key_snippets_dict(
get_key_snippet_repository().get_snippets_dict()
)
except RuntimeError as e:
logger.error(f"Failed to access key snippet repository: {str(e)}")
key_snippets = ""
@ -741,7 +791,7 @@ def run_web_research_agent(
working_directory = os.getcwd()
# Get environment inventory information
prompt = WEB_RESEARCH_PROMPT.format(
current_date=current_date,
working_directory=working_directory,
@ -771,9 +821,7 @@ def run_web_research_agent(
logger.debug("Web research agent completed successfully")
none_or_fallback_handler = init_fallback_handler(agent, tools)
_result = run_agent_with_retry(
agent, prompt, none_or_fallback_handler
)
_result = run_agent_with_retry(agent, prompt, none_or_fallback_handler)
if _result:
# Log web research completion
log_work_event(f"Completed web research phase for: {query}")
@ -835,17 +883,19 @@ def run_planning_agent(
provider = get_config_repository().get("provider", "")
model_name = get_config_repository().get("model", "")
logger.debug("Checking for reasoning_assist_default on %s/%s", provider, model_name)
# Get model configuration to check for reasoning_assist_default
model_config = {}
provider_models = models_params.get(provider, {})
if provider_models and model_name in provider_models:
model_config = provider_models[model_name]
# Check if reasoning assist is explicitly enabled/disabled
force_assistance = get_config_repository().get("force_reasoning_assistance", False)
disable_assistance = get_config_repository().get("disable_reasoning_assistance", False)
disable_assistance = get_config_repository().get(
"disable_reasoning_assistance", False
)
if force_assistance:
reasoning_assist_enabled = True
elif disable_assistance:
@ -853,27 +903,29 @@ def run_planning_agent(
else:
# Fall back to model default
reasoning_assist_enabled = model_config.get("reasoning_assist_default", False)
logger.debug("Reasoning assist enabled: %s", reasoning_assist_enabled)
# Get all the context information (used both for normal planning and reasoning assist)
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
working_directory = os.getcwd()
# Make sure key_facts is defined before using it
try:
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = ""
# Make sure key_snippets is defined before using it
try:
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
key_snippets = format_key_snippets_dict(
get_key_snippet_repository().get_snippets_dict()
)
except RuntimeError as e:
logger.error(f"Failed to access key snippet repository: {str(e)}")
key_snippets = ""
# Get formatted research notes using repository
try:
repository = get_research_note_repository()
@ -882,28 +934,31 @@ def run_planning_agent(
except RuntimeError as e:
logger.error(f"Failed to access research note repository: {str(e)}")
formatted_research_notes = ""
# Get related files
related_files = "\n".join(get_related_files())
# Get environment inventory information
env_inv = get_env_inv()
# Display the planning stage header before any reasoning assistance
print_stage_header("Planning Stage")
# Initialize expert guidance section
expert_guidance = ""
# If reasoning assist is enabled, make a one-off call to the expert model
if reasoning_assist_enabled:
try:
logger.info("Reasoning assist enabled for model %s, getting expert guidance", model_name)
logger.info(
"Reasoning assist enabled for model %s, getting expert guidance",
model_name,
)
# Collect tool descriptions
tool_metadata = []
from ra_aid.tools.reflection import get_function_info as get_tool_info
for tool in tools:
try:
tool_info = get_tool_info(tool.func)
@ -912,13 +967,13 @@ def run_planning_agent(
tool_metadata.append(f"Tool: {name}\nDescription: {description}\n")
except Exception as e:
logger.warning(f"Error getting tool info for {tool}: {e}")
# Format tool metadata
formatted_tool_metadata = "\n".join(tool_metadata)
# Initialize expert model
expert_model = initialize_expert_llm(provider, model_name)
# Format the reasoning assist prompt
reasoning_assist_prompt = REASONING_ASSIST_PROMPT_PLANNING.format(
current_date=current_date,
@ -931,62 +986,78 @@ def run_planning_agent(
env_inv=env_inv,
tool_metadata=formatted_tool_metadata,
)
# Show the reasoning assist query in a panel
console.print(
Panel(Markdown("Consulting with the reasoning model on the best way to do this."), title="📝 Thinking about the plan...", border_style="yellow")
Panel(
Markdown(
"Consulting with the reasoning model on the best way to do this."
),
title="📝 Thinking about the plan...",
border_style="yellow",
)
)
logger.debug("Invoking expert model for reasoning assist")
# Make the call to the expert model
response = expert_model.invoke(reasoning_assist_prompt)
# Check if the model supports think tags
supports_think_tag = model_config.get("supports_think_tag", False)
supports_thinking = model_config.get("supports_thinking", False)
# Get response content, handling if it's a list (for Claude thinking mode)
content = None
if hasattr(response, 'content'):
if hasattr(response, "content"):
content = response.content
else:
# Fallback if content attribute is missing
content = str(response)
# Process content based on its type
if isinstance(content, list):
# Handle structured thinking mode (e.g., Claude 3.7)
thinking_content = None
response_text = None
# Process each item in the list
for item in content:
if isinstance(item, dict):
# Extract thinking content
if item.get('type') == 'thinking' and 'thinking' in item:
thinking_content = item['thinking']
if item.get("type") == "thinking" and "thinking" in item:
thinking_content = item["thinking"]
logger.debug("Found structured thinking content")
# Extract response text
elif item.get('type') == 'text' and 'text' in item:
response_text = item['text']
elif item.get("type") == "text" and "text" in item:
response_text = item["text"]
logger.debug("Found structured response text")
# Display thinking content in a separate panel if available
if thinking_content and get_config_repository().get("show_thoughts", False):
logger.debug(f"Displaying structured thinking content ({len(thinking_content)} chars)")
console.print(
Panel(Markdown(thinking_content), title="💭 Expert Thinking", border_style="yellow")
if thinking_content and get_config_repository().get(
"show_thoughts", False
):
logger.debug(
f"Displaying structured thinking content ({len(thinking_content)} chars)"
)
console.print(
Panel(
Markdown(thinking_content),
title="💭 Expert Thinking",
border_style="yellow",
)
)
# Use response_text if available, otherwise fall back to joining
if response_text:
content = response_text
else:
# Fallback: join list items if structured extraction failed
logger.debug("No structured response text found, joining list items")
logger.debug(
"No structured response text found, joining list items"
)
content = "\n".join(str(item) for item in content)
elif (supports_think_tag or supports_thinking):
elif supports_think_tag or supports_thinking:
# Process thinking content using the centralized function
content, _ = process_thinking_content(
content=content,
@ -994,24 +1065,28 @@ def run_planning_agent(
supports_thinking=supports_thinking,
panel_title="💭 Expert Thinking",
panel_style="yellow",
logger=logger
logger=logger,
)
# Display the expert guidance in a panel
console.print(
Panel(Markdown(content), title="Reasoning Guidance", border_style="blue")
Panel(
Markdown(content), title="Reasoning Guidance", border_style="blue"
)
)
# Use the content as expert guidance
expert_guidance = content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY ON THIS TASK"
expert_guidance = (
content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY ON THIS TASK"
)
logger.info("Received expert guidance for planning")
except Exception as e:
logger.error("Error getting expert guidance for planning: %s", e)
expert_guidance = ""
agent = create_agent(model, tools, checkpointer=memory, agent_type="planner")
expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else ""
human_section = HUMAN_PROMPT_SECTION_PLANNING if hil else ""
web_research_section = (
@ -1019,7 +1094,7 @@ def run_planning_agent(
if get_config_repository().get("web_research_enabled", False)
else ""
)
# Prepare expert guidance section if expert guidance is available
expert_guidance_section = ""
if expert_guidance:
@ -1050,7 +1125,9 @@ def run_planning_agent(
)
config_values = get_config_repository().get_all()
recursion_limit = get_config_repository().get("recursion_limit", DEFAULT_RECURSION_LIMIT)
recursion_limit = get_config_repository().get(
"recursion_limit", DEFAULT_RECURSION_LIMIT
)
run_config = {
"configurable": {"thread_id": thread_id},
"recursion_limit": recursion_limit,
@ -1060,9 +1137,7 @@ def run_planning_agent(
try:
logger.debug("Planning agent completed successfully")
none_or_fallback_handler = init_fallback_handler(agent, tools)
_result = run_agent_with_retry(
agent, planning_prompt, none_or_fallback_handler
)
_result = run_agent_with_retry(agent, planning_prompt, none_or_fallback_handler)
if _result:
# Log planning completion
log_work_event(f"Completed planning phase for: {base_task}")
@ -1135,7 +1210,7 @@ def run_task_implementation_agent(
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
key_facts = ""
# Get formatted research notes using repository
try:
repository = get_research_note_repository()
@ -1144,7 +1219,7 @@ def run_task_implementation_agent(
except RuntimeError as e:
logger.error(f"Failed to access research note repository: {str(e)}")
formatted_research_notes = ""
# Get latest project info
try:
project_info = get_project_info(".")
@ -1152,24 +1227,26 @@ def run_task_implementation_agent(
except Exception as e:
logger.warning("Failed to get project info: %s", str(e))
formatted_project_info = "Project info unavailable"
# Get environment inventory information
env_inv = get_env_inv()
# Get model configuration to check for reasoning_assist_default
provider = get_config_repository().get("provider", "")
model_name = get_config_repository().get("model", "")
logger.debug("Checking for reasoning_assist_default on %s/%s", provider, model_name)
model_config = {}
provider_models = models_params.get(provider, {})
if provider_models and model_name in provider_models:
model_config = provider_models[model_name]
# Check if reasoning assist is explicitly enabled/disabled
force_assistance = get_config_repository().get("force_reasoning_assistance", False)
disable_assistance = get_config_repository().get("disable_reasoning_assistance", False)
disable_assistance = get_config_repository().get(
"disable_reasoning_assistance", False
)
if force_assistance:
reasoning_assist_enabled = True
elif disable_assistance:
@ -1177,71 +1254,84 @@ def run_task_implementation_agent(
else:
# Fall back to model default
reasoning_assist_enabled = model_config.get("reasoning_assist_default", False)
logger.debug("Reasoning assist enabled: %s", reasoning_assist_enabled)
# Initialize implementation guidance section
implementation_guidance_section = ""
# If reasoning assist is enabled, make a one-off call to the expert model
if reasoning_assist_enabled:
try:
logger.info("Reasoning assist enabled for model %s, getting implementation guidance", model_name)
logger.info(
"Reasoning assist enabled for model %s, getting implementation guidance",
model_name,
)
# Collect tool descriptions
tool_metadata = []
from ra_aid.tools.reflection import get_function_info as get_tool_info
for tool in tools:
try:
tool_info = get_tool_info(tool.func)
name = tool.func.__name__
description = inspect.getdoc(tool.func)
tool_metadata.append(f"Tool: {name}\\nDescription: {description}\\n")
tool_metadata.append(
f"Tool: {name}\\nDescription: {description}\\n"
)
except Exception as e:
logger.warning(f"Error getting tool info for {tool}: {e}")
# Format tool metadata
formatted_tool_metadata = "\\n".join(tool_metadata)
# Initialize expert model
expert_model = initialize_expert_llm(provider, model_name)
# Format the reasoning assist prompt for implementation
reasoning_assist_prompt = REASONING_ASSIST_PROMPT_IMPLEMENTATION.format(
current_date=current_date,
working_directory=working_directory,
task=task,
key_facts=key_facts,
key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()),
key_snippets=format_key_snippets_dict(
get_key_snippet_repository().get_snippets_dict()
),
research_notes=formatted_research_notes,
related_files="\\n".join(related_files),
env_inv=env_inv,
tool_metadata=formatted_tool_metadata,
)
# Show the reasoning assist query in a panel
console.print(
Panel(Markdown("Consulting with the reasoning model on the best implementation approach."), title="📝 Thinking about implementation...", border_style="yellow")
Panel(
Markdown(
"Consulting with the reasoning model on the best implementation approach."
),
title="📝 Thinking about implementation...",
border_style="yellow",
)
)
logger.debug("Invoking expert model for implementation reasoning assist")
# Make the call to the expert model
response = expert_model.invoke(reasoning_assist_prompt)
# Check if the model supports think tags
supports_think_tag = model_config.get("supports_think_tag", False)
supports_thinking = model_config.get("supports_thinking", False)
# Process response content
content = None
if hasattr(response, 'content'):
if hasattr(response, "content"):
content = response.content
else:
# Fallback if content attribute is missing
content = str(response)
# Process the response content using the centralized function
content, extracted_thinking = process_thinking_content(
content=content,
@ -1249,24 +1339,28 @@ def run_task_implementation_agent(
supports_thinking=supports_thinking,
panel_title="💭 Implementation Thinking",
panel_style="yellow",
logger=logger
logger=logger,
)
# Display the implementation guidance in a panel
console.print(
Panel(Markdown(content), title="Implementation Guidance", border_style="blue")
Panel(
Markdown(content),
title="Implementation Guidance",
border_style="blue",
)
)
# Format the implementation guidance section for the prompt
implementation_guidance_section = f"""<implementation guidance>
{content}
</implementation guidance>"""
logger.info("Received implementation guidance")
except Exception as e:
logger.error("Error getting implementation guidance: %s", e)
implementation_guidance_section = ""
prompt = IMPLEMENTATION_PROMPT.format(
current_date=current_date,
working_directory=working_directory,
@ -1276,7 +1370,9 @@ def run_task_implementation_agent(
plan=plan,
related_files=related_files,
key_facts=key_facts,
key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()),
key_snippets=format_key_snippets_dict(
get_key_snippet_repository().get_snippets_dict()
),
research_notes=formatted_research_notes,
work_log=get_work_log_repository().format_work_log(),
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
@ -1296,7 +1392,9 @@ def run_task_implementation_agent(
)
config_values = get_config_repository().get_all()
recursion_limit = get_config_repository().get("recursion_limit", DEFAULT_RECURSION_LIMIT)
recursion_limit = get_config_repository().get(
"recursion_limit", DEFAULT_RECURSION_LIMIT
)
run_config = {
"configurable": {"thread_id": thread_id},
"recursion_limit": recursion_limit,
@ -1306,9 +1404,7 @@ def run_task_implementation_agent(
try:
logger.debug("Implementation agent completed successfully")
none_or_fallback_handler = init_fallback_handler(agent, tools)
_result = run_agent_with_retry(
agent, prompt, none_or_fallback_handler
)
_result = run_agent_with_retry(agent, prompt, none_or_fallback_handler)
if _result:
# Log task implementation completion
log_work_event(f"Completed implementation of task: {task}")
@ -1380,27 +1476,37 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
# 1. Check if this is a ValueError with 429 code or rate limit phrases
if isinstance(e, ValueError):
error_str = str(e).lower()
rate_limit_phrases = ["429", "rate limit", "too many requests", "quota exceeded"]
if "code" not in error_str and not any(phrase in error_str for phrase in rate_limit_phrases):
rate_limit_phrases = [
"429",
"rate limit",
"too many requests",
"quota exceeded",
]
if "code" not in error_str and not any(
phrase in error_str for phrase in rate_limit_phrases
):
raise e
# 2. Check for status_code or http_status attribute equal to 429
if hasattr(e, 'status_code') and e.status_code == 429:
if hasattr(e, "status_code") and e.status_code == 429:
pass # This is a rate limit error, continue with retry logic
elif hasattr(e, 'http_status') and e.http_status == 429:
elif hasattr(e, "http_status") and e.http_status == 429:
pass # This is a rate limit error, continue with retry logic
# 3. Check for rate limit phrases in error message
elif isinstance(e, Exception) and not isinstance(e, ValueError):
error_str = str(e).lower()
if not any(phrase in error_str for phrase in ["rate limit", "too many requests", "quota exceeded", "429"]) and not ("rate" in error_str and "limit" in error_str):
if not any(
phrase in error_str
for phrase in ["rate limit", "too many requests", "quota exceeded", "429"]
) and not ("rate" in error_str and "limit" in error_str):
# This doesn't look like a rate limit error, but we'll still retry other API errors
pass
# Apply common retry logic for all identified errors
if attempt == max_retries - 1:
logger.error("Max retries reached, failing: %s", str(e))
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {e}")
logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e))
delay = base_delay * (2**attempt)
print_error(
@ -1457,55 +1563,78 @@ def _handle_fallback_response(
def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage]):
"""
Streams agent output while handling completion and interruption.
For each chunk, it logs the output, calls check_interrupt(), prints agent output,
and then checks if is_completed() or should_exit() are true. If so, it resets completion
flags and returns. After finishing a stream iteration (i.e. the for-loop over chunks),
the function retrieves the agent's state. If the state indicates further steps (i.e. state.next is non-empty),
it resumes execution via agent.invoke(None, config); otherwise, it exits the loop.
This function adheres to the latest LangGraph best practices (as of March 2025) for handling
human-in-the-loop interruptions using interrupt_after=["tools"].
"""
config = get_config_repository().get_all()
stream_config = config.copy()
cb = None
if is_anthropic_claude(config):
model_name = config.get("model", "")
full_model_name = model_name
cb = AnthropicCallbackHandler(full_model_name)
if "callbacks" not in stream_config:
stream_config["callbacks"] = []
stream_config["callbacks"].append(cb)
while True:
# Process each chunk from the agent stream.
for chunk in agent.stream({"messages": msg_list}, config):
for chunk in agent.stream({"messages": msg_list}, stream_config):
logger.debug("Agent output: %s", chunk)
check_interrupt()
agent_type = get_agent_type(agent)
print_agent_output(chunk, agent_type)
if is_completed() or should_exit():
reset_completion_flags()
return True # Exit immediately when finished or signaled to exit.
if cb:
logger.debug(f"AnthropicCallbackHandler:\n{cb}")
return True
logger.debug("Stream iteration ended; checking agent state for continuation.")
# Prepare state configuration, ensuring 'configurable' is present.
state_config = get_config_repository().get_all().copy()
if "configurable" not in state_config:
logger.debug("Key 'configurable' not found in config; adding it as an empty dict.")
logger.debug(
"Key 'configurable' not found in config; adding it as an empty dict."
)
state_config["configurable"] = {}
logger.debug("Using state_config for agent.get_state(): %s", state_config)
try:
state = agent.get_state(state_config)
logger.debug("Agent state retrieved: %s", state)
except Exception as e:
logger.error("Error retrieving agent state with state_config %s: %s", state_config, e)
logger.error(
"Error retrieving agent state with state_config %s: %s", state_config, e
)
raise
# If the state indicates that further steps remain (i.e. state.next is non-empty),
# then resume execution by invoking the agent with no new input.
if state.next:
logger.debug("State indicates continuation (state.next: %s); resuming execution.", state.next)
agent.invoke(None, config)
logger.debug(
"State indicates continuation (state.next: %s); resuming execution.",
state.next,
)
agent.invoke(None, stream_config)
continue
else:
logger.debug("No continuation indicated in state; exiting stream loop.")
break
if cb:
logger.debug(f"AnthropicCallbackHandler:\n{cb}")
return True
def run_agent_with_retry(
agent: RAgents,
prompt: str,
@ -1517,7 +1646,9 @@ def run_agent_with_retry(
max_retries = 20
base_delay = 1
test_attempts = 0
_max_test_retries = get_config_repository().get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES)
_max_test_retries = get_config_repository().get(
"max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES
)
auto_test = get_config_repository().get("auto_test", False)
original_prompt = prompt
msg_list = [HumanMessage(content=prompt)]

View File

@ -0,0 +1,270 @@
"""Custom callback handlers for tracking token usage and costs."""
import threading
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Dict, List, Optional
from langchain_core.callbacks import BaseCallbackHandler
# Define cost per 1K tokens for various models
ANTHROPIC_MODEL_COSTS = {
# Claude 3.7 Sonnet input
"claude-3-7-sonnet-20250219": 0.003,
"anthropic/claude-3.7-sonnet": 0.003,
"claude-3.7-sonnet": 0.003,
# Claude 3.7 Sonnet output
"claude-3-7-sonnet-20250219-completion": 0.015,
"anthropic/claude-3.7-sonnet-completion": 0.015,
"claude-3.7-sonnet-completion": 0.015,
# Claude 3 Opus input
"claude-3-opus-20240229": 0.015,
"anthropic/claude-3-opus": 0.015,
"claude-3-opus": 0.015,
# Claude 3 Opus output
"claude-3-opus-20240229-completion": 0.075,
"anthropic/claude-3-opus-completion": 0.075,
"claude-3-opus-completion": 0.075,
# Claude 3 Sonnet input
"claude-3-sonnet-20240229": 0.003,
"anthropic/claude-3-sonnet": 0.003,
"claude-3-sonnet": 0.003,
# Claude 3 Sonnet output
"claude-3-sonnet-20240229-completion": 0.015,
"anthropic/claude-3-sonnet-completion": 0.015,
"claude-3-sonnet-completion": 0.015,
# Claude 3 Haiku input
"claude-3-haiku-20240307": 0.00025,
"anthropic/claude-3-haiku": 0.00025,
"claude-3-haiku": 0.00025,
# Claude 3 Haiku output
"claude-3-haiku-20240307-completion": 0.00125,
"anthropic/claude-3-haiku-completion": 0.00125,
"claude-3-haiku-completion": 0.00125,
# Claude 2 input
"claude-2": 0.008,
"claude-2.0": 0.008,
"claude-2.1": 0.008,
# Claude 2 output
"claude-2-completion": 0.024,
"claude-2.0-completion": 0.024,
"claude-2.1-completion": 0.024,
# Claude Instant input
"claude-instant-1": 0.0016,
"claude-instant-1.2": 0.0016,
# Claude Instant output
"claude-instant-1-completion": 0.0055,
"claude-instant-1.2-completion": 0.0055,
}
def standardize_model_name(model_name: str, is_completion: bool = False) -> str:
"""
Standardize the model name to a format that can be used for cost calculation.
Args:
model_name: Model name to standardize.
is_completion: Whether the model is used for completion or not.
Returns:
Standardized model name.
"""
if not model_name:
model_name = "claude-3-sonnet"
model_name = model_name.lower()
# Handle OpenRouter prefixes
if model_name.startswith("anthropic/"):
model_name = model_name[len("anthropic/") :]
# Add completion suffix if needed
if is_completion and not model_name.endswith("-completion"):
model_name = model_name + "-completion"
return model_name
def get_anthropic_token_cost_for_model(
model_name: str, num_tokens: int, is_completion: bool = False
) -> float:
"""
Get the cost in USD for a given model and number of tokens.
Args:
model_name: Name of the model
num_tokens: Number of tokens.
is_completion: Whether the model is used for completion or not.
Returns:
Cost in USD.
"""
model_name = standardize_model_name(model_name, is_completion)
if model_name not in ANTHROPIC_MODEL_COSTS:
# Default to Claude 3 Sonnet pricing if model not found
model_name = (
"claude-3-sonnet" if not is_completion else "claude-3-sonnet-completion"
)
cost_per_1k = ANTHROPIC_MODEL_COSTS[model_name]
total_cost = cost_per_1k * (num_tokens / 1000)
return total_cost
class AnthropicCallbackHandler(BaseCallbackHandler):
"""Callback Handler that tracks Anthropic token usage and costs."""
total_tokens: int = 0
prompt_tokens: int = 0
completion_tokens: int = 0
successful_requests: int = 0
total_cost: float = 0.0
model_name: str = "claude-3-sonnet" # Default model
def __init__(self, model_name: Optional[str] = None) -> None:
super().__init__()
self._lock = threading.Lock()
if model_name:
self.model_name = model_name
# Default costs for Claude 3.7 Sonnet
self.input_cost_per_token = 0.003 / 1000 # $3/M input tokens
self.output_cost_per_token = 0.015 / 1000 # $15/M output tokens
def __repr__(self) -> str:
return (
f"Tokens Used: {self.total_tokens}\n"
f"\tPrompt Tokens: {self.prompt_tokens}\n"
f"\tCompletion Tokens: {self.completion_tokens}\n"
f"Successful Requests: {self.successful_requests}\n"
f"Total Cost (USD): ${self.total_cost:.6f}"
)
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Record the model name if available."""
if "name" in serialized:
self.model_name = serialized["name"]
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Count tokens as they're generated."""
with self._lock:
self.completion_tokens += 1
self.total_tokens += 1
token_cost = get_anthropic_token_cost_for_model(
self.model_name, 1, is_completion=True
)
self.total_cost += token_cost
def on_llm_end(self, response: Any, **kwargs: Any) -> None:
"""Collect token usage from response."""
token_usage = {}
# Try to extract token usage from response
if hasattr(response, "llm_output") and response.llm_output:
llm_output = response.llm_output
if "token_usage" in llm_output:
token_usage = llm_output["token_usage"]
elif "usage" in llm_output:
usage = llm_output["usage"]
# Handle Anthropic's specific usage format
if "input_tokens" in usage:
token_usage["prompt_tokens"] = usage["input_tokens"]
if "output_tokens" in usage:
token_usage["completion_tokens"] = usage["output_tokens"]
# Extract model name if available
if "model_name" in llm_output:
self.model_name = llm_output["model_name"]
# Try to get usage from response.usage
elif hasattr(response, "usage"):
usage = response.usage
if hasattr(usage, "prompt_tokens"):
token_usage["prompt_tokens"] = usage.prompt_tokens
if hasattr(usage, "completion_tokens"):
token_usage["completion_tokens"] = usage.completion_tokens
if hasattr(usage, "total_tokens"):
token_usage["total_tokens"] = usage.total_tokens
# Extract usage from generations if available
elif hasattr(response, "generations") and response.generations:
for gen in response.generations:
if gen and hasattr(gen[0], "generation_info"):
gen_info = gen[0].generation_info or {}
if "usage" in gen_info:
token_usage = gen_info["usage"]
break
# Update counts with lock to prevent race conditions
with self._lock:
prompt_tokens = token_usage.get("prompt_tokens", 0)
completion_tokens = token_usage.get("completion_tokens", 0)
# Only update prompt tokens if we have them
if prompt_tokens > 0:
self.prompt_tokens += prompt_tokens
self.total_tokens += prompt_tokens
prompt_cost = get_anthropic_token_cost_for_model(
self.model_name, prompt_tokens, is_completion=False
)
self.total_cost += prompt_cost
# Only update completion tokens if not already counted by on_llm_new_token
if completion_tokens > 0 and completion_tokens > self.completion_tokens:
additional_tokens = completion_tokens - self.completion_tokens
self.completion_tokens = completion_tokens
self.total_tokens += additional_tokens
completion_cost = get_anthropic_token_cost_for_model(
self.model_name, additional_tokens, is_completion=True
)
self.total_cost += completion_cost
self.successful_requests += 1
def __copy__(self) -> "AnthropicCallbackHandler":
"""Return a copy of the callback handler."""
return self
def __deepcopy__(self, memo: Any) -> "AnthropicCallbackHandler":
"""Return a deep copy of the callback handler."""
return self
# Create a context variable for our custom callback
anthropic_callback_var: ContextVar[Optional[AnthropicCallbackHandler]] = ContextVar(
"anthropic_callback", default=None
)
@contextmanager
def get_anthropic_callback(
model_name: Optional[str] = None,
) -> AnthropicCallbackHandler:
"""Get the Anthropic callback handler in a context manager.
which conveniently exposes token and cost information.
Args:
model_name: Optional model name to use for cost calculation.
Returns:
AnthropicCallbackHandler: The Anthropic callback handler.
Example:
>>> with get_anthropic_callback("claude-3-sonnet") as cb:
... # Use the callback handler
... # cb.total_tokens, cb.total_cost will be available after
"""
cb = AnthropicCallbackHandler(model_name)
anthropic_callback_var.set(cb)
yield cb
anthropic_callback_var.set(None)

910
uv.lock

File diff suppressed because it is too large Load Diff