feat(agent_utils.py): add AnthropicCallbackHandler to track token usage and costs for Anthropic models (#118)
style(agent_utils.py): format imports and code for better readability refactor(agent_utils.py): standardize model name and cost calculation logic for clarity and maintainability chore(anthropic_callback_handler.py): create a new file for the AnthropicCallbackHandler implementation and related functions
This commit is contained in:
parent
d194868cff
commit
2899b5f848
|
|
@ -10,6 +10,9 @@ import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Literal, Optional, Sequence
|
from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||||
|
|
||||||
|
from ra_aid.callbacks.anthropic_callback_handler import AnthropicCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
|
||||||
from openai import RateLimitError as OpenAIRateLimitError
|
from openai import RateLimitError as OpenAIRateLimitError
|
||||||
|
|
@ -71,7 +74,11 @@ from ra_aid.prompts.human_prompts import (
|
||||||
from ra_aid.prompts.implementation_prompts import IMPLEMENTATION_PROMPT
|
from ra_aid.prompts.implementation_prompts import IMPLEMENTATION_PROMPT
|
||||||
from ra_aid.prompts.common_prompts import NEW_PROJECT_HINTS
|
from ra_aid.prompts.common_prompts import NEW_PROJECT_HINTS
|
||||||
from ra_aid.prompts.planning_prompts import PLANNING_PROMPT
|
from ra_aid.prompts.planning_prompts import PLANNING_PROMPT
|
||||||
from ra_aid.prompts.reasoning_assist_prompt import REASONING_ASSIST_PROMPT_PLANNING, REASONING_ASSIST_PROMPT_IMPLEMENTATION, REASONING_ASSIST_PROMPT_RESEARCH
|
from ra_aid.prompts.reasoning_assist_prompt import (
|
||||||
|
REASONING_ASSIST_PROMPT_PLANNING,
|
||||||
|
REASONING_ASSIST_PROMPT_IMPLEMENTATION,
|
||||||
|
REASONING_ASSIST_PROMPT_RESEARCH,
|
||||||
|
)
|
||||||
from ra_aid.prompts.research_prompts import (
|
from ra_aid.prompts.research_prompts import (
|
||||||
RESEARCH_ONLY_PROMPT,
|
RESEARCH_ONLY_PROMPT,
|
||||||
RESEARCH_PROMPT,
|
RESEARCH_PROMPT,
|
||||||
|
|
@ -90,9 +97,15 @@ from ra_aid.tool_configs import (
|
||||||
)
|
)
|
||||||
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
|
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
|
||||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||||
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
from ra_aid.database.repositories.key_snippet_repository import (
|
||||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
get_key_snippet_repository,
|
||||||
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
)
|
||||||
|
from ra_aid.database.repositories.human_input_repository import (
|
||||||
|
get_human_input_repository,
|
||||||
|
)
|
||||||
|
from ra_aid.database.repositories.research_note_repository import (
|
||||||
|
get_research_note_repository,
|
||||||
|
)
|
||||||
from ra_aid.database.repositories.work_log_repository import get_work_log_repository
|
from ra_aid.database.repositories.work_log_repository import get_work_log_repository
|
||||||
from ra_aid.model_formatters import format_key_facts_dict
|
from ra_aid.model_formatters import format_key_facts_dict
|
||||||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||||
|
|
@ -332,7 +345,9 @@ def create_agent(
|
||||||
if is_anthropic_claude(config):
|
if is_anthropic_claude(config):
|
||||||
logger.debug("Using create_react_agent to instantiate agent.")
|
logger.debug("Using create_react_agent to instantiate agent.")
|
||||||
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
|
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
|
||||||
return create_react_agent(model, tools, interrupt_after=['tools'], **agent_kwargs)
|
return create_react_agent(
|
||||||
|
model, tools, interrupt_after=["tools"], **agent_kwargs
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug("Using CiaynAgent agent instance")
|
logger.debug("Using CiaynAgent agent instance")
|
||||||
return CiaynAgent(model, tools, max_tokens=max_input_tokens, config=config)
|
return CiaynAgent(model, tools, max_tokens=max_input_tokens, config=config)
|
||||||
|
|
@ -343,7 +358,9 @@ def create_agent(
|
||||||
config = get_config_repository().get_all()
|
config = get_config_repository().get_all()
|
||||||
max_input_tokens = get_model_token_limit(config, agent_type)
|
max_input_tokens = get_model_token_limit(config, agent_type)
|
||||||
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
|
agent_kwargs = build_agent_kwargs(checkpointer, max_input_tokens)
|
||||||
return create_react_agent(model, tools, interrupt_after=['tools'], **agent_kwargs)
|
return create_react_agent(
|
||||||
|
model, tools, interrupt_after=["tools"], **agent_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_research_agent(
|
def run_research_agent(
|
||||||
|
|
@ -406,7 +423,9 @@ def run_research_agent(
|
||||||
recent_inputs = human_input_repository.get_recent(1)
|
recent_inputs = human_input_repository.get_recent(1)
|
||||||
if recent_inputs and len(recent_inputs) > 0:
|
if recent_inputs and len(recent_inputs) > 0:
|
||||||
last_human_input = recent_inputs[0].content
|
last_human_input = recent_inputs[0].content
|
||||||
base_task = f"<last human input>{last_human_input}</last human input>\n{base_task}"
|
base_task = (
|
||||||
|
f"<last human input>{last_human_input}</last human input>\n{base_task}"
|
||||||
|
)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.error(f"Failed to access human input repository: {str(e)}")
|
logger.error(f"Failed to access human input repository: {str(e)}")
|
||||||
# Continue without appending last human input
|
# Continue without appending last human input
|
||||||
|
|
@ -416,7 +435,9 @@ def run_research_agent(
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||||
key_facts = ""
|
key_facts = ""
|
||||||
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
key_snippets = format_key_snippets_dict(
|
||||||
|
get_key_snippet_repository().get_snippets_dict()
|
||||||
|
)
|
||||||
related_files = get_related_files()
|
related_files = get_related_files()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -432,20 +453,22 @@ def run_research_agent(
|
||||||
human_interaction=hil,
|
human_interaction=hil,
|
||||||
web_research_enabled=get_config_repository().get("web_research_enabled", False),
|
web_research_enabled=get_config_repository().get("web_research_enabled", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get model info for reasoning assistance configuration
|
# Get model info for reasoning assistance configuration
|
||||||
provider = get_config_repository().get("provider", "")
|
provider = get_config_repository().get("provider", "")
|
||||||
model_name = get_config_repository().get("model", "")
|
model_name = get_config_repository().get("model", "")
|
||||||
|
|
||||||
# Get model configuration to check for reasoning_assist_default
|
# Get model configuration to check for reasoning_assist_default
|
||||||
model_config = {}
|
model_config = {}
|
||||||
provider_models = models_params.get(provider, {})
|
provider_models = models_params.get(provider, {})
|
||||||
if provider_models and model_name in provider_models:
|
if provider_models and model_name in provider_models:
|
||||||
model_config = provider_models[model_name]
|
model_config = provider_models[model_name]
|
||||||
|
|
||||||
# Check if reasoning assist is explicitly enabled/disabled
|
# Check if reasoning assist is explicitly enabled/disabled
|
||||||
force_assistance = get_config_repository().get("force_reasoning_assistance", False)
|
force_assistance = get_config_repository().get("force_reasoning_assistance", False)
|
||||||
disable_assistance = get_config_repository().get("disable_reasoning_assistance", False)
|
disable_assistance = get_config_repository().get(
|
||||||
|
"disable_reasoning_assistance", False
|
||||||
|
)
|
||||||
if force_assistance:
|
if force_assistance:
|
||||||
reasoning_assist_enabled = True
|
reasoning_assist_enabled = True
|
||||||
elif disable_assistance:
|
elif disable_assistance:
|
||||||
|
|
@ -453,26 +476,31 @@ def run_research_agent(
|
||||||
else:
|
else:
|
||||||
# Fall back to model default
|
# Fall back to model default
|
||||||
reasoning_assist_enabled = model_config.get("reasoning_assist_default", False)
|
reasoning_assist_enabled = model_config.get("reasoning_assist_default", False)
|
||||||
|
|
||||||
logger.debug("Reasoning assist enabled: %s", reasoning_assist_enabled)
|
logger.debug("Reasoning assist enabled: %s", reasoning_assist_enabled)
|
||||||
expert_guidance = ""
|
expert_guidance = ""
|
||||||
|
|
||||||
# Get research note information for reasoning assistance
|
# Get research note information for reasoning assistance
|
||||||
try:
|
try:
|
||||||
research_notes = format_research_notes_dict(get_research_note_repository().get_notes_dict())
|
research_notes = format_research_notes_dict(
|
||||||
|
get_research_note_repository().get_notes_dict()
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to get research notes: {e}")
|
logger.warning(f"Failed to get research notes: {e}")
|
||||||
research_notes = ""
|
research_notes = ""
|
||||||
|
|
||||||
# If reasoning assist is enabled, make a one-off call to the expert model
|
# If reasoning assist is enabled, make a one-off call to the expert model
|
||||||
if reasoning_assist_enabled:
|
if reasoning_assist_enabled:
|
||||||
try:
|
try:
|
||||||
logger.info("Reasoning assist enabled for model %s, getting expert guidance", model_name)
|
logger.info(
|
||||||
|
"Reasoning assist enabled for model %s, getting expert guidance",
|
||||||
|
model_name,
|
||||||
|
)
|
||||||
|
|
||||||
# Collect tool descriptions
|
# Collect tool descriptions
|
||||||
tool_metadata = []
|
tool_metadata = []
|
||||||
from ra_aid.tools.reflection import get_function_info as get_tool_info
|
from ra_aid.tools.reflection import get_function_info as get_tool_info
|
||||||
|
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
try:
|
try:
|
||||||
tool_info = get_tool_info(tool.func)
|
tool_info = get_tool_info(tool.func)
|
||||||
|
|
@ -481,13 +509,13 @@ def run_research_agent(
|
||||||
tool_metadata.append(f"Tool: {name}\nDescription: {description}\n")
|
tool_metadata.append(f"Tool: {name}\nDescription: {description}\n")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error getting tool info for {tool}: {e}")
|
logger.warning(f"Error getting tool info for {tool}: {e}")
|
||||||
|
|
||||||
# Format tool metadata
|
# Format tool metadata
|
||||||
formatted_tool_metadata = "\n".join(tool_metadata)
|
formatted_tool_metadata = "\n".join(tool_metadata)
|
||||||
|
|
||||||
# Initialize expert model
|
# Initialize expert model
|
||||||
expert_model = initialize_expert_llm(provider, model_name)
|
expert_model = initialize_expert_llm(provider, model_name)
|
||||||
|
|
||||||
# Format the reasoning assist prompt
|
# Format the reasoning assist prompt
|
||||||
reasoning_assist_prompt = REASONING_ASSIST_PROMPT_RESEARCH.format(
|
reasoning_assist_prompt = REASONING_ASSIST_PROMPT_RESEARCH.format(
|
||||||
current_date=current_date,
|
current_date=current_date,
|
||||||
|
|
@ -500,62 +528,78 @@ def run_research_agent(
|
||||||
env_inv=get_env_inv(),
|
env_inv=get_env_inv(),
|
||||||
tool_metadata=formatted_tool_metadata,
|
tool_metadata=formatted_tool_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Show the reasoning assist query in a panel
|
# Show the reasoning assist query in a panel
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Markdown("Consulting with the reasoning model on the best research approach."), title="📝 Thinking about research strategy...", border_style="yellow")
|
Panel(
|
||||||
|
Markdown(
|
||||||
|
"Consulting with the reasoning model on the best research approach."
|
||||||
|
),
|
||||||
|
title="📝 Thinking about research strategy...",
|
||||||
|
border_style="yellow",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Invoking expert model for reasoning assist")
|
logger.debug("Invoking expert model for reasoning assist")
|
||||||
# Make the call to the expert model
|
# Make the call to the expert model
|
||||||
response = expert_model.invoke(reasoning_assist_prompt)
|
response = expert_model.invoke(reasoning_assist_prompt)
|
||||||
|
|
||||||
# Check if the model supports think tags
|
# Check if the model supports think tags
|
||||||
supports_think_tag = model_config.get("supports_think_tag", False)
|
supports_think_tag = model_config.get("supports_think_tag", False)
|
||||||
supports_thinking = model_config.get("supports_thinking", False)
|
supports_thinking = model_config.get("supports_thinking", False)
|
||||||
|
|
||||||
# Get response content, handling if it's a list (for Claude thinking mode)
|
# Get response content, handling if it's a list (for Claude thinking mode)
|
||||||
content = None
|
content = None
|
||||||
|
|
||||||
if hasattr(response, 'content'):
|
if hasattr(response, "content"):
|
||||||
content = response.content
|
content = response.content
|
||||||
else:
|
else:
|
||||||
# Fallback if content attribute is missing
|
# Fallback if content attribute is missing
|
||||||
content = str(response)
|
content = str(response)
|
||||||
|
|
||||||
# Process content based on its type
|
# Process content based on its type
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
# Handle structured thinking mode (e.g., Claude 3.7)
|
# Handle structured thinking mode (e.g., Claude 3.7)
|
||||||
thinking_content = None
|
thinking_content = None
|
||||||
response_text = None
|
response_text = None
|
||||||
|
|
||||||
# Process each item in the list
|
# Process each item in the list
|
||||||
for item in content:
|
for item in content:
|
||||||
if isinstance(item, dict):
|
if isinstance(item, dict):
|
||||||
# Extract thinking content
|
# Extract thinking content
|
||||||
if item.get('type') == 'thinking' and 'thinking' in item:
|
if item.get("type") == "thinking" and "thinking" in item:
|
||||||
thinking_content = item['thinking']
|
thinking_content = item["thinking"]
|
||||||
logger.debug("Found structured thinking content")
|
logger.debug("Found structured thinking content")
|
||||||
# Extract response text
|
# Extract response text
|
||||||
elif item.get('type') == 'text' and 'text' in item:
|
elif item.get("type") == "text" and "text" in item:
|
||||||
response_text = item['text']
|
response_text = item["text"]
|
||||||
logger.debug("Found structured response text")
|
logger.debug("Found structured response text")
|
||||||
|
|
||||||
# Display thinking content in a separate panel if available
|
# Display thinking content in a separate panel if available
|
||||||
if thinking_content and get_config_repository().get("show_thoughts", False):
|
if thinking_content and get_config_repository().get(
|
||||||
logger.debug(f"Displaying structured thinking content ({len(thinking_content)} chars)")
|
"show_thoughts", False
|
||||||
console.print(
|
):
|
||||||
Panel(Markdown(thinking_content), title="💭 Expert Thinking", border_style="yellow")
|
logger.debug(
|
||||||
|
f"Displaying structured thinking content ({len(thinking_content)} chars)"
|
||||||
)
|
)
|
||||||
|
console.print(
|
||||||
|
Panel(
|
||||||
|
Markdown(thinking_content),
|
||||||
|
title="💭 Expert Thinking",
|
||||||
|
border_style="yellow",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Use response_text if available, otherwise fall back to joining
|
# Use response_text if available, otherwise fall back to joining
|
||||||
if response_text:
|
if response_text:
|
||||||
content = response_text
|
content = response_text
|
||||||
else:
|
else:
|
||||||
# Fallback: join list items if structured extraction failed
|
# Fallback: join list items if structured extraction failed
|
||||||
logger.debug("No structured response text found, joining list items")
|
logger.debug(
|
||||||
|
"No structured response text found, joining list items"
|
||||||
|
)
|
||||||
content = "\n".join(str(item) for item in content)
|
content = "\n".join(str(item) for item in content)
|
||||||
elif (supports_think_tag or supports_thinking):
|
elif supports_think_tag or supports_thinking:
|
||||||
# Process thinking content using the centralized function
|
# Process thinking content using the centralized function
|
||||||
content, _ = process_thinking_content(
|
content, _ = process_thinking_content(
|
||||||
content=content,
|
content=content,
|
||||||
|
|
@ -563,22 +607,28 @@ def run_research_agent(
|
||||||
supports_thinking=supports_thinking,
|
supports_thinking=supports_thinking,
|
||||||
panel_title="💭 Expert Thinking",
|
panel_title="💭 Expert Thinking",
|
||||||
panel_style="yellow",
|
panel_style="yellow",
|
||||||
logger=logger
|
logger=logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Display the expert guidance in a panel
|
# Display the expert guidance in a panel
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Markdown(content), title="Research Strategy Guidance", border_style="blue")
|
Panel(
|
||||||
|
Markdown(content),
|
||||||
|
title="Research Strategy Guidance",
|
||||||
|
border_style="blue",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use the content as expert guidance
|
# Use the content as expert guidance
|
||||||
expert_guidance = content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY DURING RESEARCH"
|
expert_guidance = (
|
||||||
|
content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY DURING RESEARCH"
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Received expert guidance for research")
|
logger.info("Received expert guidance for research")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error getting expert guidance for research: %s", e)
|
logger.error("Error getting expert guidance for research: %s", e)
|
||||||
expert_guidance = ""
|
expert_guidance = ""
|
||||||
|
|
||||||
agent = create_agent(model, tools, checkpointer=memory, agent_type="research")
|
agent = create_agent(model, tools, checkpointer=memory, agent_type="research")
|
||||||
|
|
||||||
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
||||||
|
|
@ -588,7 +638,7 @@ def run_research_agent(
|
||||||
if get_config_repository().get("web_research_enabled")
|
if get_config_repository().get("web_research_enabled")
|
||||||
else ""
|
else ""
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare expert guidance section if expert guidance is available
|
# Prepare expert guidance section if expert guidance is available
|
||||||
expert_guidance_section = ""
|
expert_guidance_section = ""
|
||||||
if expert_guidance:
|
if expert_guidance:
|
||||||
|
|
@ -600,7 +650,7 @@ def run_research_agent(
|
||||||
# We get research notes earlier for reasoning assistance
|
# We get research notes earlier for reasoning assistance
|
||||||
|
|
||||||
# Get environment inventory information
|
# Get environment inventory information
|
||||||
|
|
||||||
prompt = (RESEARCH_ONLY_PROMPT if research_only else RESEARCH_PROMPT).format(
|
prompt = (RESEARCH_ONLY_PROMPT if research_only else RESEARCH_PROMPT).format(
|
||||||
current_date=current_date,
|
current_date=current_date,
|
||||||
working_directory=working_directory,
|
working_directory=working_directory,
|
||||||
|
|
@ -643,9 +693,7 @@ def run_research_agent(
|
||||||
if agent is not None:
|
if agent is not None:
|
||||||
logger.debug("Research agent created successfully")
|
logger.debug("Research agent created successfully")
|
||||||
none_or_fallback_handler = init_fallback_handler(agent, tools)
|
none_or_fallback_handler = init_fallback_handler(agent, tools)
|
||||||
_result = run_agent_with_retry(
|
_result = run_agent_with_retry(agent, prompt, none_or_fallback_handler)
|
||||||
agent, prompt, none_or_fallback_handler
|
|
||||||
)
|
|
||||||
if _result:
|
if _result:
|
||||||
# Log research completion
|
# Log research completion
|
||||||
log_work_event(f"Completed research phase for: {base_task_or_query}")
|
log_work_event(f"Completed research phase for: {base_task_or_query}")
|
||||||
|
|
@ -731,7 +779,9 @@ def run_web_research_agent(
|
||||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||||
key_facts = ""
|
key_facts = ""
|
||||||
try:
|
try:
|
||||||
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
key_snippets = format_key_snippets_dict(
|
||||||
|
get_key_snippet_repository().get_snippets_dict()
|
||||||
|
)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
||||||
key_snippets = ""
|
key_snippets = ""
|
||||||
|
|
@ -741,7 +791,7 @@ def run_web_research_agent(
|
||||||
working_directory = os.getcwd()
|
working_directory = os.getcwd()
|
||||||
|
|
||||||
# Get environment inventory information
|
# Get environment inventory information
|
||||||
|
|
||||||
prompt = WEB_RESEARCH_PROMPT.format(
|
prompt = WEB_RESEARCH_PROMPT.format(
|
||||||
current_date=current_date,
|
current_date=current_date,
|
||||||
working_directory=working_directory,
|
working_directory=working_directory,
|
||||||
|
|
@ -771,9 +821,7 @@ def run_web_research_agent(
|
||||||
|
|
||||||
logger.debug("Web research agent completed successfully")
|
logger.debug("Web research agent completed successfully")
|
||||||
none_or_fallback_handler = init_fallback_handler(agent, tools)
|
none_or_fallback_handler = init_fallback_handler(agent, tools)
|
||||||
_result = run_agent_with_retry(
|
_result = run_agent_with_retry(agent, prompt, none_or_fallback_handler)
|
||||||
agent, prompt, none_or_fallback_handler
|
|
||||||
)
|
|
||||||
if _result:
|
if _result:
|
||||||
# Log web research completion
|
# Log web research completion
|
||||||
log_work_event(f"Completed web research phase for: {query}")
|
log_work_event(f"Completed web research phase for: {query}")
|
||||||
|
|
@ -835,17 +883,19 @@ def run_planning_agent(
|
||||||
provider = get_config_repository().get("provider", "")
|
provider = get_config_repository().get("provider", "")
|
||||||
model_name = get_config_repository().get("model", "")
|
model_name = get_config_repository().get("model", "")
|
||||||
logger.debug("Checking for reasoning_assist_default on %s/%s", provider, model_name)
|
logger.debug("Checking for reasoning_assist_default on %s/%s", provider, model_name)
|
||||||
|
|
||||||
# Get model configuration to check for reasoning_assist_default
|
# Get model configuration to check for reasoning_assist_default
|
||||||
model_config = {}
|
model_config = {}
|
||||||
provider_models = models_params.get(provider, {})
|
provider_models = models_params.get(provider, {})
|
||||||
if provider_models and model_name in provider_models:
|
if provider_models and model_name in provider_models:
|
||||||
model_config = provider_models[model_name]
|
model_config = provider_models[model_name]
|
||||||
|
|
||||||
# Check if reasoning assist is explicitly enabled/disabled
|
# Check if reasoning assist is explicitly enabled/disabled
|
||||||
force_assistance = get_config_repository().get("force_reasoning_assistance", False)
|
force_assistance = get_config_repository().get("force_reasoning_assistance", False)
|
||||||
disable_assistance = get_config_repository().get("disable_reasoning_assistance", False)
|
disable_assistance = get_config_repository().get(
|
||||||
|
"disable_reasoning_assistance", False
|
||||||
|
)
|
||||||
|
|
||||||
if force_assistance:
|
if force_assistance:
|
||||||
reasoning_assist_enabled = True
|
reasoning_assist_enabled = True
|
||||||
elif disable_assistance:
|
elif disable_assistance:
|
||||||
|
|
@ -853,27 +903,29 @@ def run_planning_agent(
|
||||||
else:
|
else:
|
||||||
# Fall back to model default
|
# Fall back to model default
|
||||||
reasoning_assist_enabled = model_config.get("reasoning_assist_default", False)
|
reasoning_assist_enabled = model_config.get("reasoning_assist_default", False)
|
||||||
|
|
||||||
logger.debug("Reasoning assist enabled: %s", reasoning_assist_enabled)
|
logger.debug("Reasoning assist enabled: %s", reasoning_assist_enabled)
|
||||||
|
|
||||||
# Get all the context information (used both for normal planning and reasoning assist)
|
# Get all the context information (used both for normal planning and reasoning assist)
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
working_directory = os.getcwd()
|
working_directory = os.getcwd()
|
||||||
|
|
||||||
# Make sure key_facts is defined before using it
|
# Make sure key_facts is defined before using it
|
||||||
try:
|
try:
|
||||||
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
|
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||||
key_facts = ""
|
key_facts = ""
|
||||||
|
|
||||||
# Make sure key_snippets is defined before using it
|
# Make sure key_snippets is defined before using it
|
||||||
try:
|
try:
|
||||||
key_snippets = format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict())
|
key_snippets = format_key_snippets_dict(
|
||||||
|
get_key_snippet_repository().get_snippets_dict()
|
||||||
|
)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
logger.error(f"Failed to access key snippet repository: {str(e)}")
|
||||||
key_snippets = ""
|
key_snippets = ""
|
||||||
|
|
||||||
# Get formatted research notes using repository
|
# Get formatted research notes using repository
|
||||||
try:
|
try:
|
||||||
repository = get_research_note_repository()
|
repository = get_research_note_repository()
|
||||||
|
|
@ -882,28 +934,31 @@ def run_planning_agent(
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.error(f"Failed to access research note repository: {str(e)}")
|
logger.error(f"Failed to access research note repository: {str(e)}")
|
||||||
formatted_research_notes = ""
|
formatted_research_notes = ""
|
||||||
|
|
||||||
# Get related files
|
# Get related files
|
||||||
related_files = "\n".join(get_related_files())
|
related_files = "\n".join(get_related_files())
|
||||||
|
|
||||||
# Get environment inventory information
|
# Get environment inventory information
|
||||||
env_inv = get_env_inv()
|
env_inv = get_env_inv()
|
||||||
|
|
||||||
# Display the planning stage header before any reasoning assistance
|
# Display the planning stage header before any reasoning assistance
|
||||||
print_stage_header("Planning Stage")
|
print_stage_header("Planning Stage")
|
||||||
|
|
||||||
# Initialize expert guidance section
|
# Initialize expert guidance section
|
||||||
expert_guidance = ""
|
expert_guidance = ""
|
||||||
|
|
||||||
# If reasoning assist is enabled, make a one-off call to the expert model
|
# If reasoning assist is enabled, make a one-off call to the expert model
|
||||||
if reasoning_assist_enabled:
|
if reasoning_assist_enabled:
|
||||||
try:
|
try:
|
||||||
logger.info("Reasoning assist enabled for model %s, getting expert guidance", model_name)
|
logger.info(
|
||||||
|
"Reasoning assist enabled for model %s, getting expert guidance",
|
||||||
|
model_name,
|
||||||
|
)
|
||||||
|
|
||||||
# Collect tool descriptions
|
# Collect tool descriptions
|
||||||
tool_metadata = []
|
tool_metadata = []
|
||||||
from ra_aid.tools.reflection import get_function_info as get_tool_info
|
from ra_aid.tools.reflection import get_function_info as get_tool_info
|
||||||
|
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
try:
|
try:
|
||||||
tool_info = get_tool_info(tool.func)
|
tool_info = get_tool_info(tool.func)
|
||||||
|
|
@ -912,13 +967,13 @@ def run_planning_agent(
|
||||||
tool_metadata.append(f"Tool: {name}\nDescription: {description}\n")
|
tool_metadata.append(f"Tool: {name}\nDescription: {description}\n")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error getting tool info for {tool}: {e}")
|
logger.warning(f"Error getting tool info for {tool}: {e}")
|
||||||
|
|
||||||
# Format tool metadata
|
# Format tool metadata
|
||||||
formatted_tool_metadata = "\n".join(tool_metadata)
|
formatted_tool_metadata = "\n".join(tool_metadata)
|
||||||
|
|
||||||
# Initialize expert model
|
# Initialize expert model
|
||||||
expert_model = initialize_expert_llm(provider, model_name)
|
expert_model = initialize_expert_llm(provider, model_name)
|
||||||
|
|
||||||
# Format the reasoning assist prompt
|
# Format the reasoning assist prompt
|
||||||
reasoning_assist_prompt = REASONING_ASSIST_PROMPT_PLANNING.format(
|
reasoning_assist_prompt = REASONING_ASSIST_PROMPT_PLANNING.format(
|
||||||
current_date=current_date,
|
current_date=current_date,
|
||||||
|
|
@ -931,62 +986,78 @@ def run_planning_agent(
|
||||||
env_inv=env_inv,
|
env_inv=env_inv,
|
||||||
tool_metadata=formatted_tool_metadata,
|
tool_metadata=formatted_tool_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Show the reasoning assist query in a panel
|
# Show the reasoning assist query in a panel
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Markdown("Consulting with the reasoning model on the best way to do this."), title="📝 Thinking about the plan...", border_style="yellow")
|
Panel(
|
||||||
|
Markdown(
|
||||||
|
"Consulting with the reasoning model on the best way to do this."
|
||||||
|
),
|
||||||
|
title="📝 Thinking about the plan...",
|
||||||
|
border_style="yellow",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Invoking expert model for reasoning assist")
|
logger.debug("Invoking expert model for reasoning assist")
|
||||||
# Make the call to the expert model
|
# Make the call to the expert model
|
||||||
response = expert_model.invoke(reasoning_assist_prompt)
|
response = expert_model.invoke(reasoning_assist_prompt)
|
||||||
|
|
||||||
# Check if the model supports think tags
|
# Check if the model supports think tags
|
||||||
supports_think_tag = model_config.get("supports_think_tag", False)
|
supports_think_tag = model_config.get("supports_think_tag", False)
|
||||||
supports_thinking = model_config.get("supports_thinking", False)
|
supports_thinking = model_config.get("supports_thinking", False)
|
||||||
|
|
||||||
# Get response content, handling if it's a list (for Claude thinking mode)
|
# Get response content, handling if it's a list (for Claude thinking mode)
|
||||||
content = None
|
content = None
|
||||||
|
|
||||||
if hasattr(response, 'content'):
|
if hasattr(response, "content"):
|
||||||
content = response.content
|
content = response.content
|
||||||
else:
|
else:
|
||||||
# Fallback if content attribute is missing
|
# Fallback if content attribute is missing
|
||||||
content = str(response)
|
content = str(response)
|
||||||
|
|
||||||
# Process content based on its type
|
# Process content based on its type
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
# Handle structured thinking mode (e.g., Claude 3.7)
|
# Handle structured thinking mode (e.g., Claude 3.7)
|
||||||
thinking_content = None
|
thinking_content = None
|
||||||
response_text = None
|
response_text = None
|
||||||
|
|
||||||
# Process each item in the list
|
# Process each item in the list
|
||||||
for item in content:
|
for item in content:
|
||||||
if isinstance(item, dict):
|
if isinstance(item, dict):
|
||||||
# Extract thinking content
|
# Extract thinking content
|
||||||
if item.get('type') == 'thinking' and 'thinking' in item:
|
if item.get("type") == "thinking" and "thinking" in item:
|
||||||
thinking_content = item['thinking']
|
thinking_content = item["thinking"]
|
||||||
logger.debug("Found structured thinking content")
|
logger.debug("Found structured thinking content")
|
||||||
# Extract response text
|
# Extract response text
|
||||||
elif item.get('type') == 'text' and 'text' in item:
|
elif item.get("type") == "text" and "text" in item:
|
||||||
response_text = item['text']
|
response_text = item["text"]
|
||||||
logger.debug("Found structured response text")
|
logger.debug("Found structured response text")
|
||||||
|
|
||||||
# Display thinking content in a separate panel if available
|
# Display thinking content in a separate panel if available
|
||||||
if thinking_content and get_config_repository().get("show_thoughts", False):
|
if thinking_content and get_config_repository().get(
|
||||||
logger.debug(f"Displaying structured thinking content ({len(thinking_content)} chars)")
|
"show_thoughts", False
|
||||||
console.print(
|
):
|
||||||
Panel(Markdown(thinking_content), title="💭 Expert Thinking", border_style="yellow")
|
logger.debug(
|
||||||
|
f"Displaying structured thinking content ({len(thinking_content)} chars)"
|
||||||
)
|
)
|
||||||
|
console.print(
|
||||||
|
Panel(
|
||||||
|
Markdown(thinking_content),
|
||||||
|
title="💭 Expert Thinking",
|
||||||
|
border_style="yellow",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Use response_text if available, otherwise fall back to joining
|
# Use response_text if available, otherwise fall back to joining
|
||||||
if response_text:
|
if response_text:
|
||||||
content = response_text
|
content = response_text
|
||||||
else:
|
else:
|
||||||
# Fallback: join list items if structured extraction failed
|
# Fallback: join list items if structured extraction failed
|
||||||
logger.debug("No structured response text found, joining list items")
|
logger.debug(
|
||||||
|
"No structured response text found, joining list items"
|
||||||
|
)
|
||||||
content = "\n".join(str(item) for item in content)
|
content = "\n".join(str(item) for item in content)
|
||||||
elif (supports_think_tag or supports_thinking):
|
elif supports_think_tag or supports_thinking:
|
||||||
# Process thinking content using the centralized function
|
# Process thinking content using the centralized function
|
||||||
content, _ = process_thinking_content(
|
content, _ = process_thinking_content(
|
||||||
content=content,
|
content=content,
|
||||||
|
|
@ -994,24 +1065,28 @@ def run_planning_agent(
|
||||||
supports_thinking=supports_thinking,
|
supports_thinking=supports_thinking,
|
||||||
panel_title="💭 Expert Thinking",
|
panel_title="💭 Expert Thinking",
|
||||||
panel_style="yellow",
|
panel_style="yellow",
|
||||||
logger=logger
|
logger=logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Display the expert guidance in a panel
|
# Display the expert guidance in a panel
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Markdown(content), title="Reasoning Guidance", border_style="blue")
|
Panel(
|
||||||
|
Markdown(content), title="Reasoning Guidance", border_style="blue"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use the content as expert guidance
|
# Use the content as expert guidance
|
||||||
expert_guidance = content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY ON THIS TASK"
|
expert_guidance = (
|
||||||
|
content + "\n\nCONSULT WITH THE EXPERT FREQUENTLY ON THIS TASK"
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Received expert guidance for planning")
|
logger.info("Received expert guidance for planning")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error getting expert guidance for planning: %s", e)
|
logger.error("Error getting expert guidance for planning: %s", e)
|
||||||
expert_guidance = ""
|
expert_guidance = ""
|
||||||
|
|
||||||
agent = create_agent(model, tools, checkpointer=memory, agent_type="planner")
|
agent = create_agent(model, tools, checkpointer=memory, agent_type="planner")
|
||||||
|
|
||||||
expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else ""
|
expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else ""
|
||||||
human_section = HUMAN_PROMPT_SECTION_PLANNING if hil else ""
|
human_section = HUMAN_PROMPT_SECTION_PLANNING if hil else ""
|
||||||
web_research_section = (
|
web_research_section = (
|
||||||
|
|
@ -1019,7 +1094,7 @@ def run_planning_agent(
|
||||||
if get_config_repository().get("web_research_enabled", False)
|
if get_config_repository().get("web_research_enabled", False)
|
||||||
else ""
|
else ""
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare expert guidance section if expert guidance is available
|
# Prepare expert guidance section if expert guidance is available
|
||||||
expert_guidance_section = ""
|
expert_guidance_section = ""
|
||||||
if expert_guidance:
|
if expert_guidance:
|
||||||
|
|
@ -1050,7 +1125,9 @@ def run_planning_agent(
|
||||||
)
|
)
|
||||||
|
|
||||||
config_values = get_config_repository().get_all()
|
config_values = get_config_repository().get_all()
|
||||||
recursion_limit = get_config_repository().get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
recursion_limit = get_config_repository().get(
|
||||||
|
"recursion_limit", DEFAULT_RECURSION_LIMIT
|
||||||
|
)
|
||||||
run_config = {
|
run_config = {
|
||||||
"configurable": {"thread_id": thread_id},
|
"configurable": {"thread_id": thread_id},
|
||||||
"recursion_limit": recursion_limit,
|
"recursion_limit": recursion_limit,
|
||||||
|
|
@ -1060,9 +1137,7 @@ def run_planning_agent(
|
||||||
try:
|
try:
|
||||||
logger.debug("Planning agent completed successfully")
|
logger.debug("Planning agent completed successfully")
|
||||||
none_or_fallback_handler = init_fallback_handler(agent, tools)
|
none_or_fallback_handler = init_fallback_handler(agent, tools)
|
||||||
_result = run_agent_with_retry(
|
_result = run_agent_with_retry(agent, planning_prompt, none_or_fallback_handler)
|
||||||
agent, planning_prompt, none_or_fallback_handler
|
|
||||||
)
|
|
||||||
if _result:
|
if _result:
|
||||||
# Log planning completion
|
# Log planning completion
|
||||||
log_work_event(f"Completed planning phase for: {base_task}")
|
log_work_event(f"Completed planning phase for: {base_task}")
|
||||||
|
|
@ -1135,7 +1210,7 @@ def run_task_implementation_agent(
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||||
key_facts = ""
|
key_facts = ""
|
||||||
|
|
||||||
# Get formatted research notes using repository
|
# Get formatted research notes using repository
|
||||||
try:
|
try:
|
||||||
repository = get_research_note_repository()
|
repository = get_research_note_repository()
|
||||||
|
|
@ -1144,7 +1219,7 @@ def run_task_implementation_agent(
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.error(f"Failed to access research note repository: {str(e)}")
|
logger.error(f"Failed to access research note repository: {str(e)}")
|
||||||
formatted_research_notes = ""
|
formatted_research_notes = ""
|
||||||
|
|
||||||
# Get latest project info
|
# Get latest project info
|
||||||
try:
|
try:
|
||||||
project_info = get_project_info(".")
|
project_info = get_project_info(".")
|
||||||
|
|
@ -1152,24 +1227,26 @@ def run_task_implementation_agent(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Failed to get project info: %s", str(e))
|
logger.warning("Failed to get project info: %s", str(e))
|
||||||
formatted_project_info = "Project info unavailable"
|
formatted_project_info = "Project info unavailable"
|
||||||
|
|
||||||
# Get environment inventory information
|
# Get environment inventory information
|
||||||
env_inv = get_env_inv()
|
env_inv = get_env_inv()
|
||||||
|
|
||||||
# Get model configuration to check for reasoning_assist_default
|
# Get model configuration to check for reasoning_assist_default
|
||||||
provider = get_config_repository().get("provider", "")
|
provider = get_config_repository().get("provider", "")
|
||||||
model_name = get_config_repository().get("model", "")
|
model_name = get_config_repository().get("model", "")
|
||||||
logger.debug("Checking for reasoning_assist_default on %s/%s", provider, model_name)
|
logger.debug("Checking for reasoning_assist_default on %s/%s", provider, model_name)
|
||||||
|
|
||||||
model_config = {}
|
model_config = {}
|
||||||
provider_models = models_params.get(provider, {})
|
provider_models = models_params.get(provider, {})
|
||||||
if provider_models and model_name in provider_models:
|
if provider_models and model_name in provider_models:
|
||||||
model_config = provider_models[model_name]
|
model_config = provider_models[model_name]
|
||||||
|
|
||||||
# Check if reasoning assist is explicitly enabled/disabled
|
# Check if reasoning assist is explicitly enabled/disabled
|
||||||
force_assistance = get_config_repository().get("force_reasoning_assistance", False)
|
force_assistance = get_config_repository().get("force_reasoning_assistance", False)
|
||||||
disable_assistance = get_config_repository().get("disable_reasoning_assistance", False)
|
disable_assistance = get_config_repository().get(
|
||||||
|
"disable_reasoning_assistance", False
|
||||||
|
)
|
||||||
|
|
||||||
if force_assistance:
|
if force_assistance:
|
||||||
reasoning_assist_enabled = True
|
reasoning_assist_enabled = True
|
||||||
elif disable_assistance:
|
elif disable_assistance:
|
||||||
|
|
@ -1177,71 +1254,84 @@ def run_task_implementation_agent(
|
||||||
else:
|
else:
|
||||||
# Fall back to model default
|
# Fall back to model default
|
||||||
reasoning_assist_enabled = model_config.get("reasoning_assist_default", False)
|
reasoning_assist_enabled = model_config.get("reasoning_assist_default", False)
|
||||||
|
|
||||||
logger.debug("Reasoning assist enabled: %s", reasoning_assist_enabled)
|
logger.debug("Reasoning assist enabled: %s", reasoning_assist_enabled)
|
||||||
|
|
||||||
# Initialize implementation guidance section
|
# Initialize implementation guidance section
|
||||||
implementation_guidance_section = ""
|
implementation_guidance_section = ""
|
||||||
|
|
||||||
# If reasoning assist is enabled, make a one-off call to the expert model
|
# If reasoning assist is enabled, make a one-off call to the expert model
|
||||||
if reasoning_assist_enabled:
|
if reasoning_assist_enabled:
|
||||||
try:
|
try:
|
||||||
logger.info("Reasoning assist enabled for model %s, getting implementation guidance", model_name)
|
logger.info(
|
||||||
|
"Reasoning assist enabled for model %s, getting implementation guidance",
|
||||||
|
model_name,
|
||||||
|
)
|
||||||
|
|
||||||
# Collect tool descriptions
|
# Collect tool descriptions
|
||||||
tool_metadata = []
|
tool_metadata = []
|
||||||
from ra_aid.tools.reflection import get_function_info as get_tool_info
|
from ra_aid.tools.reflection import get_function_info as get_tool_info
|
||||||
|
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
try:
|
try:
|
||||||
tool_info = get_tool_info(tool.func)
|
tool_info = get_tool_info(tool.func)
|
||||||
name = tool.func.__name__
|
name = tool.func.__name__
|
||||||
description = inspect.getdoc(tool.func)
|
description = inspect.getdoc(tool.func)
|
||||||
tool_metadata.append(f"Tool: {name}\\nDescription: {description}\\n")
|
tool_metadata.append(
|
||||||
|
f"Tool: {name}\\nDescription: {description}\\n"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error getting tool info for {tool}: {e}")
|
logger.warning(f"Error getting tool info for {tool}: {e}")
|
||||||
|
|
||||||
# Format tool metadata
|
# Format tool metadata
|
||||||
formatted_tool_metadata = "\\n".join(tool_metadata)
|
formatted_tool_metadata = "\\n".join(tool_metadata)
|
||||||
|
|
||||||
# Initialize expert model
|
# Initialize expert model
|
||||||
expert_model = initialize_expert_llm(provider, model_name)
|
expert_model = initialize_expert_llm(provider, model_name)
|
||||||
|
|
||||||
# Format the reasoning assist prompt for implementation
|
# Format the reasoning assist prompt for implementation
|
||||||
reasoning_assist_prompt = REASONING_ASSIST_PROMPT_IMPLEMENTATION.format(
|
reasoning_assist_prompt = REASONING_ASSIST_PROMPT_IMPLEMENTATION.format(
|
||||||
current_date=current_date,
|
current_date=current_date,
|
||||||
working_directory=working_directory,
|
working_directory=working_directory,
|
||||||
task=task,
|
task=task,
|
||||||
key_facts=key_facts,
|
key_facts=key_facts,
|
||||||
key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()),
|
key_snippets=format_key_snippets_dict(
|
||||||
|
get_key_snippet_repository().get_snippets_dict()
|
||||||
|
),
|
||||||
research_notes=formatted_research_notes,
|
research_notes=formatted_research_notes,
|
||||||
related_files="\\n".join(related_files),
|
related_files="\\n".join(related_files),
|
||||||
env_inv=env_inv,
|
env_inv=env_inv,
|
||||||
tool_metadata=formatted_tool_metadata,
|
tool_metadata=formatted_tool_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Show the reasoning assist query in a panel
|
# Show the reasoning assist query in a panel
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Markdown("Consulting with the reasoning model on the best implementation approach."), title="📝 Thinking about implementation...", border_style="yellow")
|
Panel(
|
||||||
|
Markdown(
|
||||||
|
"Consulting with the reasoning model on the best implementation approach."
|
||||||
|
),
|
||||||
|
title="📝 Thinking about implementation...",
|
||||||
|
border_style="yellow",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Invoking expert model for implementation reasoning assist")
|
logger.debug("Invoking expert model for implementation reasoning assist")
|
||||||
# Make the call to the expert model
|
# Make the call to the expert model
|
||||||
response = expert_model.invoke(reasoning_assist_prompt)
|
response = expert_model.invoke(reasoning_assist_prompt)
|
||||||
|
|
||||||
# Check if the model supports think tags
|
# Check if the model supports think tags
|
||||||
supports_think_tag = model_config.get("supports_think_tag", False)
|
supports_think_tag = model_config.get("supports_think_tag", False)
|
||||||
supports_thinking = model_config.get("supports_thinking", False)
|
supports_thinking = model_config.get("supports_thinking", False)
|
||||||
|
|
||||||
# Process response content
|
# Process response content
|
||||||
content = None
|
content = None
|
||||||
|
|
||||||
if hasattr(response, 'content'):
|
if hasattr(response, "content"):
|
||||||
content = response.content
|
content = response.content
|
||||||
else:
|
else:
|
||||||
# Fallback if content attribute is missing
|
# Fallback if content attribute is missing
|
||||||
content = str(response)
|
content = str(response)
|
||||||
|
|
||||||
# Process the response content using the centralized function
|
# Process the response content using the centralized function
|
||||||
content, extracted_thinking = process_thinking_content(
|
content, extracted_thinking = process_thinking_content(
|
||||||
content=content,
|
content=content,
|
||||||
|
|
@ -1249,24 +1339,28 @@ def run_task_implementation_agent(
|
||||||
supports_thinking=supports_thinking,
|
supports_thinking=supports_thinking,
|
||||||
panel_title="💭 Implementation Thinking",
|
panel_title="💭 Implementation Thinking",
|
||||||
panel_style="yellow",
|
panel_style="yellow",
|
||||||
logger=logger
|
logger=logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Display the implementation guidance in a panel
|
# Display the implementation guidance in a panel
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Markdown(content), title="Implementation Guidance", border_style="blue")
|
Panel(
|
||||||
|
Markdown(content),
|
||||||
|
title="Implementation Guidance",
|
||||||
|
border_style="blue",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Format the implementation guidance section for the prompt
|
# Format the implementation guidance section for the prompt
|
||||||
implementation_guidance_section = f"""<implementation guidance>
|
implementation_guidance_section = f"""<implementation guidance>
|
||||||
{content}
|
{content}
|
||||||
</implementation guidance>"""
|
</implementation guidance>"""
|
||||||
|
|
||||||
logger.info("Received implementation guidance")
|
logger.info("Received implementation guidance")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error getting implementation guidance: %s", e)
|
logger.error("Error getting implementation guidance: %s", e)
|
||||||
implementation_guidance_section = ""
|
implementation_guidance_section = ""
|
||||||
|
|
||||||
prompt = IMPLEMENTATION_PROMPT.format(
|
prompt = IMPLEMENTATION_PROMPT.format(
|
||||||
current_date=current_date,
|
current_date=current_date,
|
||||||
working_directory=working_directory,
|
working_directory=working_directory,
|
||||||
|
|
@ -1276,7 +1370,9 @@ def run_task_implementation_agent(
|
||||||
plan=plan,
|
plan=plan,
|
||||||
related_files=related_files,
|
related_files=related_files,
|
||||||
key_facts=key_facts,
|
key_facts=key_facts,
|
||||||
key_snippets=format_key_snippets_dict(get_key_snippet_repository().get_snippets_dict()),
|
key_snippets=format_key_snippets_dict(
|
||||||
|
get_key_snippet_repository().get_snippets_dict()
|
||||||
|
),
|
||||||
research_notes=formatted_research_notes,
|
research_notes=formatted_research_notes,
|
||||||
work_log=get_work_log_repository().format_work_log(),
|
work_log=get_work_log_repository().format_work_log(),
|
||||||
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
|
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
|
||||||
|
|
@ -1296,7 +1392,9 @@ def run_task_implementation_agent(
|
||||||
)
|
)
|
||||||
|
|
||||||
config_values = get_config_repository().get_all()
|
config_values = get_config_repository().get_all()
|
||||||
recursion_limit = get_config_repository().get("recursion_limit", DEFAULT_RECURSION_LIMIT)
|
recursion_limit = get_config_repository().get(
|
||||||
|
"recursion_limit", DEFAULT_RECURSION_LIMIT
|
||||||
|
)
|
||||||
run_config = {
|
run_config = {
|
||||||
"configurable": {"thread_id": thread_id},
|
"configurable": {"thread_id": thread_id},
|
||||||
"recursion_limit": recursion_limit,
|
"recursion_limit": recursion_limit,
|
||||||
|
|
@ -1306,9 +1404,7 @@ def run_task_implementation_agent(
|
||||||
try:
|
try:
|
||||||
logger.debug("Implementation agent completed successfully")
|
logger.debug("Implementation agent completed successfully")
|
||||||
none_or_fallback_handler = init_fallback_handler(agent, tools)
|
none_or_fallback_handler = init_fallback_handler(agent, tools)
|
||||||
_result = run_agent_with_retry(
|
_result = run_agent_with_retry(agent, prompt, none_or_fallback_handler)
|
||||||
agent, prompt, none_or_fallback_handler
|
|
||||||
)
|
|
||||||
if _result:
|
if _result:
|
||||||
# Log task implementation completion
|
# Log task implementation completion
|
||||||
log_work_event(f"Completed implementation of task: {task}")
|
log_work_event(f"Completed implementation of task: {task}")
|
||||||
|
|
@ -1380,27 +1476,37 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
|
||||||
# 1. Check if this is a ValueError with 429 code or rate limit phrases
|
# 1. Check if this is a ValueError with 429 code or rate limit phrases
|
||||||
if isinstance(e, ValueError):
|
if isinstance(e, ValueError):
|
||||||
error_str = str(e).lower()
|
error_str = str(e).lower()
|
||||||
rate_limit_phrases = ["429", "rate limit", "too many requests", "quota exceeded"]
|
rate_limit_phrases = [
|
||||||
if "code" not in error_str and not any(phrase in error_str for phrase in rate_limit_phrases):
|
"429",
|
||||||
|
"rate limit",
|
||||||
|
"too many requests",
|
||||||
|
"quota exceeded",
|
||||||
|
]
|
||||||
|
if "code" not in error_str and not any(
|
||||||
|
phrase in error_str for phrase in rate_limit_phrases
|
||||||
|
):
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
# 2. Check for status_code or http_status attribute equal to 429
|
# 2. Check for status_code or http_status attribute equal to 429
|
||||||
if hasattr(e, 'status_code') and e.status_code == 429:
|
if hasattr(e, "status_code") and e.status_code == 429:
|
||||||
pass # This is a rate limit error, continue with retry logic
|
pass # This is a rate limit error, continue with retry logic
|
||||||
elif hasattr(e, 'http_status') and e.http_status == 429:
|
elif hasattr(e, "http_status") and e.http_status == 429:
|
||||||
pass # This is a rate limit error, continue with retry logic
|
pass # This is a rate limit error, continue with retry logic
|
||||||
# 3. Check for rate limit phrases in error message
|
# 3. Check for rate limit phrases in error message
|
||||||
elif isinstance(e, Exception) and not isinstance(e, ValueError):
|
elif isinstance(e, Exception) and not isinstance(e, ValueError):
|
||||||
error_str = str(e).lower()
|
error_str = str(e).lower()
|
||||||
if not any(phrase in error_str for phrase in ["rate limit", "too many requests", "quota exceeded", "429"]) and not ("rate" in error_str and "limit" in error_str):
|
if not any(
|
||||||
|
phrase in error_str
|
||||||
|
for phrase in ["rate limit", "too many requests", "quota exceeded", "429"]
|
||||||
|
) and not ("rate" in error_str and "limit" in error_str):
|
||||||
# This doesn't look like a rate limit error, but we'll still retry other API errors
|
# This doesn't look like a rate limit error, but we'll still retry other API errors
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Apply common retry logic for all identified errors
|
# Apply common retry logic for all identified errors
|
||||||
if attempt == max_retries - 1:
|
if attempt == max_retries - 1:
|
||||||
logger.error("Max retries reached, failing: %s", str(e))
|
logger.error("Max retries reached, failing: %s", str(e))
|
||||||
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {e}")
|
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {e}")
|
||||||
|
|
||||||
logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e))
|
logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e))
|
||||||
delay = base_delay * (2**attempt)
|
delay = base_delay * (2**attempt)
|
||||||
print_error(
|
print_error(
|
||||||
|
|
@ -1457,55 +1563,78 @@ def _handle_fallback_response(
|
||||||
def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage]):
|
def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage]):
|
||||||
"""
|
"""
|
||||||
Streams agent output while handling completion and interruption.
|
Streams agent output while handling completion and interruption.
|
||||||
|
|
||||||
For each chunk, it logs the output, calls check_interrupt(), prints agent output,
|
For each chunk, it logs the output, calls check_interrupt(), prints agent output,
|
||||||
and then checks if is_completed() or should_exit() are true. If so, it resets completion
|
and then checks if is_completed() or should_exit() are true. If so, it resets completion
|
||||||
flags and returns. After finishing a stream iteration (i.e. the for-loop over chunks),
|
flags and returns. After finishing a stream iteration (i.e. the for-loop over chunks),
|
||||||
the function retrieves the agent's state. If the state indicates further steps (i.e. state.next is non-empty),
|
the function retrieves the agent's state. If the state indicates further steps (i.e. state.next is non-empty),
|
||||||
it resumes execution via agent.invoke(None, config); otherwise, it exits the loop.
|
it resumes execution via agent.invoke(None, config); otherwise, it exits the loop.
|
||||||
|
|
||||||
This function adheres to the latest LangGraph best practices (as of March 2025) for handling
|
This function adheres to the latest LangGraph best practices (as of March 2025) for handling
|
||||||
human-in-the-loop interruptions using interrupt_after=["tools"].
|
human-in-the-loop interruptions using interrupt_after=["tools"].
|
||||||
"""
|
"""
|
||||||
config = get_config_repository().get_all()
|
config = get_config_repository().get_all()
|
||||||
|
stream_config = config.copy()
|
||||||
|
|
||||||
|
cb = None
|
||||||
|
if is_anthropic_claude(config):
|
||||||
|
model_name = config.get("model", "")
|
||||||
|
full_model_name = model_name
|
||||||
|
cb = AnthropicCallbackHandler(full_model_name)
|
||||||
|
|
||||||
|
if "callbacks" not in stream_config:
|
||||||
|
stream_config["callbacks"] = []
|
||||||
|
stream_config["callbacks"].append(cb)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# Process each chunk from the agent stream.
|
for chunk in agent.stream({"messages": msg_list}, stream_config):
|
||||||
for chunk in agent.stream({"messages": msg_list}, config):
|
|
||||||
logger.debug("Agent output: %s", chunk)
|
logger.debug("Agent output: %s", chunk)
|
||||||
check_interrupt()
|
check_interrupt()
|
||||||
agent_type = get_agent_type(agent)
|
agent_type = get_agent_type(agent)
|
||||||
print_agent_output(chunk, agent_type)
|
print_agent_output(chunk, agent_type)
|
||||||
|
|
||||||
if is_completed() or should_exit():
|
if is_completed() or should_exit():
|
||||||
reset_completion_flags()
|
reset_completion_flags()
|
||||||
return True # Exit immediately when finished or signaled to exit.
|
if cb:
|
||||||
|
logger.debug(f"AnthropicCallbackHandler:\n{cb}")
|
||||||
|
return True
|
||||||
|
|
||||||
logger.debug("Stream iteration ended; checking agent state for continuation.")
|
logger.debug("Stream iteration ended; checking agent state for continuation.")
|
||||||
|
|
||||||
# Prepare state configuration, ensuring 'configurable' is present.
|
# Prepare state configuration, ensuring 'configurable' is present.
|
||||||
state_config = get_config_repository().get_all().copy()
|
state_config = get_config_repository().get_all().copy()
|
||||||
if "configurable" not in state_config:
|
if "configurable" not in state_config:
|
||||||
logger.debug("Key 'configurable' not found in config; adding it as an empty dict.")
|
logger.debug(
|
||||||
|
"Key 'configurable' not found in config; adding it as an empty dict."
|
||||||
|
)
|
||||||
state_config["configurable"] = {}
|
state_config["configurable"] = {}
|
||||||
logger.debug("Using state_config for agent.get_state(): %s", state_config)
|
logger.debug("Using state_config for agent.get_state(): %s", state_config)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
state = agent.get_state(state_config)
|
state = agent.get_state(state_config)
|
||||||
logger.debug("Agent state retrieved: %s", state)
|
logger.debug("Agent state retrieved: %s", state)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error retrieving agent state with state_config %s: %s", state_config, e)
|
logger.error(
|
||||||
|
"Error retrieving agent state with state_config %s: %s", state_config, e
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# If the state indicates that further steps remain (i.e. state.next is non-empty),
|
|
||||||
# then resume execution by invoking the agent with no new input.
|
|
||||||
if state.next:
|
if state.next:
|
||||||
logger.debug("State indicates continuation (state.next: %s); resuming execution.", state.next)
|
logger.debug(
|
||||||
agent.invoke(None, config)
|
"State indicates continuation (state.next: %s); resuming execution.",
|
||||||
|
state.next,
|
||||||
|
)
|
||||||
|
agent.invoke(None, stream_config)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
logger.debug("No continuation indicated in state; exiting stream loop.")
|
logger.debug("No continuation indicated in state; exiting stream loop.")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if cb:
|
||||||
|
logger.debug(f"AnthropicCallbackHandler:\n{cb}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def run_agent_with_retry(
|
def run_agent_with_retry(
|
||||||
agent: RAgents,
|
agent: RAgents,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
|
@ -1517,7 +1646,9 @@ def run_agent_with_retry(
|
||||||
max_retries = 20
|
max_retries = 20
|
||||||
base_delay = 1
|
base_delay = 1
|
||||||
test_attempts = 0
|
test_attempts = 0
|
||||||
_max_test_retries = get_config_repository().get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES)
|
_max_test_retries = get_config_repository().get(
|
||||||
|
"max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES
|
||||||
|
)
|
||||||
auto_test = get_config_repository().get("auto_test", False)
|
auto_test = get_config_repository().get("auto_test", False)
|
||||||
original_prompt = prompt
|
original_prompt = prompt
|
||||||
msg_list = [HumanMessage(content=prompt)]
|
msg_list = [HumanMessage(content=prompt)]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,270 @@
|
||||||
|
"""Custom callback handlers for tracking token usage and costs."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
|
|
||||||
|
# Define cost per 1K tokens for various models
|
||||||
|
ANTHROPIC_MODEL_COSTS = {
|
||||||
|
# Claude 3.7 Sonnet input
|
||||||
|
"claude-3-7-sonnet-20250219": 0.003,
|
||||||
|
"anthropic/claude-3.7-sonnet": 0.003,
|
||||||
|
"claude-3.7-sonnet": 0.003,
|
||||||
|
# Claude 3.7 Sonnet output
|
||||||
|
"claude-3-7-sonnet-20250219-completion": 0.015,
|
||||||
|
"anthropic/claude-3.7-sonnet-completion": 0.015,
|
||||||
|
"claude-3.7-sonnet-completion": 0.015,
|
||||||
|
# Claude 3 Opus input
|
||||||
|
"claude-3-opus-20240229": 0.015,
|
||||||
|
"anthropic/claude-3-opus": 0.015,
|
||||||
|
"claude-3-opus": 0.015,
|
||||||
|
# Claude 3 Opus output
|
||||||
|
"claude-3-opus-20240229-completion": 0.075,
|
||||||
|
"anthropic/claude-3-opus-completion": 0.075,
|
||||||
|
"claude-3-opus-completion": 0.075,
|
||||||
|
# Claude 3 Sonnet input
|
||||||
|
"claude-3-sonnet-20240229": 0.003,
|
||||||
|
"anthropic/claude-3-sonnet": 0.003,
|
||||||
|
"claude-3-sonnet": 0.003,
|
||||||
|
# Claude 3 Sonnet output
|
||||||
|
"claude-3-sonnet-20240229-completion": 0.015,
|
||||||
|
"anthropic/claude-3-sonnet-completion": 0.015,
|
||||||
|
"claude-3-sonnet-completion": 0.015,
|
||||||
|
# Claude 3 Haiku input
|
||||||
|
"claude-3-haiku-20240307": 0.00025,
|
||||||
|
"anthropic/claude-3-haiku": 0.00025,
|
||||||
|
"claude-3-haiku": 0.00025,
|
||||||
|
# Claude 3 Haiku output
|
||||||
|
"claude-3-haiku-20240307-completion": 0.00125,
|
||||||
|
"anthropic/claude-3-haiku-completion": 0.00125,
|
||||||
|
"claude-3-haiku-completion": 0.00125,
|
||||||
|
# Claude 2 input
|
||||||
|
"claude-2": 0.008,
|
||||||
|
"claude-2.0": 0.008,
|
||||||
|
"claude-2.1": 0.008,
|
||||||
|
# Claude 2 output
|
||||||
|
"claude-2-completion": 0.024,
|
||||||
|
"claude-2.0-completion": 0.024,
|
||||||
|
"claude-2.1-completion": 0.024,
|
||||||
|
# Claude Instant input
|
||||||
|
"claude-instant-1": 0.0016,
|
||||||
|
"claude-instant-1.2": 0.0016,
|
||||||
|
# Claude Instant output
|
||||||
|
"claude-instant-1-completion": 0.0055,
|
||||||
|
"claude-instant-1.2-completion": 0.0055,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def standardize_model_name(model_name: str, is_completion: bool = False) -> str:
|
||||||
|
"""
|
||||||
|
Standardize the model name to a format that can be used for cost calculation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Model name to standardize.
|
||||||
|
is_completion: Whether the model is used for completion or not.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Standardized model name.
|
||||||
|
"""
|
||||||
|
if not model_name:
|
||||||
|
model_name = "claude-3-sonnet"
|
||||||
|
|
||||||
|
model_name = model_name.lower()
|
||||||
|
|
||||||
|
# Handle OpenRouter prefixes
|
||||||
|
if model_name.startswith("anthropic/"):
|
||||||
|
model_name = model_name[len("anthropic/") :]
|
||||||
|
|
||||||
|
# Add completion suffix if needed
|
||||||
|
if is_completion and not model_name.endswith("-completion"):
|
||||||
|
model_name = model_name + "-completion"
|
||||||
|
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
|
||||||
|
def get_anthropic_token_cost_for_model(
|
||||||
|
model_name: str, num_tokens: int, is_completion: bool = False
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Get the cost in USD for a given model and number of tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the model
|
||||||
|
num_tokens: Number of tokens.
|
||||||
|
is_completion: Whether the model is used for completion or not.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cost in USD.
|
||||||
|
"""
|
||||||
|
model_name = standardize_model_name(model_name, is_completion)
|
||||||
|
|
||||||
|
if model_name not in ANTHROPIC_MODEL_COSTS:
|
||||||
|
# Default to Claude 3 Sonnet pricing if model not found
|
||||||
|
model_name = (
|
||||||
|
"claude-3-sonnet" if not is_completion else "claude-3-sonnet-completion"
|
||||||
|
)
|
||||||
|
|
||||||
|
cost_per_1k = ANTHROPIC_MODEL_COSTS[model_name]
|
||||||
|
total_cost = cost_per_1k * (num_tokens / 1000)
|
||||||
|
|
||||||
|
return total_cost
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicCallbackHandler(BaseCallbackHandler):
|
||||||
|
"""Callback Handler that tracks Anthropic token usage and costs."""
|
||||||
|
|
||||||
|
total_tokens: int = 0
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
completion_tokens: int = 0
|
||||||
|
successful_requests: int = 0
|
||||||
|
total_cost: float = 0.0
|
||||||
|
model_name: str = "claude-3-sonnet" # Default model
|
||||||
|
|
||||||
|
def __init__(self, model_name: Optional[str] = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
if model_name:
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
# Default costs for Claude 3.7 Sonnet
|
||||||
|
self.input_cost_per_token = 0.003 / 1000 # $3/M input tokens
|
||||||
|
self.output_cost_per_token = 0.015 / 1000 # $15/M output tokens
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"Tokens Used: {self.total_tokens}\n"
|
||||||
|
f"\tPrompt Tokens: {self.prompt_tokens}\n"
|
||||||
|
f"\tCompletion Tokens: {self.completion_tokens}\n"
|
||||||
|
f"Successful Requests: {self.successful_requests}\n"
|
||||||
|
f"Total Cost (USD): ${self.total_cost:.6f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def always_verbose(self) -> bool:
|
||||||
|
"""Whether to call verbose callbacks even if verbose is False."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Record the model name if available."""
|
||||||
|
if "name" in serialized:
|
||||||
|
self.model_name = serialized["name"]
|
||||||
|
|
||||||
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
|
"""Count tokens as they're generated."""
|
||||||
|
with self._lock:
|
||||||
|
self.completion_tokens += 1
|
||||||
|
self.total_tokens += 1
|
||||||
|
token_cost = get_anthropic_token_cost_for_model(
|
||||||
|
self.model_name, 1, is_completion=True
|
||||||
|
)
|
||||||
|
self.total_cost += token_cost
|
||||||
|
|
||||||
|
def on_llm_end(self, response: Any, **kwargs: Any) -> None:
|
||||||
|
"""Collect token usage from response."""
|
||||||
|
token_usage = {}
|
||||||
|
|
||||||
|
# Try to extract token usage from response
|
||||||
|
if hasattr(response, "llm_output") and response.llm_output:
|
||||||
|
llm_output = response.llm_output
|
||||||
|
if "token_usage" in llm_output:
|
||||||
|
token_usage = llm_output["token_usage"]
|
||||||
|
elif "usage" in llm_output:
|
||||||
|
usage = llm_output["usage"]
|
||||||
|
|
||||||
|
# Handle Anthropic's specific usage format
|
||||||
|
if "input_tokens" in usage:
|
||||||
|
token_usage["prompt_tokens"] = usage["input_tokens"]
|
||||||
|
if "output_tokens" in usage:
|
||||||
|
token_usage["completion_tokens"] = usage["output_tokens"]
|
||||||
|
|
||||||
|
# Extract model name if available
|
||||||
|
if "model_name" in llm_output:
|
||||||
|
self.model_name = llm_output["model_name"]
|
||||||
|
|
||||||
|
# Try to get usage from response.usage
|
||||||
|
elif hasattr(response, "usage"):
|
||||||
|
usage = response.usage
|
||||||
|
if hasattr(usage, "prompt_tokens"):
|
||||||
|
token_usage["prompt_tokens"] = usage.prompt_tokens
|
||||||
|
if hasattr(usage, "completion_tokens"):
|
||||||
|
token_usage["completion_tokens"] = usage.completion_tokens
|
||||||
|
if hasattr(usage, "total_tokens"):
|
||||||
|
token_usage["total_tokens"] = usage.total_tokens
|
||||||
|
|
||||||
|
# Extract usage from generations if available
|
||||||
|
elif hasattr(response, "generations") and response.generations:
|
||||||
|
for gen in response.generations:
|
||||||
|
if gen and hasattr(gen[0], "generation_info"):
|
||||||
|
gen_info = gen[0].generation_info or {}
|
||||||
|
if "usage" in gen_info:
|
||||||
|
token_usage = gen_info["usage"]
|
||||||
|
break
|
||||||
|
|
||||||
|
# Update counts with lock to prevent race conditions
|
||||||
|
with self._lock:
|
||||||
|
prompt_tokens = token_usage.get("prompt_tokens", 0)
|
||||||
|
completion_tokens = token_usage.get("completion_tokens", 0)
|
||||||
|
|
||||||
|
# Only update prompt tokens if we have them
|
||||||
|
if prompt_tokens > 0:
|
||||||
|
self.prompt_tokens += prompt_tokens
|
||||||
|
self.total_tokens += prompt_tokens
|
||||||
|
prompt_cost = get_anthropic_token_cost_for_model(
|
||||||
|
self.model_name, prompt_tokens, is_completion=False
|
||||||
|
)
|
||||||
|
self.total_cost += prompt_cost
|
||||||
|
|
||||||
|
# Only update completion tokens if not already counted by on_llm_new_token
|
||||||
|
if completion_tokens > 0 and completion_tokens > self.completion_tokens:
|
||||||
|
additional_tokens = completion_tokens - self.completion_tokens
|
||||||
|
self.completion_tokens = completion_tokens
|
||||||
|
self.total_tokens += additional_tokens
|
||||||
|
completion_cost = get_anthropic_token_cost_for_model(
|
||||||
|
self.model_name, additional_tokens, is_completion=True
|
||||||
|
)
|
||||||
|
self.total_cost += completion_cost
|
||||||
|
|
||||||
|
self.successful_requests += 1
|
||||||
|
|
||||||
|
def __copy__(self) -> "AnthropicCallbackHandler":
|
||||||
|
"""Return a copy of the callback handler."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __deepcopy__(self, memo: Any) -> "AnthropicCallbackHandler":
|
||||||
|
"""Return a deep copy of the callback handler."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
# Create a context variable for our custom callback
|
||||||
|
anthropic_callback_var: ContextVar[Optional[AnthropicCallbackHandler]] = ContextVar(
|
||||||
|
"anthropic_callback", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_anthropic_callback(
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
) -> AnthropicCallbackHandler:
|
||||||
|
"""Get the Anthropic callback handler in a context manager.
|
||||||
|
which conveniently exposes token and cost information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Optional model name to use for cost calculation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AnthropicCallbackHandler: The Anthropic callback handler.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> with get_anthropic_callback("claude-3-sonnet") as cb:
|
||||||
|
... # Use the callback handler
|
||||||
|
... # cb.total_tokens, cb.total_cost will be available after
|
||||||
|
"""
|
||||||
|
cb = AnthropicCallbackHandler(model_name)
|
||||||
|
anthropic_callback_var.set(cb)
|
||||||
|
yield cb
|
||||||
|
anthropic_callback_var.set(None)
|
||||||
Loading…
Reference in New Issue