feat(agent_utils.py): add AnthropicCallbackHandler to track token usage and costs for Anthropic models
style(agent_utils.py): format imports and code for better readability refactor(agent_utils.py): standardize model name and cost calculation logic for clarity and maintainability chore(anthropic_callback_handler.py): create a new file for the AnthropicCallbackHandler implementation and related functions
This commit is contained in:
parent
d194868cff
commit
ddd0e2ae2d
|
|
@ -10,6 +10,9 @@ import uuid
|
|||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||
|
||||
from ra_aid.callbacks.anthropic_callback_handler import AnthropicCallbackHandler
|
||||
|
||||
|
||||
import litellm
|
||||
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
||||
from openai import RateLimitError as OpenAIRateLimitError
|
||||
|
|
@ -71,7 +74,11 @@ from ra_aid.prompts.human_prompts import (
|
|||
from ra_aid.prompts.implementation_prompts import IMPLEMENTATION_PROMPT
|
||||
from ra_aid.prompts.common_prompts import NEW_PROJECT_HINTS
|
||||
from ra_aid.prompts.planning_prompts import PLANNING_PROMPT
|
||||
from ra_aid.prompts.reasoning_assist_prompt import REASONING_ASSIST_PROMPT_PLANNING, REASONING_ASSIST_PROMPT_IMPLEMENTATION, REASONING_ASSIST_PROMPT_RESEARCH
|
||||
from ra_aid.prompts.reasoning_assist_prompt import (
|
||||
REASONING_ASSIST_PROMPT_PLANNING,
|
||||
REASONING_ASSIST_PROMPT_IMPLEMENTATION,
|
||||
REASONING_ASSIST_PROMPT_RESEARCH,
|
||||
)
|
||||
from ra_aid.prompts.research_prompts import (
|
||||
RESEARCH_ONLY_PROMPT,
|
||||
RESEARCH_PROMPT,
|
||||
|
|
@ -90,9 +97,15 @@ from ra_aid.tool_configs import (
|
|||
)
|
||||
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
|
||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
||||
from ra_aid.database.repositories.key_snippet_repository import (
|
||||
get_key_snippet_repository,
|
||||
)
|
||||
from ra_aid.database.repositories.human_input_repository import (
|
||||
get_human_input_repository,
|
||||
)
|
||||
from ra_aid.database.repositories.research_note_repository import (
|
||||
get_research_note_repository,
|
||||
)
|
||||
from ra_aid.database.repositories.work_log_repository import get_work_log_repository
|
||||
from ra_aid.model_formatters import format_key_facts_dict
|
||||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||
|
|
@ -332,7 +345,9 @@ def create_agent(
|
|||
if is_anthropic_claude(config):
|
||||
logger.debug("Using create_react_agent to instantiate agent.")
|
||||
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
|
||||
return create_react_agent(model, tools, interrupt_after=['tools'], **agent_kwargs)
|
||||
return create_react_agent(
|
||||
model, tools, interrupt_after=["tools"], **agent_kwargs
|
||||
)
|
||||
else:
|
||||
logger.debug("Using CiaynAgent agent instance")
|
||||
return CiaynAgent(model, tools, max_tokens=max_input_tokens, config=config)
|
||||
|
|
@ -343,7 +358,9 @@ def create_agent(
|
|||
config = get_config_repository().get_all()
|
||||
max_input_tokens = get_model_token_limit(config, agent_type)
|
||||
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
|
||||
return create_react_agent(model, tools, interrupt_after=['tools'], **agent_kwargs)
|
||||
return create_react_agent(
|
||||
model, tools, interrupt_after=["tools"], **agent_kwargs
|
||||
)
|
||||
|
||||
|
||||
def run_research_agent(
|
||||
|
|
@ -406,7 +423,9 @@ def run_research_agent(
|
|||
recent_inputs = human_input_repository.get_recent(1)
|
||||
if recent_inputs and len(recent_inputs) > 0:
|
||||
last_human_input = recent_inputs[0].content
|
||||
base_task = f"<last human input>{last_human_input}</last human input>\n{base_task}"
|
||||
base_task = (
|
||||
f"<last human input>{last_human_input}</last human input>\n{base_task}"
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access human input repository: {str(e)}")
|
||||
# Continue without appending last human input
|
||||
|
|
@ -416,7 +435,9 @@ def run_research_agent(
|
|||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||
key_facts = ""
|
||||
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
||||
key_snippets = format_key_snippets_dict(
|
||||
get_key_snippet_repository().get_snippets_dict()
|
||||
)
|
||||
related_files = get_related_files()
|
||||
|
||||
try:
|
||||
|
|
@ -432,20 +453,22 @@ def run_research_agent(
|
|||
human_interaction=hil,
|
||||
web_research_enabled=get_config_repository().get("web_research_enabled", False),
|
||||
)
|
||||
|
||||
|
||||
# Get model info for reasoning assistance configuration
|
||||
provider = get_config_repository().get("provider", "")
|
||||
model_name = get_config_repository().get("model", "")
|
||||
|
||||
|
||||
# Get model configuration to check for reasoning_assist_default
|
||||
model_config = {}
|
||||
provider_models = models_params.get(provider, {})
|
||||
if provider_models and model_name in provider_models:
|
||||
model_config = provider_models[model_name]
|
||||
|
||||
|
||||
# Check if reasoning assist is explicitly enabled/disabled
|
||||
force_assistance = get_config_repository().get("force_reasoning_assistance", False)
|
||||
disable_assistance = get_config_repository().get("disable_reasoning_assistance", False)
|
||||
disable_assistance = get_config_repository().get(
|
||||
"disable_reasoning_assistance", False
|
||||
)
|
||||
if force_assistance:
|
||||
reasoning_assist_enabled = True
|
||||
elif disable_assistance:
|
||||
|
|
@ -453,26 +476,31 @@ def run_research_agent(
|
|||
else:
|
||||
# Fall back to model default
|
||||
reasoning_assist_enabled = model_config.get("reasoning_assist_default", False)
|
||||
|
||||
|
||||
logger.debug("Reasoning assist enabled: %s", reasoning_assist_enabled)
|
||||
expert_guidance = ""
|
||||
|
||||
|
||||
# Get research note information for reasoning assistance
|
||||
try:
|
||||
research_notes = format_research_notes_dict(get_research_note_repository().get_notes_dict())
|
||||
research_notes = format_research_notes_dict(
|
||||
get_research_note_repository().get_notes_dict()
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get research notes: {e}")
|
||||
research_notes = ""
|
||||
|
||||
|
||||
# If reasoning assist is enabled, make a one-off call to the expert model
|
||||
if reasoning_assist_enabled:
|
||||
try:
|
||||
logger.info("Reasoning assist enabled for model %s, getting expert guidance", model_name)
|
||||
|
||||
logger.info(
|
||||
"Reasoning assist enabled for model %s, getting expert guidance",
|
||||
model_name,
|
||||
)
|
||||
|
||||
# Collect tool descriptions
|
||||
tool_metadata = []
|
||||
from ra_aid.tools.reflection import get_function_info as get_tool_info
|
||||
|
||||
|
||||
for tool in tools:
|
||||
try:
|
||||
tool_info = get_tool_info(tool.func)
|
||||
|
|
@ -481,13 +509,13 @@ def run_research_agent(
|
|||
tool_metadata.append(f"Tool: {name}\nDescription: {description}\n")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting tool info for {tool}: {e}")
|
||||
|
||||
|
||||
# Format tool metadata
|
||||
formatted_tool_metadata = "\n".join(tool_metadata)
|
||||
|
||||
|
||||
# Initialize expert model
|
||||
expert_model = initialize_expert_llm(provider, model_name)
|
||||
|
||||
|
||||
# Format the reasoning assist prompt
|
||||
reasoning_assist_prompt = REASONING_ASSIST_PROMPT_RESEARCH.format(
|
||||
current_date=current_date,
|
||||
|
|
@ -500,62 +528,78 @@ def run_research_agent(
|
|||
env_inv=get_env_inv(),
|
||||
tool_metadata=formatted_tool_metadata,
|
||||
)
|
||||
|
||||
|
||||
# Show the reasoning assist query in a panel
|
||||
console.print(
|
||||
Panel(Markdown("Consulting with the reasoning model on the best research approach."), title="📝 Thinking about research strategy...", border_style="yellow")
|
||||
Panel(
|
||||
Markdown(
|
||||
"Consulting with the reasoning model on the best research approach."
|
||||
),
|
||||
title="📝 Thinking about research strategy...",
|
||||
border_style="yellow",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
logger.debug("Invoking expert model for reasoning assist")
|
||||
# Make the call to the expert model
|
||||
response = expert_model.invoke(reasoning_assist_prompt)
|
||||
|
||||
|
||||
# Check if the model supports think tags
|
||||
supports_think_tag = model_config.get("supports_think_tag", False)
|
||||
supports_thinking = model_config.get("supports_thinking", False)
|
||||
|
||||
|
||||
# Get response content, handling if it's a list (for Claude thinking mode)
|
||||
content = None
|
||||
|
||||
if hasattr(response, 'content'):
|
||||
|
||||
if hasattr(response, "content"):
|
||||
content = response.content
|
||||
else:
|
||||
# Fallback if content attribute is missing
|
||||
content = str(response)
|
||||
|
||||
|
||||
# Process content based on its type
|
||||
if isinstance(content, list):
|
||||
# Handle structured thinking mode (e.g., Claude 3.7)
|
||||
thinking_content = None
|
||||
response_text = None
|
||||
|
||||
|
||||
# Process each item in the list
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
# Extract thinking content
|
||||
if item.get('type') == 'thinking' and 'thinking' in item:
|
||||
thinking_content = item['thinking']
|
||||
if item.get("type") == "thinking" and "thinking" in item:
|
||||
thinking_content = item["thinking"]
|
||||
logger.debug("Found structured thinking content")
|
||||
# Extract response text
|
||||
elif item.get('type') == 'text' and 'text' in item:
|
||||
response_text = item['text']
|
||||
elif item.get("type") == "text" and "text" in item:
|
||||
response_text = item["text"]
|
||||
logger.debug("Found structured response text")
|
||||
|
||||
|
||||
# Display thinking content in a separate panel if available
|
||||
if thinking_content and get_config_repository().get("show_thoughts", False):
|
||||
logger.debug(f"Displaying structured thinking content ({len(thinking_content)} chars)")
|
||||
console.print(
|
||||
Panel(Markdown(thinking_content), title="💭 Expert Thinking", border_style="yellow")
|
||||
if thinking_content and get_config_repository().get(
|
||||
"show_thoughts", False
|
||||
):
|
||||
logger.debug(
|
||||
f"Displaying structured thinking content ({len(thinking_content)} chars)"
|
||||
)
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(thinking_content),
|
||||
title="💭 Expert Thinking",
|
||||
border_style="yellow",
|
||||
)
|
||||
)
|
||||
|
||||
# Use response_text if available, otherwise fall back to joining
|
||||
if response_text:
|
||||
content = response_text
|
||||
else:
|
||||
# Fallback: join list items if structured extraction failed
|
||||
logger.debug("No structured response text found, joining list items")
|
||||
logger.debug(
|
||||
"No structured response text found, joining list items"
|
||||
)
|
||||
content = "\n".join(str(item) for item in content)
|
||||
elif (supports_think_tag or supports_thinking):
|
||||
elif supports_think_tag or supports_thinking:
|
||||
# Process thinking content using the centralized function
|
||||
content, _ = process_thinking_content(
|
||||
content=content,
|
||||
|
|
@ -563,22 +607,28 @@ def run_research_agent(
|
|||
supports_thinking=supports_thinking,
|
||||
panel_title="💭 Expert Thinking",
|
||||
panel_style="yellow",
|
||||
logger=logger
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
|
||||
# Display the expert guidance in a panel
|
||||
console.print(
|
||||
Panel(Markdown(content), title="Research Strategy Guidance", border_style="blue")
|
||||
Panel(
|
||||
Markdown(content),
|
||||
title="Research Strategy Guidance",
|
||||
border_style="blue",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Use the content as expert guidance
|
||||
expert_guidance = content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY DURING RESEARCH"
|
||||
|
||||
expert_guidance = (
|
||||
content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY DURING RESEARCH"
|
||||
)
|
||||
|
||||
logger.info("Received expert guidance for research")
|
||||
except Exception as e:
|
||||
logger.error("Error getting expert guidance for research: %s", e)
|
||||
expert_guidance = ""
|
||||
|
||||
|
||||
agent = create_agent(model, tools, checkpointer=memory, agent_type="research")
|
||||
|
||||
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
||||
|
|
@ -588,7 +638,7 @@ def run_research_agent(
|
|||
if get_config_repository().get("web_research_enabled")
|
||||
else ""
|
||||
)
|
||||
|
||||
|
||||
# Prepare expert guidance section if expert guidance is available
|
||||
expert_guidance_section = ""
|
||||
if expert_guidance:
|
||||
|
|
@ -600,7 +650,7 @@ def run_research_agent(
|
|||
# We get research notes earlier for reasoning assistance
|
||||
|
||||
# Get environment inventory information
|
||||
|
||||
|
||||
prompt = (RESEARCH_ONLY_PROMPT if research_only else RESEARCH_PROMPT).format(
|
||||
current_date=current_date,
|
||||
working_directory=working_directory,
|
||||
|
|
@ -643,9 +693,7 @@ def run_research_agent(
|
|||
if agent is not None:
|
||||
logger.debug("Research agent created successfully")
|
||||
none_or_fallback_handler = init_fallback_handler(agent, tools)
|
||||
_result = run_agent_with_retry(
|
||||
agent, prompt, none_or_fallback_handler
|
||||
)
|
||||
_result = run_agent_with_retry(agent, prompt, none_or_fallback_handler)
|
||||
if _result:
|
||||
# Log research completion
|
||||
log_work_event(f"Completed research phase for: {base_task_or_query}")
|
||||
|
|
@ -731,7 +779,9 @@ def run_web_research_agent(
|
|||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||
key_facts = ""
|
||||
try:
|
||||
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
||||
key_snippets = format_key_snippets_dict(
|
||||
get_key_snippet_repository().get_snippets_dict()
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
||||
key_snippets = ""
|
||||
|
|
@ -741,7 +791,7 @@ def run_web_research_agent(
|
|||
working_directory = os.getcwd()
|
||||
|
||||
# Get environment inventory information
|
||||
|
||||
|
||||
prompt = WEB_RESEARCH_PROMPT.format(
|
||||
current_date=current_date,
|
||||
working_directory=working_directory,
|
||||
|
|
@ -771,9 +821,7 @@ def run_web_research_agent(
|
|||
|
||||
logger.debug("Web research agent completed successfully")
|
||||
none_or_fallback_handler = init_fallback_handler(agent, tools)
|
||||
_result = run_agent_with_retry(
|
||||
agent, prompt, none_or_fallback_handler
|
||||
)
|
||||
_result = run_agent_with_retry(agent, prompt, none_or_fallback_handler)
|
||||
if _result:
|
||||
# Log web research completion
|
||||
log_work_event(f"Completed web research phase for: {query}")
|
||||
|
|
@ -835,17 +883,19 @@ def run_planning_agent(
|
|||
provider = get_config_repository().get("provider", "")
|
||||
model_name = get_config_repository().get("model", "")
|
||||
logger.debug("Checking for reasoning_assist_default on %s/%s", provider, model_name)
|
||||
|
||||
|
||||
# Get model configuration to check for reasoning_assist_default
|
||||
model_config = {}
|
||||
provider_models = models_params.get(provider, {})
|
||||
if provider_models and model_name in provider_models:
|
||||
model_config = provider_models[model_name]
|
||||
|
||||
|
||||
# Check if reasoning assist is explicitly enabled/disabled
|
||||
force_assistance = get_config_repository().get("force_reasoning_assistance", False)
|
||||
disable_assistance = get_config_repository().get("disable_reasoning_assistance", False)
|
||||
|
||||
disable_assistance = get_config_repository().get(
|
||||
"disable_reasoning_assistance", False
|
||||
)
|
||||
|
||||
if force_assistance:
|
||||
reasoning_assist_enabled = True
|
||||
elif disable_assistance:
|
||||
|
|
@ -853,27 +903,29 @@ def run_planning_agent(
|
|||
else:
|
||||
# Fall back to model default
|
||||
reasoning_assist_enabled = model_config.get("reasoning_assist_default", False)
|
||||
|
||||
|
||||
logger.debug("Reasoning assist enabled: %s", reasoning_assist_enabled)
|
||||
|
||||
|
||||
# Get all the context information (used both for normal planning and reasoning assist)
|
||||
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
working_directory = os.getcwd()
|
||||
|
||||
|
||||
# Make sure key_facts is defined before using it
|
||||
try:
|
||||
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||
key_facts = ""
|
||||
|
||||
|
||||
# Make sure key_snippets is defined before using it
|
||||
try:
|
||||
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
||||
key_snippets = format_key_snippets_dict(
|
||||
get_key_snippet_repository().get_snippets_dict()
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
||||
key_snippets = ""
|
||||
|
||||
|
||||
# Get formatted research notes using repository
|
||||
try:
|
||||
repository = get_research_note_repository()
|
||||
|
|
@ -882,28 +934,31 @@ def run_planning_agent(
|
|||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access research note repository: {str(e)}")
|
||||
formatted_research_notes = ""
|
||||
|
||||
|
||||
# Get related files
|
||||
related_files = "\n".join(get_related_files())
|
||||
|
||||
|
||||
# Get environment inventory information
|
||||
env_inv = get_env_inv()
|
||||
|
||||
|
||||
# Display the planning stage header before any reasoning assistance
|
||||
print_stage_header("Planning Stage")
|
||||
|
||||
|
||||
# Initialize expert guidance section
|
||||
expert_guidance = ""
|
||||
|
||||
|
||||
# If reasoning assist is enabled, make a one-off call to the expert model
|
||||
if reasoning_assist_enabled:
|
||||
try:
|
||||
logger.info("Reasoning assist enabled for model %s, getting expert guidance", model_name)
|
||||
|
||||
logger.info(
|
||||
"Reasoning assist enabled for model %s, getting expert guidance",
|
||||
model_name,
|
||||
)
|
||||
|
||||
# Collect tool descriptions
|
||||
tool_metadata = []
|
||||
from ra_aid.tools.reflection import get_function_info as get_tool_info
|
||||
|
||||
|
||||
for tool in tools:
|
||||
try:
|
||||
tool_info = get_tool_info(tool.func)
|
||||
|
|
@ -912,13 +967,13 @@ def run_planning_agent(
|
|||
tool_metadata.append(f"Tool: {name}\nDescription: {description}\n")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting tool info for {tool}: {e}")
|
||||
|
||||
|
||||
# Format tool metadata
|
||||
formatted_tool_metadata = "\n".join(tool_metadata)
|
||||
|
||||
|
||||
# Initialize expert model
|
||||
expert_model = initialize_expert_llm(provider, model_name)
|
||||
|
||||
|
||||
# Format the reasoning assist prompt
|
||||
reasoning_assist_prompt = REASONING_ASSIST_PROMPT_PLANNING.format(
|
||||
current_date=current_date,
|
||||
|
|
@ -931,62 +986,78 @@ def run_planning_agent(
|
|||
env_inv=env_inv,
|
||||
tool_metadata=formatted_tool_metadata,
|
||||
)
|
||||
|
||||
|
||||
# Show the reasoning assist query in a panel
|
||||
console.print(
|
||||
Panel(Markdown("Consulting with the reasoning model on the best way to do this."), title="📝 Thinking about the plan...", border_style="yellow")
|
||||
Panel(
|
||||
Markdown(
|
||||
"Consulting with the reasoning model on the best way to do this."
|
||||
),
|
||||
title="📝 Thinking about the plan...",
|
||||
border_style="yellow",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
logger.debug("Invoking expert model for reasoning assist")
|
||||
# Make the call to the expert model
|
||||
response = expert_model.invoke(reasoning_assist_prompt)
|
||||
|
||||
|
||||
# Check if the model supports think tags
|
||||
supports_think_tag = model_config.get("supports_think_tag", False)
|
||||
supports_thinking = model_config.get("supports_thinking", False)
|
||||
|
||||
|
||||
# Get response content, handling if it's a list (for Claude thinking mode)
|
||||
content = None
|
||||
|
||||
if hasattr(response, 'content'):
|
||||
|
||||
if hasattr(response, "content"):
|
||||
content = response.content
|
||||
else:
|
||||
# Fallback if content attribute is missing
|
||||
content = str(response)
|
||||
|
||||
|
||||
# Process content based on its type
|
||||
if isinstance(content, list):
|
||||
# Handle structured thinking mode (e.g., Claude 3.7)
|
||||
thinking_content = None
|
||||
response_text = None
|
||||
|
||||
|
||||
# Process each item in the list
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
# Extract thinking content
|
||||
if item.get('type') == 'thinking' and 'thinking' in item:
|
||||
thinking_content = item['thinking']
|
||||
if item.get("type") == "thinking" and "thinking" in item:
|
||||
thinking_content = item["thinking"]
|
||||
logger.debug("Found structured thinking content")
|
||||
# Extract response text
|
||||
elif item.get('type') == 'text' and 'text' in item:
|
||||
response_text = item['text']
|
||||
elif item.get("type") == "text" and "text" in item:
|
||||
response_text = item["text"]
|
||||
logger.debug("Found structured response text")
|
||||
|
||||
|
||||
# Display thinking content in a separate panel if available
|
||||
if thinking_content and get_config_repository().get("show_thoughts", False):
|
||||
logger.debug(f"Displaying structured thinking content ({len(thinking_content)} chars)")
|
||||
console.print(
|
||||
Panel(Markdown(thinking_content), title="💭 Expert Thinking", border_style="yellow")
|
||||
if thinking_content and get_config_repository().get(
|
||||
"show_thoughts", False
|
||||
):
|
||||
logger.debug(
|
||||
f"Displaying structured thinking content ({len(thinking_content)} chars)"
|
||||
)
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(thinking_content),
|
||||
title="💭 Expert Thinking",
|
||||
border_style="yellow",
|
||||
)
|
||||
)
|
||||
|
||||
# Use response_text if available, otherwise fall back to joining
|
||||
if response_text:
|
||||
content = response_text
|
||||
else:
|
||||
# Fallback: join list items if structured extraction failed
|
||||
logger.debug("No structured response text found, joining list items")
|
||||
logger.debug(
|
||||
"No structured response text found, joining list items"
|
||||
)
|
||||
content = "\n".join(str(item) for item in content)
|
||||
elif (supports_think_tag or supports_thinking):
|
||||
elif supports_think_tag or supports_thinking:
|
||||
# Process thinking content using the centralized function
|
||||
content, _ = process_thinking_content(
|
||||
content=content,
|
||||
|
|
@ -994,24 +1065,28 @@ def run_planning_agent(
|
|||
supports_thinking=supports_thinking,
|
||||
panel_title="💭 Expert Thinking",
|
||||
panel_style="yellow",
|
||||
logger=logger
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
|
||||
# Display the expert guidance in a panel
|
||||
console.print(
|
||||
Panel(Markdown(content), title="Reasoning Guidance", border_style="blue")
|
||||
Panel(
|
||||
Markdown(content), title="Reasoning Guidance", border_style="blue"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Use the content as expert guidance
|
||||
expert_guidance = content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY ON THIS TASK"
|
||||
|
||||
expert_guidance = (
|
||||
content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY ON THIS TASK"
|
||||
)
|
||||
|
||||
logger.info("Received expert guidance for planning")
|
||||
except Exception as e:
|
||||
logger.error("Error getting expert guidance for planning: %s", e)
|
||||
expert_guidance = ""
|
||||
|
||||
|
||||
agent = create_agent(model, tools, checkpointer=memory, agent_type="planner")
|
||||
|
||||
|
||||
expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else ""
|
||||
human_section = HUMAN_PROMPT_SECTION_PLANNING if hil else ""
|
||||
web_research_section = (
|
||||
|
|
@ -1019,7 +1094,7 @@ def run_planning_agent(
|
|||
if get_config_repository().get("web_research_enabled", False)
|
||||
else ""
|
||||
)
|
||||
|
||||
|
||||
# Prepare expert guidance section if expert guidance is available
|
||||
expert_guidance_section = ""
|
||||
if expert_guidance:
|
||||
|
|
@ -1050,7 +1125,9 @@ def run_planning_agent(
|
|||
)
|
||||
|
||||
config_values = get_config_repository().get_all()
|
||||
recursion_limit = get_config_repository().get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
||||
recursion_limit = get_config_repository().get(
|
||||
"recursion_limit", DEFAULT_RECURSION_LIMIT
|
||||
)
|
||||
run_config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
"recursion_limit": recursion_limit,
|
||||
|
|
@ -1060,9 +1137,7 @@ def run_planning_agent(
|
|||
try:
|
||||
logger.debug("Planning agent completed successfully")
|
||||
none_or_fallback_handler = init_fallback_handler(agent, tools)
|
||||
_result = run_agent_with_retry(
|
||||
agent, planning_prompt, none_or_fallback_handler
|
||||
)
|
||||
_result = run_agent_with_retry(agent, planning_prompt, none_or_fallback_handler)
|
||||
if _result:
|
||||
# Log planning completion
|
||||
log_work_event(f"Completed planning phase for: {base_task}")
|
||||
|
|
@ -1135,7 +1210,7 @@ def run_task_implementation_agent(
|
|||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||
key_facts = ""
|
||||
|
||||
|
||||
# Get formatted research notes using repository
|
||||
try:
|
||||
repository = get_research_note_repository()
|
||||
|
|
@ -1144,7 +1219,7 @@ def run_task_implementation_agent(
|
|||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access research note repository: {str(e)}")
|
||||
formatted_research_notes = ""
|
||||
|
||||
|
||||
# Get latest project info
|
||||
try:
|
||||
project_info = get_project_info(".")
|
||||
|
|
@ -1152,24 +1227,26 @@ def run_task_implementation_agent(
|
|||
except Exception as e:
|
||||
logger.warning("Failed to get project info: %s", str(e))
|
||||
formatted_project_info = "Project info unavailable"
|
||||
|
||||
|
||||
# Get environment inventory information
|
||||
env_inv = get_env_inv()
|
||||
|
||||
|
||||
# Get model configuration to check for reasoning_assist_default
|
||||
provider = get_config_repository().get("provider", "")
|
||||
model_name = get_config_repository().get("model", "")
|
||||
logger.debug("Checking for reasoning_assist_default on %s/%s", provider, model_name)
|
||||
|
||||
|
||||
model_config = {}
|
||||
provider_models = models_params.get(provider, {})
|
||||
if provider_models and model_name in provider_models:
|
||||
model_config = provider_models[model_name]
|
||||
|
||||
|
||||
# Check if reasoning assist is explicitly enabled/disabled
|
||||
force_assistance = get_config_repository().get("force_reasoning_assistance", False)
|
||||
disable_assistance = get_config_repository().get("disable_reasoning_assistance", False)
|
||||
|
||||
disable_assistance = get_config_repository().get(
|
||||
"disable_reasoning_assistance", False
|
||||
)
|
||||
|
||||
if force_assistance:
|
||||
reasoning_assist_enabled = True
|
||||
elif disable_assistance:
|
||||
|
|
@ -1177,71 +1254,84 @@ def run_task_implementation_agent(
|
|||
else:
|
||||
# Fall back to model default
|
||||
reasoning_assist_enabled = model_config.get("reasoning_assist_default", False)
|
||||
|
||||
|
||||
logger.debug("Reasoning assist enabled: %s", reasoning_assist_enabled)
|
||||
|
||||
|
||||
# Initialize implementation guidance section
|
||||
implementation_guidance_section = ""
|
||||
|
||||
|
||||
# If reasoning assist is enabled, make a one-off call to the expert model
|
||||
if reasoning_assist_enabled:
|
||||
try:
|
||||
logger.info("Reasoning assist enabled for model %s, getting implementation guidance", model_name)
|
||||
|
||||
logger.info(
|
||||
"Reasoning assist enabled for model %s, getting implementation guidance",
|
||||
model_name,
|
||||
)
|
||||
|
||||
# Collect tool descriptions
|
||||
tool_metadata = []
|
||||
from ra_aid.tools.reflection import get_function_info as get_tool_info
|
||||
|
||||
|
||||
for tool in tools:
|
||||
try:
|
||||
tool_info = get_tool_info(tool.func)
|
||||
name = tool.func.__name__
|
||||
description = inspect.getdoc(tool.func)
|
||||
tool_metadata.append(f"Tool: {name}\\nDescription: {description}\\n")
|
||||
tool_metadata.append(
|
||||
f"Tool: {name}\\nDescription: {description}\\n"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting tool info for {tool}: {e}")
|
||||
|
||||
|
||||
# Format tool metadata
|
||||
formatted_tool_metadata = "\\n".join(tool_metadata)
|
||||
|
||||
|
||||
# Initialize expert model
|
||||
expert_model = initialize_expert_llm(provider, model_name)
|
||||
|
||||
|
||||
# Format the reasoning assist prompt for implementation
|
||||
reasoning_assist_prompt = REASONING_ASSIST_PROMPT_IMPLEMENTATION.format(
|
||||
current_date=current_date,
|
||||
working_directory=working_directory,
|
||||
task=task,
|
||||
key_facts=key_facts,
|
||||
key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()),
|
||||
key_snippets=format_key_snippets_dict(
|
||||
get_key_snippet_repository().get_snippets_dict()
|
||||
),
|
||||
research_notes=formatted_research_notes,
|
||||
related_files="\\n".join(related_files),
|
||||
env_inv=env_inv,
|
||||
tool_metadata=formatted_tool_metadata,
|
||||
)
|
||||
|
||||
|
||||
# Show the reasoning assist query in a panel
|
||||
console.print(
|
||||
Panel(Markdown("Consulting with the reasoning model on the best implementation approach."), title="📝 Thinking about implementation...", border_style="yellow")
|
||||
Panel(
|
||||
Markdown(
|
||||
"Consulting with the reasoning model on the best implementation approach."
|
||||
),
|
||||
title="📝 Thinking about implementation...",
|
||||
border_style="yellow",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
logger.debug("Invoking expert model for implementation reasoning assist")
|
||||
# Make the call to the expert model
|
||||
response = expert_model.invoke(reasoning_assist_prompt)
|
||||
|
||||
|
||||
# Check if the model supports think tags
|
||||
supports_think_tag = model_config.get("supports_think_tag", False)
|
||||
supports_thinking = model_config.get("supports_thinking", False)
|
||||
|
||||
|
||||
# Process response content
|
||||
content = None
|
||||
|
||||
if hasattr(response, 'content'):
|
||||
|
||||
if hasattr(response, "content"):
|
||||
content = response.content
|
||||
else:
|
||||
# Fallback if content attribute is missing
|
||||
content = str(response)
|
||||
|
||||
|
||||
# Process the response content using the centralized function
|
||||
content, extracted_thinking = process_thinking_content(
|
||||
content=content,
|
||||
|
|
@ -1249,24 +1339,28 @@ def run_task_implementation_agent(
|
|||
supports_thinking=supports_thinking,
|
||||
panel_title="💭 Implementation Thinking",
|
||||
panel_style="yellow",
|
||||
logger=logger
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
|
||||
# Display the implementation guidance in a panel
|
||||
console.print(
|
||||
Panel(Markdown(content), title="Implementation Guidance", border_style="blue")
|
||||
Panel(
|
||||
Markdown(content),
|
||||
title="Implementation Guidance",
|
||||
border_style="blue",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Format the implementation guidance section for the prompt
|
||||
implementation_guidance_section = f"""<implementation guidance>
|
||||
{content}
|
||||
</implementation guidance>"""
|
||||
|
||||
|
||||
logger.info("Received implementation guidance")
|
||||
except Exception as e:
|
||||
logger.error("Error getting implementation guidance: %s", e)
|
||||
implementation_guidance_section = ""
|
||||
|
||||
|
||||
prompt = IMPLEMENTATION_PROMPT.format(
|
||||
current_date=current_date,
|
||||
working_directory=working_directory,
|
||||
|
|
@ -1276,7 +1370,9 @@ def run_task_implementation_agent(
|
|||
plan=plan,
|
||||
related_files=related_files,
|
||||
key_facts=key_facts,
|
||||
key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()),
|
||||
key_snippets=format_key_snippets_dict(
|
||||
get_key_snippet_repository().get_snippets_dict()
|
||||
),
|
||||
research_notes=formatted_research_notes,
|
||||
work_log=get_work_log_repository().format_work_log(),
|
||||
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
|
||||
|
|
@ -1296,7 +1392,9 @@ def run_task_implementation_agent(
|
|||
)
|
||||
|
||||
config_values = get_config_repository().get_all()
|
||||
recursion_limit = get_config_repository().get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
||||
recursion_limit = get_config_repository().get(
|
||||
"recursion_limit", DEFAULT_RECURSION_LIMIT
|
||||
)
|
||||
run_config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
"recursion_limit": recursion_limit,
|
||||
|
|
@ -1306,9 +1404,7 @@ def run_task_implementation_agent(
|
|||
try:
|
||||
logger.debug("Implementation agent completed successfully")
|
||||
none_or_fallback_handler = init_fallback_handler(agent, tools)
|
||||
_result = run_agent_with_retry(
|
||||
agent, prompt, none_or_fallback_handler
|
||||
)
|
||||
_result = run_agent_with_retry(agent, prompt, none_or_fallback_handler)
|
||||
if _result:
|
||||
# Log task implementation completion
|
||||
log_work_event(f"Completed implementation of task: {task}")
|
||||
|
|
@ -1380,27 +1476,37 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
|
|||
# 1. Check if this is a ValueError with 429 code or rate limit phrases
|
||||
if isinstance(e, ValueError):
|
||||
error_str = str(e).lower()
|
||||
rate_limit_phrases = ["429", "rate limit", "too many requests", "quota exceeded"]
|
||||
if "code" not in error_str and not any(phrase in error_str for phrase in rate_limit_phrases):
|
||||
rate_limit_phrases = [
|
||||
"429",
|
||||
"rate limit",
|
||||
"too many requests",
|
||||
"quota exceeded",
|
||||
]
|
||||
if "code" not in error_str and not any(
|
||||
phrase in error_str for phrase in rate_limit_phrases
|
||||
):
|
||||
raise e
|
||||
|
||||
|
||||
# 2. Check for status_code or http_status attribute equal to 429
|
||||
if hasattr(e, 'status_code') and e.status_code == 429:
|
||||
if hasattr(e, "status_code") and e.status_code == 429:
|
||||
pass # This is a rate limit error, continue with retry logic
|
||||
elif hasattr(e, 'http_status') and e.http_status == 429:
|
||||
elif hasattr(e, "http_status") and e.http_status == 429:
|
||||
pass # This is a rate limit error, continue with retry logic
|
||||
# 3. Check for rate limit phrases in error message
|
||||
elif isinstance(e, Exception) and not isinstance(e, ValueError):
|
||||
error_str = str(e).lower()
|
||||
if not any(phrase in error_str for phrase in ["rate limit", "too many requests", "quota exceeded", "429"]) and not ("rate" in error_str and "limit" in error_str):
|
||||
if not any(
|
||||
phrase in error_str
|
||||
for phrase in ["rate limit", "too many requests", "quota exceeded", "429"]
|
||||
) and not ("rate" in error_str and "limit" in error_str):
|
||||
# This doesn't look like a rate limit error, but we'll still retry other API errors
|
||||
pass
|
||||
|
||||
|
||||
# Apply common retry logic for all identified errors
|
||||
if attempt == max_retries - 1:
|
||||
logger.error("Max retries reached, failing: %s", str(e))
|
||||
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {e}")
|
||||
|
||||
|
||||
logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e))
|
||||
delay = base_delay * (2**attempt)
|
||||
print_error(
|
||||
|
|
@ -1457,55 +1563,78 @@ def _handle_fallback_response(
|
|||
def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage]):
|
||||
"""
|
||||
Streams agent output while handling completion and interruption.
|
||||
|
||||
|
||||
For each chunk, it logs the output, calls check_interrupt(), prints agent output,
|
||||
and then checks if is_completed() or should_exit() are true. If so, it resets completion
|
||||
flags and returns. After finishing a stream iteration (i.e. the for-loop over chunks),
|
||||
the function retrieves the agent's state. If the state indicates further steps (i.e. state.next is non-empty),
|
||||
it resumes execution via agent.invoke(None, config); otherwise, it exits the loop.
|
||||
|
||||
|
||||
This function adheres to the latest LangGraph best practices (as of March 2025) for handling
|
||||
human-in-the-loop interruptions using interrupt_after=["tools"].
|
||||
"""
|
||||
config = get_config_repository().get_all()
|
||||
stream_config = config.copy()
|
||||
|
||||
cb = None
|
||||
if is_anthropic_claude(config):
|
||||
model_name = config.get("model", "")
|
||||
full_model_name = model_name
|
||||
cb = AnthropicCallbackHandler(full_model_name)
|
||||
|
||||
if "callbacks" not in stream_config:
|
||||
stream_config["callbacks"] = []
|
||||
stream_config["callbacks"].append(cb)
|
||||
|
||||
while True:
|
||||
# Process each chunk from the agent stream.
|
||||
for chunk in agent.stream({"messages": msg_list}, config):
|
||||
for chunk in agent.stream({"messages": msg_list}, stream_config):
|
||||
logger.debug("Agent output: %s", chunk)
|
||||
check_interrupt()
|
||||
agent_type = get_agent_type(agent)
|
||||
print_agent_output(chunk, agent_type)
|
||||
|
||||
if is_completed() or should_exit():
|
||||
reset_completion_flags()
|
||||
return True # Exit immediately when finished or signaled to exit.
|
||||
if cb:
|
||||
logger.debug(f"AnthropicCallbackHandler:\n{cb}")
|
||||
return True
|
||||
|
||||
logger.debug("Stream iteration ended; checking agent state for continuation.")
|
||||
|
||||
|
||||
# Prepare state configuration, ensuring 'configurable' is present.
|
||||
state_config = get_config_repository().get_all().copy()
|
||||
if "configurable" not in state_config:
|
||||
logger.debug("Key 'configurable' not found in config; adding it as an empty dict.")
|
||||
logger.debug(
|
||||
"Key 'configurable' not found in config; adding it as an empty dict."
|
||||
)
|
||||
state_config["configurable"] = {}
|
||||
logger.debug("Using state_config for agent.get_state(): %s", state_config)
|
||||
|
||||
|
||||
try:
|
||||
state = agent.get_state(state_config)
|
||||
logger.debug("Agent state retrieved: %s", state)
|
||||
except Exception as e:
|
||||
logger.error("Error retrieving agent state with state_config %s: %s", state_config, e)
|
||||
logger.error(
|
||||
"Error retrieving agent state with state_config %s: %s", state_config, e
|
||||
)
|
||||
raise
|
||||
|
||||
# If the state indicates that further steps remain (i.e. state.next is non-empty),
|
||||
# then resume execution by invoking the agent with no new input.
|
||||
|
||||
if state.next:
|
||||
logger.debug("State indicates continuation (state.next: %s); resuming execution.", state.next)
|
||||
agent.invoke(None, config)
|
||||
logger.debug(
|
||||
"State indicates continuation (state.next: %s); resuming execution.",
|
||||
state.next,
|
||||
)
|
||||
agent.invoke(None, stream_config)
|
||||
continue
|
||||
else:
|
||||
logger.debug("No continuation indicated in state; exiting stream loop.")
|
||||
break
|
||||
|
||||
if cb:
|
||||
logger.debug(f"AnthropicCallbackHandler:\n{cb}")
|
||||
return True
|
||||
|
||||
|
||||
def run_agent_with_retry(
|
||||
agent: RAgents,
|
||||
prompt: str,
|
||||
|
|
@ -1517,7 +1646,9 @@ def run_agent_with_retry(
|
|||
max_retries = 20
|
||||
base_delay = 1
|
||||
test_attempts = 0
|
||||
_max_test_retries = get_config_repository().get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES)
|
||||
_max_test_retries = get_config_repository().get(
|
||||
"max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES
|
||||
)
|
||||
auto_test = get_config_repository().get("auto_test", False)
|
||||
original_prompt = prompt
|
||||
msg_list = [HumanMessage(content=prompt)]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,270 @@
|
|||
"""Custom callback handlers for tracking token usage and costs."""
|
||||
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
|
||||
# Define cost per 1K tokens for various models
|
||||
ANTHROPIC_MODEL_COSTS = {
|
||||
# Claude 3.7 Sonnet input
|
||||
"claude-3-7-sonnet-20250219": 0.003,
|
||||
"anthropic/claude-3.7-sonnet": 0.003,
|
||||
"claude-3.7-sonnet": 0.003,
|
||||
# Claude 3.7 Sonnet output
|
||||
"claude-3-7-sonnet-20250219-completion": 0.015,
|
||||
"anthropic/claude-3.7-sonnet-completion": 0.015,
|
||||
"claude-3.7-sonnet-completion": 0.015,
|
||||
# Claude 3 Opus input
|
||||
"claude-3-opus-20240229": 0.015,
|
||||
"anthropic/claude-3-opus": 0.015,
|
||||
"claude-3-opus": 0.015,
|
||||
# Claude 3 Opus output
|
||||
"claude-3-opus-20240229-completion": 0.075,
|
||||
"anthropic/claude-3-opus-completion": 0.075,
|
||||
"claude-3-opus-completion": 0.075,
|
||||
# Claude 3 Sonnet input
|
||||
"claude-3-sonnet-20240229": 0.003,
|
||||
"anthropic/claude-3-sonnet": 0.003,
|
||||
"claude-3-sonnet": 0.003,
|
||||
# Claude 3 Sonnet output
|
||||
"claude-3-sonnet-20240229-completion": 0.015,
|
||||
"anthropic/claude-3-sonnet-completion": 0.015,
|
||||
"claude-3-sonnet-completion": 0.015,
|
||||
# Claude 3 Haiku input
|
||||
"claude-3-haiku-20240307": 0.00025,
|
||||
"anthropic/claude-3-haiku": 0.00025,
|
||||
"claude-3-haiku": 0.00025,
|
||||
# Claude 3 Haiku output
|
||||
"claude-3-haiku-20240307-completion": 0.00125,
|
||||
"anthropic/claude-3-haiku-completion": 0.00125,
|
||||
"claude-3-haiku-completion": 0.00125,
|
||||
# Claude 2 input
|
||||
"claude-2": 0.008,
|
||||
"claude-2.0": 0.008,
|
||||
"claude-2.1": 0.008,
|
||||
# Claude 2 output
|
||||
"claude-2-completion": 0.024,
|
||||
"claude-2.0-completion": 0.024,
|
||||
"claude-2.1-completion": 0.024,
|
||||
# Claude Instant input
|
||||
"claude-instant-1": 0.0016,
|
||||
"claude-instant-1.2": 0.0016,
|
||||
# Claude Instant output
|
||||
"claude-instant-1-completion": 0.0055,
|
||||
"claude-instant-1.2-completion": 0.0055,
|
||||
}
|
||||
|
||||
|
||||
def standardize_model_name(model_name: str, is_completion: bool = False) -> str:
|
||||
"""
|
||||
Standardize the model name to a format that can be used for cost calculation.
|
||||
|
||||
Args:
|
||||
model_name: Model name to standardize.
|
||||
is_completion: Whether the model is used for completion or not.
|
||||
|
||||
Returns:
|
||||
Standardized model name.
|
||||
"""
|
||||
if not model_name:
|
||||
model_name = "claude-3-sonnet"
|
||||
|
||||
model_name = model_name.lower()
|
||||
|
||||
# Handle OpenRouter prefixes
|
||||
if model_name.startswith("anthropic/"):
|
||||
model_name = model_name[len("anthropic/") :]
|
||||
|
||||
# Add completion suffix if needed
|
||||
if is_completion and not model_name.endswith("-completion"):
|
||||
model_name = model_name + "-completion"
|
||||
|
||||
return model_name
|
||||
|
||||
|
||||
def get_anthropic_token_cost_for_model(
|
||||
model_name: str, num_tokens: int, is_completion: bool = False
|
||||
) -> float:
|
||||
"""
|
||||
Get the cost in USD for a given model and number of tokens.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
num_tokens: Number of tokens.
|
||||
is_completion: Whether the model is used for completion or not.
|
||||
|
||||
Returns:
|
||||
Cost in USD.
|
||||
"""
|
||||
model_name = standardize_model_name(model_name, is_completion)
|
||||
|
||||
if model_name not in ANTHROPIC_MODEL_COSTS:
|
||||
# Default to Claude 3 Sonnet pricing if model not found
|
||||
model_name = (
|
||||
"claude-3-sonnet" if not is_completion else "claude-3-sonnet-completion"
|
||||
)
|
||||
|
||||
cost_per_1k = ANTHROPIC_MODEL_COSTS[model_name]
|
||||
total_cost = cost_per_1k * (num_tokens / 1000)
|
||||
|
||||
return total_cost
|
||||
|
||||
|
||||
class AnthropicCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that tracks Anthropic token usage and costs."""
|
||||
|
||||
total_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
successful_requests: int = 0
|
||||
total_cost: float = 0.0
|
||||
model_name: str = "claude-3-sonnet" # Default model
|
||||
|
||||
def __init__(self, model_name: Optional[str] = None) -> None:
|
||||
super().__init__()
|
||||
self._lock = threading.Lock()
|
||||
if model_name:
|
||||
self.model_name = model_name
|
||||
|
||||
# Default costs for Claude 3.7 Sonnet
|
||||
self.input_cost_per_token = 0.003 / 1000 # $3/M input tokens
|
||||
self.output_cost_per_token = 0.015 / 1000 # $15/M output tokens
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"Tokens Used: {self.total_tokens}\n"
|
||||
f"\tPrompt Tokens: {self.prompt_tokens}\n"
|
||||
f"\tCompletion Tokens: {self.completion_tokens}\n"
|
||||
f"Successful Requests: {self.successful_requests}\n"
|
||||
f"Total Cost (USD): ${self.total_cost:.6f}"
|
||||
)
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return True
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Record the model name if available."""
|
||||
if "name" in serialized:
|
||||
self.model_name = serialized["name"]
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Count tokens as they're generated."""
|
||||
with self._lock:
|
||||
self.completion_tokens += 1
|
||||
self.total_tokens += 1
|
||||
token_cost = get_anthropic_token_cost_for_model(
|
||||
self.model_name, 1, is_completion=True
|
||||
)
|
||||
self.total_cost += token_cost
|
||||
|
||||
def on_llm_end(self, response: Any, **kwargs: Any) -> None:
|
||||
"""Collect token usage from response."""
|
||||
token_usage = {}
|
||||
|
||||
# Try to extract token usage from response
|
||||
if hasattr(response, "llm_output") and response.llm_output:
|
||||
llm_output = response.llm_output
|
||||
if "token_usage" in llm_output:
|
||||
token_usage = llm_output["token_usage"]
|
||||
elif "usage" in llm_output:
|
||||
usage = llm_output["usage"]
|
||||
|
||||
# Handle Anthropic's specific usage format
|
||||
if "input_tokens" in usage:
|
||||
token_usage["prompt_tokens"] = usage["input_tokens"]
|
||||
if "output_tokens" in usage:
|
||||
token_usage["completion_tokens"] = usage["output_tokens"]
|
||||
|
||||
# Extract model name if available
|
||||
if "model_name" in llm_output:
|
||||
self.model_name = llm_output["model_name"]
|
||||
|
||||
# Try to get usage from response.usage
|
||||
elif hasattr(response, "usage"):
|
||||
usage = response.usage
|
||||
if hasattr(usage, "prompt_tokens"):
|
||||
token_usage["prompt_tokens"] = usage.prompt_tokens
|
||||
if hasattr(usage, "completion_tokens"):
|
||||
token_usage["completion_tokens"] = usage.completion_tokens
|
||||
if hasattr(usage, "total_tokens"):
|
||||
token_usage["total_tokens"] = usage.total_tokens
|
||||
|
||||
# Extract usage from generations if available
|
||||
elif hasattr(response, "generations") and response.generations:
|
||||
for gen in response.generations:
|
||||
if gen and hasattr(gen[0], "generation_info"):
|
||||
gen_info = gen[0].generation_info or {}
|
||||
if "usage" in gen_info:
|
||||
token_usage = gen_info["usage"]
|
||||
break
|
||||
|
||||
# Update counts with lock to prevent race conditions
|
||||
with self._lock:
|
||||
prompt_tokens = token_usage.get("prompt_tokens", 0)
|
||||
completion_tokens = token_usage.get("completion_tokens", 0)
|
||||
|
||||
# Only update prompt tokens if we have them
|
||||
if prompt_tokens > 0:
|
||||
self.prompt_tokens += prompt_tokens
|
||||
self.total_tokens += prompt_tokens
|
||||
prompt_cost = get_anthropic_token_cost_for_model(
|
||||
self.model_name, prompt_tokens, is_completion=False
|
||||
)
|
||||
self.total_cost += prompt_cost
|
||||
|
||||
# Only update completion tokens if not already counted by on_llm_new_token
|
||||
if completion_tokens > 0 and completion_tokens > self.completion_tokens:
|
||||
additional_tokens = completion_tokens - self.completion_tokens
|
||||
self.completion_tokens = completion_tokens
|
||||
self.total_tokens += additional_tokens
|
||||
completion_cost = get_anthropic_token_cost_for_model(
|
||||
self.model_name, additional_tokens, is_completion=True
|
||||
)
|
||||
self.total_cost += completion_cost
|
||||
|
||||
self.successful_requests += 1
|
||||
|
||||
def __copy__(self) -> "AnthropicCallbackHandler":
|
||||
"""Return a copy of the callback handler."""
|
||||
return self
|
||||
|
||||
def __deepcopy__(self, memo: Any) -> "AnthropicCallbackHandler":
|
||||
"""Return a deep copy of the callback handler."""
|
||||
return self
|
||||
|
||||
|
||||
# Create a context variable for our custom callback
|
||||
anthropic_callback_var: ContextVar[Optional[AnthropicCallbackHandler]] = ContextVar(
|
||||
"anthropic_callback", default=None
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_anthropic_callback(
|
||||
model_name: Optional[str] = None,
|
||||
) -> AnthropicCallbackHandler:
|
||||
"""Get the Anthropic callback handler in a context manager.
|
||||
which conveniently exposes token and cost information.
|
||||
|
||||
Args:
|
||||
model_name: Optional model name to use for cost calculation.
|
||||
|
||||
Returns:
|
||||
AnthropicCallbackHandler: The Anthropic callback handler.
|
||||
|
||||
Example:
|
||||
>>> with get_anthropic_callback("claude-3-sonnet") as cb:
|
||||
... # Use the callback handler
|
||||
... # cb.total_tokens, cb.total_cost will be available after
|
||||
"""
|
||||
cb = AnthropicCallbackHandler(model_name)
|
||||
anthropic_callback_var.set(cb)
|
||||
yield cb
|
||||
anthropic_callback_var.set(None)
|
||||
Loading…
Reference in New Issue