feat(agent_utils.py): add AnthropicCallbackHandler to track token usage and costs for Anthropic models (#118)
style(agent_utils.py): format imports and code for better readability refactor(agent_utils.py): standardize model name and cost calculation logic for clarity and maintainability chore(anthropic_callback_handler.py): create a new file for the AnthropicCallbackHandler implementation and related functions
This commit is contained in:
parent
d194868cff
commit
2899b5f848
|
|
@ -10,6 +10,9 @@ import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Literal, Optional, Sequence
|
from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||||
|
|
||||||
|
from ra_aid.callbacks.anthropic_callback_handler import AnthropicCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
||||||
from openai import RateLimitError as OpenAIRateLimitError
|
from openai import RateLimitError as OpenAIRateLimitError
|
||||||
|
|
@ -71,7 +74,11 @@ from ra_aid.prompts.human_prompts import (
|
||||||
from ra_aid.prompts.implementation_prompts import IMPLEMENTATION_PROMPT
|
from ra_aid.prompts.implementation_prompts import IMPLEMENTATION_PROMPT
|
||||||
from ra_aid.prompts.common_prompts import NEW_PROJECT_HINTS
|
from ra_aid.prompts.common_prompts import NEW_PROJECT_HINTS
|
||||||
from ra_aid.prompts.planning_prompts import PLANNING_PROMPT
|
from ra_aid.prompts.planning_prompts import PLANNING_PROMPT
|
||||||
from ra_aid.prompts.reasoning_assist_prompt import REASONING_ASSIST_PROMPT_PLANNING, REASONING_ASSIST_PROMPT_IMPLEMENTATION, REASONING_ASSIST_PROMPT_RESEARCH
|
from ra_aid.prompts.reasoning_assist_prompt import (
|
||||||
|
REASONING_ASSIST_PROMPT_PLANNING,
|
||||||
|
REASONING_ASSIST_PROMPT_IMPLEMENTATION,
|
||||||
|
REASONING_ASSIST_PROMPT_RESEARCH,
|
||||||
|
)
|
||||||
from ra_aid.prompts.research_prompts import (
|
from ra_aid.prompts.research_prompts import (
|
||||||
RESEARCH_ONLY_PROMPT,
|
RESEARCH_ONLY_PROMPT,
|
||||||
RESEARCH_PROMPT,
|
RESEARCH_PROMPT,
|
||||||
|
|
@ -90,9 +97,15 @@ from ra_aid.tool_configs import (
|
||||||
)
|
)
|
||||||
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
|
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
|
||||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||||
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
from ra_aid.database.repositories.key_snippet_repository import (
|
||||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
get_key_snippet_repository,
|
||||||
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
)
|
||||||
|
from ra_aid.database.repositories.human_input_repository import (
|
||||||
|
get_human_input_repository,
|
||||||
|
)
|
||||||
|
from ra_aid.database.repositories.research_note_repository import (
|
||||||
|
get_research_note_repository,
|
||||||
|
)
|
||||||
from ra_aid.database.repositories.work_log_repository import get_work_log_repository
|
from ra_aid.database.repositories.work_log_repository import get_work_log_repository
|
||||||
from ra_aid.model_formatters import format_key_facts_dict
|
from ra_aid.model_formatters import format_key_facts_dict
|
||||||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||||
|
|
@ -332,7 +345,9 @@ def create_agent(
|
||||||
if is_anthropic_claude(config):
|
if is_anthropic_claude(config):
|
||||||
logger.debug("Using create_react_agent to instantiate agent.")
|
logger.debug("Using create_react_agent to instantiate agent.")
|
||||||
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
|
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
|
||||||
return create_react_agent(model, tools, interrupt_after=['tools'], **agent_kwargs)
|
return create_react_agent(
|
||||||
|
model, tools, interrupt_after=["tools"], **agent_kwargs
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug("Using CiaynAgent agent instance")
|
logger.debug("Using CiaynAgent agent instance")
|
||||||
return CiaynAgent(model, tools, max_tokens=max_input_tokens, config=config)
|
return CiaynAgent(model, tools, max_tokens=max_input_tokens, config=config)
|
||||||
|
|
@ -343,7 +358,9 @@ def create_agent(
|
||||||
config = get_config_repository().get_all()
|
config = get_config_repository().get_all()
|
||||||
max_input_tokens = get_model_token_limit(config, agent_type)
|
max_input_tokens = get_model_token_limit(config, agent_type)
|
||||||
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
|
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
|
||||||
return create_react_agent(model, tools, interrupt_after=['tools'], **agent_kwargs)
|
return create_react_agent(
|
||||||
|
model, tools, interrupt_after=["tools"], **agent_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_research_agent(
|
def run_research_agent(
|
||||||
|
|
@ -406,7 +423,9 @@ def run_research_agent(
|
||||||
recent_inputs = human_input_repository.get_recent(1)
|
recent_inputs = human_input_repository.get_recent(1)
|
||||||
if recent_inputs and len(recent_inputs) > 0:
|
if recent_inputs and len(recent_inputs) > 0:
|
||||||
last_human_input = recent_inputs[0].content
|
last_human_input = recent_inputs[0].content
|
||||||
base_task = f"<last human input>{last_human_input}</last human input>\n{base_task}"
|
base_task = (
|
||||||
|
f"<last human input>{last_human_input}</last human input>\n{base_task}"
|
||||||
|
)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.error(f"Failed to access human input repository: {str(e)}")
|
logger.error(f"Failed to access human input repository: {str(e)}")
|
||||||
# Continue without appending last human input
|
# Continue without appending last human input
|
||||||
|
|
@ -416,7 +435,9 @@ def run_research_agent(
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||||
key_facts = ""
|
key_facts = ""
|
||||||
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
key_snippets = format_key_snippets_dict(
|
||||||
|
get_key_snippet_repository().get_snippets_dict()
|
||||||
|
)
|
||||||
related_files = get_related_files()
|
related_files = get_related_files()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -445,7 +466,9 @@ def run_research_agent(
|
||||||
|
|
||||||
# Check if reasoning assist is explicitly enabled/disabled
|
# Check if reasoning assist is explicitly enabled/disabled
|
||||||
force_assistance = get_config_repository().get("force_reasoning_assistance", False)
|
force_assistance = get_config_repository().get("force_reasoning_assistance", False)
|
||||||
disable_assistance = get_config_repository().get("disable_reasoning_assistance", False)
|
disable_assistance = get_config_repository().get(
|
||||||
|
"disable_reasoning_assistance", False
|
||||||
|
)
|
||||||
if force_assistance:
|
if force_assistance:
|
||||||
reasoning_assist_enabled = True
|
reasoning_assist_enabled = True
|
||||||
elif disable_assistance:
|
elif disable_assistance:
|
||||||
|
|
@ -459,7 +482,9 @@ def run_research_agent(
|
||||||
|
|
||||||
# Get research note information for reasoning assistance
|
# Get research note information for reasoning assistance
|
||||||
try:
|
try:
|
||||||
research_notes = format_research_notes_dict(get_research_note_repository().get_notes_dict())
|
research_notes = format_research_notes_dict(
|
||||||
|
get_research_note_repository().get_notes_dict()
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to get research notes: {e}")
|
logger.warning(f"Failed to get research notes: {e}")
|
||||||
research_notes = ""
|
research_notes = ""
|
||||||
|
|
@ -467,7 +492,10 @@ def run_research_agent(
|
||||||
# If reasoning assist is enabled, make a one-off call to the expert model
|
# If reasoning assist is enabled, make a one-off call to the expert model
|
||||||
if reasoning_assist_enabled:
|
if reasoning_assist_enabled:
|
||||||
try:
|
try:
|
||||||
logger.info("Reasoning assist enabled for model %s, getting expert guidance", model_name)
|
logger.info(
|
||||||
|
"Reasoning assist enabled for model %s, getting expert guidance",
|
||||||
|
model_name,
|
||||||
|
)
|
||||||
|
|
||||||
# Collect tool descriptions
|
# Collect tool descriptions
|
||||||
tool_metadata = []
|
tool_metadata = []
|
||||||
|
|
@ -503,7 +531,13 @@ def run_research_agent(
|
||||||
|
|
||||||
# Show the reasoning assist query in a panel
|
# Show the reasoning assist query in a panel
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Markdown("Consulting with the reasoning model on the best research approach."), title="📝 Thinking about research strategy...", border_style="yellow")
|
Panel(
|
||||||
|
Markdown(
|
||||||
|
"Consulting with the reasoning model on the best research approach."
|
||||||
|
),
|
||||||
|
title="📝 Thinking about research strategy...",
|
||||||
|
border_style="yellow",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Invoking expert model for reasoning assist")
|
logger.debug("Invoking expert model for reasoning assist")
|
||||||
|
|
@ -517,7 +551,7 @@ def run_research_agent(
|
||||||
# Get response content, handling if it's a list (for Claude thinking mode)
|
# Get response content, handling if it's a list (for Claude thinking mode)
|
||||||
content = None
|
content = None
|
||||||
|
|
||||||
if hasattr(response, 'content'):
|
if hasattr(response, "content"):
|
||||||
content = response.content
|
content = response.content
|
||||||
else:
|
else:
|
||||||
# Fallback if content attribute is missing
|
# Fallback if content attribute is missing
|
||||||
|
|
@ -533,19 +567,27 @@ def run_research_agent(
|
||||||
for item in content:
|
for item in content:
|
||||||
if isinstance(item, dict):
|
if isinstance(item, dict):
|
||||||
# Extract thinking content
|
# Extract thinking content
|
||||||
if item.get('type') == 'thinking' and 'thinking' in item:
|
if item.get("type") == "thinking" and "thinking" in item:
|
||||||
thinking_content = item['thinking']
|
thinking_content = item["thinking"]
|
||||||
logger.debug("Found structured thinking content")
|
logger.debug("Found structured thinking content")
|
||||||
# Extract response text
|
# Extract response text
|
||||||
elif item.get('type') == 'text' and 'text' in item:
|
elif item.get("type") == "text" and "text" in item:
|
||||||
response_text = item['text']
|
response_text = item["text"]
|
||||||
logger.debug("Found structured response text")
|
logger.debug("Found structured response text")
|
||||||
|
|
||||||
# Display thinking content in a separate panel if available
|
# Display thinking content in a separate panel if available
|
||||||
if thinking_content and get_config_repository().get("show_thoughts", False):
|
if thinking_content and get_config_repository().get(
|
||||||
logger.debug(f"Displaying structured thinking content ({len(thinking_content)} chars)")
|
"show_thoughts", False
|
||||||
|
):
|
||||||
|
logger.debug(
|
||||||
|
f"Displaying structured thinking content ({len(thinking_content)} chars)"
|
||||||
|
)
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Markdown(thinking_content), title="💭 Expert Thinking", border_style="yellow")
|
Panel(
|
||||||
|
Markdown(thinking_content),
|
||||||
|
title="💭 Expert Thinking",
|
||||||
|
border_style="yellow",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use response_text if available, otherwise fall back to joining
|
# Use response_text if available, otherwise fall back to joining
|
||||||
|
|
@ -553,9 +595,11 @@ def run_research_agent(
|
||||||
content = response_text
|
content = response_text
|
||||||
else:
|
else:
|
||||||
# Fallback: join list items if structured extraction failed
|
# Fallback: join list items if structured extraction failed
|
||||||
logger.debug("No structured response text found, joining list items")
|
logger.debug(
|
||||||
|
"No structured response text found, joining list items"
|
||||||
|
)
|
||||||
content = "\n".join(str(item) for item in content)
|
content = "\n".join(str(item) for item in content)
|
||||||
elif (supports_think_tag or supports_thinking):
|
elif supports_think_tag or supports_thinking:
|
||||||
# Process thinking content using the centralized function
|
# Process thinking content using the centralized function
|
||||||
content, _ = process_thinking_content(
|
content, _ = process_thinking_content(
|
||||||
content=content,
|
content=content,
|
||||||
|
|
@ -563,16 +607,22 @@ def run_research_agent(
|
||||||
supports_thinking=supports_thinking,
|
supports_thinking=supports_thinking,
|
||||||
panel_title="💭 Expert Thinking",
|
panel_title="💭 Expert Thinking",
|
||||||
panel_style="yellow",
|
panel_style="yellow",
|
||||||
logger=logger
|
logger=logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Display the expert guidance in a panel
|
# Display the expert guidance in a panel
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Markdown(content), title="Research Strategy Guidance", border_style="blue")
|
Panel(
|
||||||
|
Markdown(content),
|
||||||
|
title="Research Strategy Guidance",
|
||||||
|
border_style="blue",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use the content as expert guidance
|
# Use the content as expert guidance
|
||||||
expert_guidance = content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY DURING RESEARCH"
|
expert_guidance = (
|
||||||
|
content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY DURING RESEARCH"
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Received expert guidance for research")
|
logger.info("Received expert guidance for research")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -643,9 +693,7 @@ def run_research_agent(
|
||||||
if agent is not None:
|
if agent is not None:
|
||||||
logger.debug("Research agent created successfully")
|
logger.debug("Research agent created successfully")
|
||||||
none_or_fallback_handler = init_fallback_handler(agent, tools)
|
none_or_fallback_handler = init_fallback_handler(agent, tools)
|
||||||
_result = run_agent_with_retry(
|
_result = run_agent_with_retry(agent, prompt, none_or_fallback_handler)
|
||||||
agent, prompt, none_or_fallback_handler
|
|
||||||
)
|
|
||||||
if _result:
|
if _result:
|
||||||
# Log research completion
|
# Log research completion
|
||||||
log_work_event(f"Completed research phase for: {base_task_or_query}")
|
log_work_event(f"Completed research phase for: {base_task_or_query}")
|
||||||
|
|
@ -731,7 +779,9 @@ def run_web_research_agent(
|
||||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||||
key_facts = ""
|
key_facts = ""
|
||||||
try:
|
try:
|
||||||
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
key_snippets = format_key_snippets_dict(
|
||||||
|
get_key_snippet_repository().get_snippets_dict()
|
||||||
|
)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
||||||
key_snippets = ""
|
key_snippets = ""
|
||||||
|
|
@ -771,9 +821,7 @@ def run_web_research_agent(
|
||||||
|
|
||||||
logger.debug("Web research agent completed successfully")
|
logger.debug("Web research agent completed successfully")
|
||||||
none_or_fallback_handler = init_fallback_handler(agent, tools)
|
none_or_fallback_handler = init_fallback_handler(agent, tools)
|
||||||
_result = run_agent_with_retry(
|
_result = run_agent_with_retry(agent, prompt, none_or_fallback_handler)
|
||||||
agent, prompt, none_or_fallback_handler
|
|
||||||
)
|
|
||||||
if _result:
|
if _result:
|
||||||
# Log web research completion
|
# Log web research completion
|
||||||
log_work_event(f"Completed web research phase for: {query}")
|
log_work_event(f"Completed web research phase for: {query}")
|
||||||
|
|
@ -844,7 +892,9 @@ def run_planning_agent(
|
||||||
|
|
||||||
# Check if reasoning assist is explicitly enabled/disabled
|
# Check if reasoning assist is explicitly enabled/disabled
|
||||||
force_assistance = get_config_repository().get("force_reasoning_assistance", False)
|
force_assistance = get_config_repository().get("force_reasoning_assistance", False)
|
||||||
disable_assistance = get_config_repository().get("disable_reasoning_assistance", False)
|
disable_assistance = get_config_repository().get(
|
||||||
|
"disable_reasoning_assistance", False
|
||||||
|
)
|
||||||
|
|
||||||
if force_assistance:
|
if force_assistance:
|
||||||
reasoning_assist_enabled = True
|
reasoning_assist_enabled = True
|
||||||
|
|
@ -869,7 +919,9 @@ def run_planning_agent(
|
||||||
|
|
||||||
# Make sure key_snippets is defined before using it
|
# Make sure key_snippets is defined before using it
|
||||||
try:
|
try:
|
||||||
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
key_snippets = format_key_snippets_dict(
|
||||||
|
get_key_snippet_repository().get_snippets_dict()
|
||||||
|
)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
||||||
key_snippets = ""
|
key_snippets = ""
|
||||||
|
|
@ -898,7 +950,10 @@ def run_planning_agent(
|
||||||
# If reasoning assist is enabled, make a one-off call to the expert model
|
# If reasoning assist is enabled, make a one-off call to the expert model
|
||||||
if reasoning_assist_enabled:
|
if reasoning_assist_enabled:
|
||||||
try:
|
try:
|
||||||
logger.info("Reasoning assist enabled for model %s, getting expert guidance", model_name)
|
logger.info(
|
||||||
|
"Reasoning assist enabled for model %s, getting expert guidance",
|
||||||
|
model_name,
|
||||||
|
)
|
||||||
|
|
||||||
# Collect tool descriptions
|
# Collect tool descriptions
|
||||||
tool_metadata = []
|
tool_metadata = []
|
||||||
|
|
@ -934,7 +989,13 @@ def run_planning_agent(
|
||||||
|
|
||||||
# Show the reasoning assist query in a panel
|
# Show the reasoning assist query in a panel
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Markdown("Consulting with the reasoning model on the best way to do this."), title="📝 Thinking about the plan...", border_style="yellow")
|
Panel(
|
||||||
|
Markdown(
|
||||||
|
"Consulting with the reasoning model on the best way to do this."
|
||||||
|
),
|
||||||
|
title="📝 Thinking about the plan...",
|
||||||
|
border_style="yellow",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Invoking expert model for reasoning assist")
|
logger.debug("Invoking expert model for reasoning assist")
|
||||||
|
|
@ -948,7 +1009,7 @@ def run_planning_agent(
|
||||||
# Get response content, handling if it's a list (for Claude thinking mode)
|
# Get response content, handling if it's a list (for Claude thinking mode)
|
||||||
content = None
|
content = None
|
||||||
|
|
||||||
if hasattr(response, 'content'):
|
if hasattr(response, "content"):
|
||||||
content = response.content
|
content = response.content
|
||||||
else:
|
else:
|
||||||
# Fallback if content attribute is missing
|
# Fallback if content attribute is missing
|
||||||
|
|
@ -964,19 +1025,27 @@ def run_planning_agent(
|
||||||
for item in content:
|
for item in content:
|
||||||
if isinstance(item, dict):
|
if isinstance(item, dict):
|
||||||
# Extract thinking content
|
# Extract thinking content
|
||||||
if item.get('type') == 'thinking' and 'thinking' in item:
|
if item.get("type") == "thinking" and "thinking" in item:
|
||||||
thinking_content = item['thinking']
|
thinking_content = item["thinking"]
|
||||||
logger.debug("Found structured thinking content")
|
logger.debug("Found structured thinking content")
|
||||||
# Extract response text
|
# Extract response text
|
||||||
elif item.get('type') == 'text' and 'text' in item:
|
elif item.get("type") == "text" and "text" in item:
|
||||||
response_text = item['text']
|
response_text = item["text"]
|
||||||
logger.debug("Found structured response text")
|
logger.debug("Found structured response text")
|
||||||
|
|
||||||
# Display thinking content in a separate panel if available
|
# Display thinking content in a separate panel if available
|
||||||
if thinking_content and get_config_repository().get("show_thoughts", False):
|
if thinking_content and get_config_repository().get(
|
||||||
logger.debug(f"Displaying structured thinking content ({len(thinking_content)} chars)")
|
"show_thoughts", False
|
||||||
|
):
|
||||||
|
logger.debug(
|
||||||
|
f"Displaying structured thinking content ({len(thinking_content)} chars)"
|
||||||
|
)
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Markdown(thinking_content), title="💭 Expert Thinking", border_style="yellow")
|
Panel(
|
||||||
|
Markdown(thinking_content),
|
||||||
|
title="💭 Expert Thinking",
|
||||||
|
border_style="yellow",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use response_text if available, otherwise fall back to joining
|
# Use response_text if available, otherwise fall back to joining
|
||||||
|
|
@ -984,9 +1053,11 @@ def run_planning_agent(
|
||||||
content = response_text
|
content = response_text
|
||||||
else:
|
else:
|
||||||
# Fallback: join list items if structured extraction failed
|
# Fallback: join list items if structured extraction failed
|
||||||
logger.debug("No structured response text found, joining list items")
|
logger.debug(
|
||||||
|
"No structured response text found, joining list items"
|
||||||
|
)
|
||||||
content = "\n".join(str(item) for item in content)
|
content = "\n".join(str(item) for item in content)
|
||||||
elif (supports_think_tag or supports_thinking):
|
elif supports_think_tag or supports_thinking:
|
||||||
# Process thinking content using the centralized function
|
# Process thinking content using the centralized function
|
||||||
content, _ = process_thinking_content(
|
content, _ = process_thinking_content(
|
||||||
content=content,
|
content=content,
|
||||||
|
|
@ -994,16 +1065,20 @@ def run_planning_agent(
|
||||||
supports_thinking=supports_thinking,
|
supports_thinking=supports_thinking,
|
||||||
panel_title="💭 Expert Thinking",
|
panel_title="💭 Expert Thinking",
|
||||||
panel_style="yellow",
|
panel_style="yellow",
|
||||||
logger=logger
|
logger=logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Display the expert guidance in a panel
|
# Display the expert guidance in a panel
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Markdown(content), title="Reasoning Guidance", border_style="blue")
|
Panel(
|
||||||
|
Markdown(content), title="Reasoning Guidance", border_style="blue"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use the content as expert guidance
|
# Use the content as expert guidance
|
||||||
expert_guidance = content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY ON THIS TASK"
|
expert_guidance = (
|
||||||
|
content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY ON THIS TASK"
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Received expert guidance for planning")
|
logger.info("Received expert guidance for planning")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -1050,7 +1125,9 @@ def run_planning_agent(
|
||||||
)
|
)
|
||||||
|
|
||||||
config_values = get_config_repository().get_all()
|
config_values = get_config_repository().get_all()
|
||||||
recursion_limit = get_config_repository().get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
recursion_limit = get_config_repository().get(
|
||||||
|
"recursion_limit", DEFAULT_RECURSION_LIMIT
|
||||||
|
)
|
||||||
run_config = {
|
run_config = {
|
||||||
"configurable": {"thread_id": thread_id},
|
"configurable": {"thread_id": thread_id},
|
||||||
"recursion_limit": recursion_limit,
|
"recursion_limit": recursion_limit,
|
||||||
|
|
@ -1060,9 +1137,7 @@ def run_planning_agent(
|
||||||
try:
|
try:
|
||||||
logger.debug("Planning agent completed successfully")
|
logger.debug("Planning agent completed successfully")
|
||||||
none_or_fallback_handler = init_fallback_handler(agent, tools)
|
none_or_fallback_handler = init_fallback_handler(agent, tools)
|
||||||
_result = run_agent_with_retry(
|
_result = run_agent_with_retry(agent, planning_prompt, none_or_fallback_handler)
|
||||||
agent, planning_prompt, none_or_fallback_handler
|
|
||||||
)
|
|
||||||
if _result:
|
if _result:
|
||||||
# Log planning completion
|
# Log planning completion
|
||||||
log_work_event(f"Completed planning phase for: {base_task}")
|
log_work_event(f"Completed planning phase for: {base_task}")
|
||||||
|
|
@ -1168,7 +1243,9 @@ def run_task_implementation_agent(
|
||||||
|
|
||||||
# Check if reasoning assist is explicitly enabled/disabled
|
# Check if reasoning assist is explicitly enabled/disabled
|
||||||
force_assistance = get_config_repository().get("force_reasoning_assistance", False)
|
force_assistance = get_config_repository().get("force_reasoning_assistance", False)
|
||||||
disable_assistance = get_config_repository().get("disable_reasoning_assistance", False)
|
disable_assistance = get_config_repository().get(
|
||||||
|
"disable_reasoning_assistance", False
|
||||||
|
)
|
||||||
|
|
||||||
if force_assistance:
|
if force_assistance:
|
||||||
reasoning_assist_enabled = True
|
reasoning_assist_enabled = True
|
||||||
|
|
@ -1186,7 +1263,10 @@ def run_task_implementation_agent(
|
||||||
# If reasoning assist is enabled, make a one-off call to the expert model
|
# If reasoning assist is enabled, make a one-off call to the expert model
|
||||||
if reasoning_assist_enabled:
|
if reasoning_assist_enabled:
|
||||||
try:
|
try:
|
||||||
logger.info("Reasoning assist enabled for model %s, getting implementation guidance", model_name)
|
logger.info(
|
||||||
|
"Reasoning assist enabled for model %s, getting implementation guidance",
|
||||||
|
model_name,
|
||||||
|
)
|
||||||
|
|
||||||
# Collect tool descriptions
|
# Collect tool descriptions
|
||||||
tool_metadata = []
|
tool_metadata = []
|
||||||
|
|
@ -1197,7 +1277,9 @@ def run_task_implementation_agent(
|
||||||
tool_info = get_tool_info(tool.func)
|
tool_info = get_tool_info(tool.func)
|
||||||
name = tool.func.__name__
|
name = tool.func.__name__
|
||||||
description = inspect.getdoc(tool.func)
|
description = inspect.getdoc(tool.func)
|
||||||
tool_metadata.append(f"Tool: {name}\\nDescription: {description}\\n")
|
tool_metadata.append(
|
||||||
|
f"Tool: {name}\\nDescription: {description}\\n"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error getting tool info for {tool}: {e}")
|
logger.warning(f"Error getting tool info for {tool}: {e}")
|
||||||
|
|
||||||
|
|
@ -1213,7 +1295,9 @@ def run_task_implementation_agent(
|
||||||
working_directory=working_directory,
|
working_directory=working_directory,
|
||||||
task=task,
|
task=task,
|
||||||
key_facts=key_facts,
|
key_facts=key_facts,
|
||||||
key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()),
|
key_snippets=format_key_snippets_dict(
|
||||||
|
get_key_snippet_repository().get_snippets_dict()
|
||||||
|
),
|
||||||
research_notes=formatted_research_notes,
|
research_notes=formatted_research_notes,
|
||||||
related_files="\\n".join(related_files),
|
related_files="\\n".join(related_files),
|
||||||
env_inv=env_inv,
|
env_inv=env_inv,
|
||||||
|
|
@ -1222,7 +1306,13 @@ def run_task_implementation_agent(
|
||||||
|
|
||||||
# Show the reasoning assist query in a panel
|
# Show the reasoning assist query in a panel
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Markdown("Consulting with the reasoning model on the best implementation approach."), title="📝 Thinking about implementation...", border_style="yellow")
|
Panel(
|
||||||
|
Markdown(
|
||||||
|
"Consulting with the reasoning model on the best implementation approach."
|
||||||
|
),
|
||||||
|
title="📝 Thinking about implementation...",
|
||||||
|
border_style="yellow",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Invoking expert model for implementation reasoning assist")
|
logger.debug("Invoking expert model for implementation reasoning assist")
|
||||||
|
|
@ -1236,7 +1326,7 @@ def run_task_implementation_agent(
|
||||||
# Process response content
|
# Process response content
|
||||||
content = None
|
content = None
|
||||||
|
|
||||||
if hasattr(response, 'content'):
|
if hasattr(response, "content"):
|
||||||
content = response.content
|
content = response.content
|
||||||
else:
|
else:
|
||||||
# Fallback if content attribute is missing
|
# Fallback if content attribute is missing
|
||||||
|
|
@ -1249,12 +1339,16 @@ def run_task_implementation_agent(
|
||||||
supports_thinking=supports_thinking,
|
supports_thinking=supports_thinking,
|
||||||
panel_title="💭 Implementation Thinking",
|
panel_title="💭 Implementation Thinking",
|
||||||
panel_style="yellow",
|
panel_style="yellow",
|
||||||
logger=logger
|
logger=logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Display the implementation guidance in a panel
|
# Display the implementation guidance in a panel
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Markdown(content), title="Implementation Guidance", border_style="blue")
|
Panel(
|
||||||
|
Markdown(content),
|
||||||
|
title="Implementation Guidance",
|
||||||
|
border_style="blue",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Format the implementation guidance section for the prompt
|
# Format the implementation guidance section for the prompt
|
||||||
|
|
@ -1276,7 +1370,9 @@ def run_task_implementation_agent(
|
||||||
plan=plan,
|
plan=plan,
|
||||||
related_files=related_files,
|
related_files=related_files,
|
||||||
key_facts=key_facts,
|
key_facts=key_facts,
|
||||||
key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()),
|
key_snippets=format_key_snippets_dict(
|
||||||
|
get_key_snippet_repository().get_snippets_dict()
|
||||||
|
),
|
||||||
research_notes=formatted_research_notes,
|
research_notes=formatted_research_notes,
|
||||||
work_log=get_work_log_repository().format_work_log(),
|
work_log=get_work_log_repository().format_work_log(),
|
||||||
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
|
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
|
||||||
|
|
@ -1296,7 +1392,9 @@ def run_task_implementation_agent(
|
||||||
)
|
)
|
||||||
|
|
||||||
config_values = get_config_repository().get_all()
|
config_values = get_config_repository().get_all()
|
||||||
recursion_limit = get_config_repository().get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
recursion_limit = get_config_repository().get(
|
||||||
|
"recursion_limit", DEFAULT_RECURSION_LIMIT
|
||||||
|
)
|
||||||
run_config = {
|
run_config = {
|
||||||
"configurable": {"thread_id": thread_id},
|
"configurable": {"thread_id": thread_id},
|
||||||
"recursion_limit": recursion_limit,
|
"recursion_limit": recursion_limit,
|
||||||
|
|
@ -1306,9 +1404,7 @@ def run_task_implementation_agent(
|
||||||
try:
|
try:
|
||||||
logger.debug("Implementation agent completed successfully")
|
logger.debug("Implementation agent completed successfully")
|
||||||
none_or_fallback_handler = init_fallback_handler(agent, tools)
|
none_or_fallback_handler = init_fallback_handler(agent, tools)
|
||||||
_result = run_agent_with_retry(
|
_result = run_agent_with_retry(agent, prompt, none_or_fallback_handler)
|
||||||
agent, prompt, none_or_fallback_handler
|
|
||||||
)
|
|
||||||
if _result:
|
if _result:
|
||||||
# Log task implementation completion
|
# Log task implementation completion
|
||||||
log_work_event(f"Completed implementation of task: {task}")
|
log_work_event(f"Completed implementation of task: {task}")
|
||||||
|
|
@ -1380,19 +1476,29 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
|
||||||
# 1. Check if this is a ValueError with 429 code or rate limit phrases
|
# 1. Check if this is a ValueError with 429 code or rate limit phrases
|
||||||
if isinstance(e, ValueError):
|
if isinstance(e, ValueError):
|
||||||
error_str = str(e).lower()
|
error_str = str(e).lower()
|
||||||
rate_limit_phrases = ["429", "rate limit", "too many requests", "quota exceeded"]
|
rate_limit_phrases = [
|
||||||
if "code" not in error_str and not any(phrase in error_str for phrase in rate_limit_phrases):
|
"429",
|
||||||
|
"rate limit",
|
||||||
|
"too many requests",
|
||||||
|
"quota exceeded",
|
||||||
|
]
|
||||||
|
if "code" not in error_str and not any(
|
||||||
|
phrase in error_str for phrase in rate_limit_phrases
|
||||||
|
):
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
# 2. Check for status_code or http_status attribute equal to 429
|
# 2. Check for status_code or http_status attribute equal to 429
|
||||||
if hasattr(e, 'status_code') and e.status_code == 429:
|
if hasattr(e, "status_code") and e.status_code == 429:
|
||||||
pass # This is a rate limit error, continue with retry logic
|
pass # This is a rate limit error, continue with retry logic
|
||||||
elif hasattr(e, 'http_status') and e.http_status == 429:
|
elif hasattr(e, "http_status") and e.http_status == 429:
|
||||||
pass # This is a rate limit error, continue with retry logic
|
pass # This is a rate limit error, continue with retry logic
|
||||||
# 3. Check for rate limit phrases in error message
|
# 3. Check for rate limit phrases in error message
|
||||||
elif isinstance(e, Exception) and not isinstance(e, ValueError):
|
elif isinstance(e, Exception) and not isinstance(e, ValueError):
|
||||||
error_str = str(e).lower()
|
error_str = str(e).lower()
|
||||||
if not any(phrase in error_str for phrase in ["rate limit", "too many requests", "quota exceeded", "429"]) and not ("rate" in error_str and "limit" in error_str):
|
if not any(
|
||||||
|
phrase in error_str
|
||||||
|
for phrase in ["rate limit", "too many requests", "quota exceeded", "429"]
|
||||||
|
) and not ("rate" in error_str and "limit" in error_str):
|
||||||
# This doesn't look like a rate limit error, but we'll still retry other API errors
|
# This doesn't look like a rate limit error, but we'll still retry other API errors
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -1468,22 +1574,39 @@ def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage]):
|
||||||
human-in-the-loop interruptions using interrupt_after=["tools"].
|
human-in-the-loop interruptions using interrupt_after=["tools"].
|
||||||
"""
|
"""
|
||||||
config = get_config_repository().get_all()
|
config = get_config_repository().get_all()
|
||||||
|
stream_config = config.copy()
|
||||||
|
|
||||||
|
cb = None
|
||||||
|
if is_anthropic_claude(config):
|
||||||
|
model_name = config.get("model", "")
|
||||||
|
full_model_name = model_name
|
||||||
|
cb = AnthropicCallbackHandler(full_model_name)
|
||||||
|
|
||||||
|
if "callbacks" not in stream_config:
|
||||||
|
stream_config["callbacks"] = []
|
||||||
|
stream_config["callbacks"].append(cb)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# Process each chunk from the agent stream.
|
for chunk in agent.stream({"messages": msg_list}, stream_config):
|
||||||
for chunk in agent.stream({"messages": msg_list}, config):
|
|
||||||
logger.debug("Agent output: %s", chunk)
|
logger.debug("Agent output: %s", chunk)
|
||||||
check_interrupt()
|
check_interrupt()
|
||||||
agent_type = get_agent_type(agent)
|
agent_type = get_agent_type(agent)
|
||||||
print_agent_output(chunk, agent_type)
|
print_agent_output(chunk, agent_type)
|
||||||
|
|
||||||
if is_completed() or should_exit():
|
if is_completed() or should_exit():
|
||||||
reset_completion_flags()
|
reset_completion_flags()
|
||||||
return True # Exit immediately when finished or signaled to exit.
|
if cb:
|
||||||
|
logger.debug(f"AnthropicCallbackHandler:\n{cb}")
|
||||||
|
return True
|
||||||
|
|
||||||
logger.debug("Stream iteration ended; checking agent state for continuation.")
|
logger.debug("Stream iteration ended; checking agent state for continuation.")
|
||||||
|
|
||||||
# Prepare state configuration, ensuring 'configurable' is present.
|
# Prepare state configuration, ensuring 'configurable' is present.
|
||||||
state_config = get_config_repository().get_all().copy()
|
state_config = get_config_repository().get_all().copy()
|
||||||
if "configurable" not in state_config:
|
if "configurable" not in state_config:
|
||||||
logger.debug("Key 'configurable' not found in config; adding it as an empty dict.")
|
logger.debug(
|
||||||
|
"Key 'configurable' not found in config; adding it as an empty dict."
|
||||||
|
)
|
||||||
state_config["configurable"] = {}
|
state_config["configurable"] = {}
|
||||||
logger.debug("Using state_config for agent.get_state(): %s", state_config)
|
logger.debug("Using state_config for agent.get_state(): %s", state_config)
|
||||||
|
|
||||||
|
|
@ -1491,21 +1614,27 @@ def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage]):
|
||||||
state = agent.get_state(state_config)
|
state = agent.get_state(state_config)
|
||||||
logger.debug("Agent state retrieved: %s", state)
|
logger.debug("Agent state retrieved: %s", state)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error retrieving agent state with state_config %s: %s", state_config, e)
|
logger.error(
|
||||||
|
"Error retrieving agent state with state_config %s: %s", state_config, e
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# If the state indicates that further steps remain (i.e. state.next is non-empty),
|
|
||||||
# then resume execution by invoking the agent with no new input.
|
|
||||||
if state.next:
|
if state.next:
|
||||||
logger.debug("State indicates continuation (state.next: %s); resuming execution.", state.next)
|
logger.debug(
|
||||||
agent.invoke(None, config)
|
"State indicates continuation (state.next: %s); resuming execution.",
|
||||||
|
state.next,
|
||||||
|
)
|
||||||
|
agent.invoke(None, stream_config)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
logger.debug("No continuation indicated in state; exiting stream loop.")
|
logger.debug("No continuation indicated in state; exiting stream loop.")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if cb:
|
||||||
|
logger.debug(f"AnthropicCallbackHandler:\n{cb}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def run_agent_with_retry(
|
def run_agent_with_retry(
|
||||||
agent: RAgents,
|
agent: RAgents,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
|
@ -1517,7 +1646,9 @@ def run_agent_with_retry(
|
||||||
max_retries = 20
|
max_retries = 20
|
||||||
base_delay = 1
|
base_delay = 1
|
||||||
test_attempts = 0
|
test_attempts = 0
|
||||||
_max_test_retries = get_config_repository().get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES)
|
_max_test_retries = get_config_repository().get(
|
||||||
|
"max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES
|
||||||
|
)
|
||||||
auto_test = get_config_repository().get("auto_test", False)
|
auto_test = get_config_repository().get("auto_test", False)
|
||||||
original_prompt = prompt
|
original_prompt = prompt
|
||||||
msg_list = [HumanMessage(content=prompt)]
|
msg_list = [HumanMessage(content=prompt)]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,270 @@
|
||||||
|
"""Custom callback handlers for tracking token usage and costs."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
|
|
||||||
|
# Define cost per 1K tokens for various models
|
||||||
|
ANTHROPIC_MODEL_COSTS = {
|
||||||
|
# Claude 3.7 Sonnet input
|
||||||
|
"claude-3-7-sonnet-20250219": 0.003,
|
||||||
|
"anthropic/claude-3.7-sonnet": 0.003,
|
||||||
|
"claude-3.7-sonnet": 0.003,
|
||||||
|
# Claude 3.7 Sonnet output
|
||||||
|
"claude-3-7-sonnet-20250219-completion": 0.015,
|
||||||
|
"anthropic/claude-3.7-sonnet-completion": 0.015,
|
||||||
|
"claude-3.7-sonnet-completion": 0.015,
|
||||||
|
# Claude 3 Opus input
|
||||||
|
"claude-3-opus-20240229": 0.015,
|
||||||
|
"anthropic/claude-3-opus": 0.015,
|
||||||
|
"claude-3-opus": 0.015,
|
||||||
|
# Claude 3 Opus output
|
||||||
|
"claude-3-opus-20240229-completion": 0.075,
|
||||||
|
"anthropic/claude-3-opus-completion": 0.075,
|
||||||
|
"claude-3-opus-completion": 0.075,
|
||||||
|
# Claude 3 Sonnet input
|
||||||
|
"claude-3-sonnet-20240229": 0.003,
|
||||||
|
"anthropic/claude-3-sonnet": 0.003,
|
||||||
|
"claude-3-sonnet": 0.003,
|
||||||
|
# Claude 3 Sonnet output
|
||||||
|
"claude-3-sonnet-20240229-completion": 0.015,
|
||||||
|
"anthropic/claude-3-sonnet-completion": 0.015,
|
||||||
|
"claude-3-sonnet-completion": 0.015,
|
||||||
|
# Claude 3 Haiku input
|
||||||
|
"claude-3-haiku-20240307": 0.00025,
|
||||||
|
"anthropic/claude-3-haiku": 0.00025,
|
||||||
|
"claude-3-haiku": 0.00025,
|
||||||
|
# Claude 3 Haiku output
|
||||||
|
"claude-3-haiku-20240307-completion": 0.00125,
|
||||||
|
"anthropic/claude-3-haiku-completion": 0.00125,
|
||||||
|
"claude-3-haiku-completion": 0.00125,
|
||||||
|
# Claude 2 input
|
||||||
|
"claude-2": 0.008,
|
||||||
|
"claude-2.0": 0.008,
|
||||||
|
"claude-2.1": 0.008,
|
||||||
|
# Claude 2 output
|
||||||
|
"claude-2-completion": 0.024,
|
||||||
|
"claude-2.0-completion": 0.024,
|
||||||
|
"claude-2.1-completion": 0.024,
|
||||||
|
# Claude Instant input
|
||||||
|
"claude-instant-1": 0.0016,
|
||||||
|
"claude-instant-1.2": 0.0016,
|
||||||
|
# Claude Instant output
|
||||||
|
"claude-instant-1-completion": 0.0055,
|
||||||
|
"claude-instant-1.2-completion": 0.0055,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def standardize_model_name(model_name: str, is_completion: bool = False) -> str:
|
||||||
|
"""
|
||||||
|
Standardize the model name to a format that can be used for cost calculation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Model name to standardize.
|
||||||
|
is_completion: Whether the model is used for completion or not.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Standardized model name.
|
||||||
|
"""
|
||||||
|
if not model_name:
|
||||||
|
model_name = "claude-3-sonnet"
|
||||||
|
|
||||||
|
model_name = model_name.lower()
|
||||||
|
|
||||||
|
# Handle OpenRouter prefixes
|
||||||
|
if model_name.startswith("anthropic/"):
|
||||||
|
model_name = model_name[len("anthropic/") :]
|
||||||
|
|
||||||
|
# Add completion suffix if needed
|
||||||
|
if is_completion and not model_name.endswith("-completion"):
|
||||||
|
model_name = model_name + "-completion"
|
||||||
|
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
|
||||||
|
def get_anthropic_token_cost_for_model(
|
||||||
|
model_name: str, num_tokens: int, is_completion: bool = False
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Get the cost in USD for a given model and number of tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the model
|
||||||
|
num_tokens: Number of tokens.
|
||||||
|
is_completion: Whether the model is used for completion or not.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cost in USD.
|
||||||
|
"""
|
||||||
|
model_name = standardize_model_name(model_name, is_completion)
|
||||||
|
|
||||||
|
if model_name not in ANTHROPIC_MODEL_COSTS:
|
||||||
|
# Default to Claude 3 Sonnet pricing if model not found
|
||||||
|
model_name = (
|
||||||
|
"claude-3-sonnet" if not is_completion else "claude-3-sonnet-completion"
|
||||||
|
)
|
||||||
|
|
||||||
|
cost_per_1k = ANTHROPIC_MODEL_COSTS[model_name]
|
||||||
|
total_cost = cost_per_1k * (num_tokens / 1000)
|
||||||
|
|
||||||
|
return total_cost
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicCallbackHandler(BaseCallbackHandler):
|
||||||
|
"""Callback Handler that tracks Anthropic token usage and costs."""
|
||||||
|
|
||||||
|
total_tokens: int = 0
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
completion_tokens: int = 0
|
||||||
|
successful_requests: int = 0
|
||||||
|
total_cost: float = 0.0
|
||||||
|
model_name: str = "claude-3-sonnet" # Default model
|
||||||
|
|
||||||
|
def __init__(self, model_name: Optional[str] = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
if model_name:
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
# Default costs for Claude 3.7 Sonnet
|
||||||
|
self.input_cost_per_token = 0.003 / 1000 # $3/M input tokens
|
||||||
|
self.output_cost_per_token = 0.015 / 1000 # $15/M output tokens
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"Tokens Used: {self.total_tokens}\n"
|
||||||
|
f"\tPrompt Tokens: {self.prompt_tokens}\n"
|
||||||
|
f"\tCompletion Tokens: {self.completion_tokens}\n"
|
||||||
|
f"Successful Requests: {self.successful_requests}\n"
|
||||||
|
f"Total Cost (USD): ${self.total_cost:.6f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def always_verbose(self) -> bool:
|
||||||
|
"""Whether to call verbose callbacks even if verbose is False."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Record the model name if available."""
|
||||||
|
if "name" in serialized:
|
||||||
|
self.model_name = serialized["name"]
|
||||||
|
|
||||||
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
|
"""Count tokens as they're generated."""
|
||||||
|
with self._lock:
|
||||||
|
self.completion_tokens += 1
|
||||||
|
self.total_tokens += 1
|
||||||
|
token_cost = get_anthropic_token_cost_for_model(
|
||||||
|
self.model_name, 1, is_completion=True
|
||||||
|
)
|
||||||
|
self.total_cost += token_cost
|
||||||
|
|
||||||
|
def on_llm_end(self, response: Any, **kwargs: Any) -> None:
|
||||||
|
"""Collect token usage from response."""
|
||||||
|
token_usage = {}
|
||||||
|
|
||||||
|
# Try to extract token usage from response
|
||||||
|
if hasattr(response, "llm_output") and response.llm_output:
|
||||||
|
llm_output = response.llm_output
|
||||||
|
if "token_usage" in llm_output:
|
||||||
|
token_usage = llm_output["token_usage"]
|
||||||
|
elif "usage" in llm_output:
|
||||||
|
usage = llm_output["usage"]
|
||||||
|
|
||||||
|
# Handle Anthropic's specific usage format
|
||||||
|
if "input_tokens" in usage:
|
||||||
|
token_usage["prompt_tokens"] = usage["input_tokens"]
|
||||||
|
if "output_tokens" in usage:
|
||||||
|
token_usage["completion_tokens"] = usage["output_tokens"]
|
||||||
|
|
||||||
|
# Extract model name if available
|
||||||
|
if "model_name" in llm_output:
|
||||||
|
self.model_name = llm_output["model_name"]
|
||||||
|
|
||||||
|
# Try to get usage from response.usage
|
||||||
|
elif hasattr(response, "usage"):
|
||||||
|
usage = response.usage
|
||||||
|
if hasattr(usage, "prompt_tokens"):
|
||||||
|
token_usage["prompt_tokens"] = usage.prompt_tokens
|
||||||
|
if hasattr(usage, "completion_tokens"):
|
||||||
|
token_usage["completion_tokens"] = usage.completion_tokens
|
||||||
|
if hasattr(usage, "total_tokens"):
|
||||||
|
token_usage["total_tokens"] = usage.total_tokens
|
||||||
|
|
||||||
|
# Extract usage from generations if available
|
||||||
|
elif hasattr(response, "generations") and response.generations:
|
||||||
|
for gen in response.generations:
|
||||||
|
if gen and hasattr(gen[0], "generation_info"):
|
||||||
|
gen_info = gen[0].generation_info or {}
|
||||||
|
if "usage" in gen_info:
|
||||||
|
token_usage = gen_info["usage"]
|
||||||
|
break
|
||||||
|
|
||||||
|
# Update counts with lock to prevent race conditions
|
||||||
|
with self._lock:
|
||||||
|
prompt_tokens = token_usage.get("prompt_tokens", 0)
|
||||||
|
completion_tokens = token_usage.get("completion_tokens", 0)
|
||||||
|
|
||||||
|
# Only update prompt tokens if we have them
|
||||||
|
if prompt_tokens > 0:
|
||||||
|
self.prompt_tokens += prompt_tokens
|
||||||
|
self.total_tokens += prompt_tokens
|
||||||
|
prompt_cost = get_anthropic_token_cost_for_model(
|
||||||
|
self.model_name, prompt_tokens, is_completion=False
|
||||||
|
)
|
||||||
|
self.total_cost += prompt_cost
|
||||||
|
|
||||||
|
# Only update completion tokens if not already counted by on_llm_new_token
|
||||||
|
if completion_tokens > 0 and completion_tokens > self.completion_tokens:
|
||||||
|
additional_tokens = completion_tokens - self.completion_tokens
|
||||||
|
self.completion_tokens = completion_tokens
|
||||||
|
self.total_tokens += additional_tokens
|
||||||
|
completion_cost = get_anthropic_token_cost_for_model(
|
||||||
|
self.model_name, additional_tokens, is_completion=True
|
||||||
|
)
|
||||||
|
self.total_cost += completion_cost
|
||||||
|
|
||||||
|
self.successful_requests += 1
|
||||||
|
|
||||||
|
def __copy__(self) -> "AnthropicCallbackHandler":
|
||||||
|
"""Return a copy of the callback handler."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __deepcopy__(self, memo: Any) -> "AnthropicCallbackHandler":
|
||||||
|
"""Return a deep copy of the callback handler."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
# Create a context variable for our custom callback
|
||||||
|
anthropic_callback_var: ContextVar[Optional[AnthropicCallbackHandler]] = ContextVar(
|
||||||
|
"anthropic_callback", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_anthropic_callback(
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
) -> AnthropicCallbackHandler:
|
||||||
|
"""Get the Anthropic callback handler in a context manager.
|
||||||
|
which conveniently exposes token and cost information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Optional model name to use for cost calculation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AnthropicCallbackHandler: The Anthropic callback handler.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> with get_anthropic_callback("claude-3-sonnet") as cb:
|
||||||
|
... # Use the callback handler
|
||||||
|
... # cb.total_tokens, cb.total_cost will be available after
|
||||||
|
"""
|
||||||
|
cb = AnthropicCallbackHandler(model_name)
|
||||||
|
anthropic_callback_var.set(cb)
|
||||||
|
yield cb
|
||||||
|
anthropic_callback_var.set(None)
|
||||||
Loading…
Reference in New Issue