Set Default Max Token Limit with Provider/Model Dictionary and Limit Tokens for Anthropic Claude React Agent (#45)
* feat(agent_utils.py): add get_model_token_limit function to retrieve token limits for models based on provider and model name feat(models_tokens.py): create models_tokens module to store token limits for various models and providers test(agent_utils.py): implement unit tests for get_model_token_limit and create_agent functions to ensure correct behavior and error handling * test: Add unit tests for token limiting and agent creation functionality * fix: Correct indentation and add missing test function for error handling * fix: Update test assertion to use messages_modifier instead of state_modifier * feat(agent_utils.py): add limit_tokens function to manage message token limits and preserve system message fix(agent_utils.py): update get_model_token_limit to handle exceptions and return None on error fix(ciayn_agent.py): set default max_tokens to DEFAULT_TOKEN_LIMIT in CiaynAgent initialization feat(models_tokens.py): define DEFAULT_TOKEN_LIMIT for consistent token limit management across agents style(agent_utils.py): improve code formatting and consistency in function definitions and comments style(agent_utils.py): refactor imports for better organization and readability fix(test_agent_utils.py): correct test assertion to use state_modifier instead of messages_modifier for create_agent function * refactor(agent_utils.py): improve docstring clarity and formatting for limit_tokens function to enhance readability refactor(test_agent_utils.py): format test assertions for consistency and readability in agent creation tests * feat: Update limit_tokens function to support Dict type for state parameter * feat: Update limit_tokens to handle both list and dict input types * refactor: Extract duplicate token trimming logic into helper function * refactor: Rename and update message trimming functions for clarity * refactor: Extract agent kwargs logic into a helper method for reuse * refactor: Rename _build_agent_kwargs to build_agent_kwargs for clarity * fix: Ensure state_modifier is passed correctly for agent creation * test: Add tests for create_agent token limiting configuration * refactor: Simplify CiaynAgent instantiation to only use max_tokens parameter * refactor: Remove is_react_agent parameter from build_agent_kwargs function * test: Fix test assertions for state_modifier in agent creation tests * fix: Update agent creation to handle checkpointer and simplify tests * test: Remove unnecessary assertions from agent creation test * feat: Implement token limiting configuration for create_agent function * refactor: Remove unused model info and token limit retrieval code * test: Fix assertion errors in agent creation tests and update state_modifier handling * test: Remove commented-out code and clarify assertions in tests * test: Fix assertion in test_create_agent_anthropic_token_limiting_disabled * feat(main.py): add --limit-tokens argument to control token limiting in agent state fix(main.py): include limit_tokens in configuration to ensure proper state management * test: Refactor agent creation tests for improved readability and consistency * test: Modify error handling in create_agent test to use side_effect on get_model_token_limit * test: Improve error handling in create_agent test to verify fallback behavior * test: Trigger exception on get_model_token_limit in error handling test * refactor(agent_utils.py): remove unused config parameter from create_agent function to simplify the function signature fix(agent_utils.py): ensure config is always retrieved from _global_memory with a default value to prevent potential errors test(tests/test_agent_utils.py): remove outdated test for create_agent error handling to clean up the test suite * feat: Add debug print for agent_kwargs in create_agent function * refactor: Replace lambda with inner function for state_modifier in agent_utils * refactor: Simplify limit_tokens function to return only message sequences * feat: Add debug print statements to show token trimming details in trim_messages * PAIN * feat(main.py): add debug print statement for args.chat to assist in troubleshooting chat mode feat(agent_utils.py): implement estimate_messages_tokens function to calculate total tokens in messages refactor(agent_utils.py): replace token counting logic in trim_messages_with_removal with estimate_messages_tokens for clarity refactor(agent_utils.py): modify state_modifier to accept model and max_tokens parameters for better flexibility refactor(agent_utils.py): update build_agent_kwargs to pass model to state_modifier for improved functionality * feat: Add .direnvrc to manage Python virtual environment activation * refactor: Update state_modifier to handle first message token count and trim messages * chore: remove unused .direnvrc file to clean up project structure feat: add .envrc to .gitignore to prevent environment configuration file from being tracked fix: update help text for --disable-limit-tokens argument for clarity refactor: clean up imports in agent_utils.py for better readability refactor: remove unused functions and comments in agent_utils.py to streamline code test: add unit tests for state_modifier function to ensure correct message trimming behavior * refactor: Remove commented-out code in create_agent function * feat: Add is_anthropic_claude method to check provider and model name * fix: Correct search/replace block to match existing lines in agent_utils.py * fix(main.py): update help text for --disable-limit-tokens argument to clarify it applies to react agents refactor(agent_utils.py): streamline token limit retrieval and improve readability by removing redundant checks and restructuring code refactor(agent_utils.py): modify build_agent_kwargs to use is_anthropic_claude for clarity and maintainability * test: Update tests to pass config argument to get_model_token_limit() * refactor(agent_utils.py): remove unnecessary print statements and improve function signatures for clarity refactor(agent_utils.py): consolidate provider and model_name retrieval into config parameter for better maintainability * test: Remove redundant token limiting tests from agent creation logic * test: Refactor test_state_modifier to use mock_messages fixture * test(tests): update test description for clarity on state_modifier behavior and use mock_messages for assertions to ensure consistency * refactor(agent_utils.py): simplify token limit retrieval by removing unnecessary variable initialization and defaulting to None in get method * chore(main.py): remove debug print statement for args.chat to clean up output
This commit is contained in:
parent
0c39166172
commit
32fcf914ed
|
|
@ -11,3 +11,4 @@ __pycache__/
|
|||
/venv
|
||||
/.idea
|
||||
/htmlcov
|
||||
.envrc
|
||||
|
|
|
|||
|
|
@ -6,112 +6,120 @@ from rich.panel import Panel
|
|||
from rich.console import Console
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from ra_aid.env import validate_environment
|
||||
from ra_aid.project_info import get_project_info, format_project_info, display_project_status
|
||||
from ra_aid.project_info import (
|
||||
get_project_info,
|
||||
format_project_info,
|
||||
display_project_status,
|
||||
)
|
||||
from ra_aid.tools.memory import _global_memory
|
||||
from ra_aid.tools.human import ask_human
|
||||
from ra_aid import print_stage_header, print_error
|
||||
from ra_aid.tools.human import ask_human
|
||||
from ra_aid.__version__ import __version__
|
||||
from ra_aid.agent_utils import (
|
||||
AgentInterrupt,
|
||||
run_agent_with_retry,
|
||||
run_research_agent,
|
||||
run_planning_agent,
|
||||
create_agent
|
||||
)
|
||||
from ra_aid.prompts import (
|
||||
CHAT_PROMPT,
|
||||
WEB_RESEARCH_PROMPT_SECTION_CHAT
|
||||
create_agent,
|
||||
)
|
||||
from ra_aid.prompts import CHAT_PROMPT, WEB_RESEARCH_PROMPT_SECTION_CHAT
|
||||
from ra_aid.llm import initialize_llm
|
||||
from ra_aid.logging_config import setup_logging, get_logger
|
||||
from ra_aid.tool_configs import (
|
||||
get_chat_tools
|
||||
)
|
||||
from ra_aid.tool_configs import get_chat_tools
|
||||
from ra_aid.dependencies import check_dependencies
|
||||
import os
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def parse_arguments(args=None):
|
||||
VALID_PROVIDERS = ['anthropic', 'openai', 'openrouter', 'openai-compatible', 'gemini']
|
||||
ANTHROPIC_DEFAULT_MODEL = 'claude-3-5-sonnet-20241022'
|
||||
OPENAI_DEFAULT_MODEL = 'gpt-4o'
|
||||
VALID_PROVIDERS = [
|
||||
"anthropic",
|
||||
"openai",
|
||||
"openrouter",
|
||||
"openai-compatible",
|
||||
"gemini",
|
||||
]
|
||||
ANTHROPIC_DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
|
||||
OPENAI_DEFAULT_MODEL = "gpt-4o"
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='RA.Aid - AI Agent for executing programming and research tasks',
|
||||
description="RA.Aid - AI Agent for executing programming and research tasks",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog='''
|
||||
epilog="""
|
||||
Examples:
|
||||
ra-aid -m "Add error handling to the database module"
|
||||
ra-aid -m "Explain the authentication flow" --research-only
|
||||
'''
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
'-m', '--message',
|
||||
"-m",
|
||||
"--message",
|
||||
type=str,
|
||||
help='The task or query to be executed by the agent'
|
||||
help="The task or query to be executed by the agent",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--version',
|
||||
action='version',
|
||||
version=f'%(prog)s {__version__}',
|
||||
help='Show program version number and exit'
|
||||
"--version",
|
||||
action="version",
|
||||
version=f"%(prog)s {__version__}",
|
||||
help="Show program version number and exit",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--research-only',
|
||||
action='store_true',
|
||||
help='Only perform research without implementation'
|
||||
"--research-only",
|
||||
action="store_true",
|
||||
help="Only perform research without implementation",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--provider',
|
||||
"--provider",
|
||||
type=str,
|
||||
default='openai' if (os.getenv('OPENAI_API_KEY') and not os.getenv('ANTHROPIC_API_KEY')) else 'anthropic',
|
||||
default="openai"
|
||||
if (os.getenv("OPENAI_API_KEY") and not os.getenv("ANTHROPIC_API_KEY"))
|
||||
else "anthropic",
|
||||
choices=VALID_PROVIDERS,
|
||||
help='The LLM provider to use'
|
||||
help="The LLM provider to use",
|
||||
)
|
||||
parser.add_argument("--model", type=str, help="The model name to use")
|
||||
parser.add_argument(
|
||||
"--cowboy-mode",
|
||||
action="store_true",
|
||||
help="Skip interactive approval for shell commands",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
"--expert-provider",
|
||||
type=str,
|
||||
help='The model name to use'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--cowboy-mode',
|
||||
action='store_true',
|
||||
help='Skip interactive approval for shell commands'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--expert-provider',
|
||||
type=str,
|
||||
default='openai',
|
||||
default="openai",
|
||||
choices=VALID_PROVIDERS,
|
||||
help='The LLM provider to use for expert knowledge queries (default: openai)'
|
||||
help="The LLM provider to use for expert knowledge queries (default: openai)",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--expert-model',
|
||||
"--expert-model",
|
||||
type=str,
|
||||
help='The model name to use for expert knowledge queries (required for non-OpenAI providers)'
|
||||
help="The model name to use for expert knowledge queries (required for non-OpenAI providers)",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--hil', '-H',
|
||||
action='store_true',
|
||||
help='Enable human-in-the-loop mode, where the agent can prompt the user for additional information.'
|
||||
"--hil",
|
||||
"-H",
|
||||
action="store_true",
|
||||
help="Enable human-in-the-loop mode, where the agent can prompt the user for additional information.",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--chat',
|
||||
action='store_true',
|
||||
help='Enable chat mode with direct human interaction (implies --hil)'
|
||||
"--chat",
|
||||
action="store_true",
|
||||
help="Enable chat mode with direct human interaction (implies --hil)",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--verbose',
|
||||
action='store_true',
|
||||
help='Enable verbose logging output'
|
||||
"--verbose", action="store_true", help="Enable verbose logging output"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--temperature',
|
||||
"--temperature",
|
||||
type=float,
|
||||
help='LLM temperature (0.0-2.0). Controls randomness in responses',
|
||||
default=None
|
||||
help="LLM temperature (0.0-2.0). Controls randomness in responses",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-limit-tokens",
|
||||
action="store_false",
|
||||
help="Whether to disable token limiting for Anthropic Claude react agents. Token limiter removes older messages to prevent maximum token limit API errors.",
|
||||
)
|
||||
|
||||
if args is None:
|
||||
|
|
@ -129,23 +137,34 @@ Examples:
|
|||
|
||||
if parsed_args.provider == "openai":
|
||||
parsed_args.model = parsed_args.model or OPENAI_DEFAULT_MODEL
|
||||
if parsed_args.provider == 'anthropic':
|
||||
if parsed_args.provider == "anthropic":
|
||||
# Always use default model for Anthropic
|
||||
parsed_args.model = ANTHROPIC_DEFAULT_MODEL
|
||||
elif not parsed_args.model and not parsed_args.research_only:
|
||||
# Require model for other providers unless in research mode
|
||||
parser.error(f"--model is required when using provider '{parsed_args.provider}'")
|
||||
parser.error(
|
||||
f"--model is required when using provider '{parsed_args.provider}'"
|
||||
)
|
||||
|
||||
# Validate expert model requirement
|
||||
if parsed_args.expert_provider != 'openai' and not parsed_args.expert_model and not parsed_args.research_only:
|
||||
parser.error(f"--expert-model is required when using expert provider '{parsed_args.expert_provider}'")
|
||||
if (
|
||||
parsed_args.expert_provider != "openai"
|
||||
and not parsed_args.expert_model
|
||||
and not parsed_args.research_only
|
||||
):
|
||||
parser.error(
|
||||
f"--expert-model is required when using expert provider '{parsed_args.expert_provider}'"
|
||||
)
|
||||
|
||||
# Validate temperature range if provided
|
||||
if parsed_args.temperature is not None and not (0.0 <= parsed_args.temperature <= 2.0):
|
||||
parser.error('Temperature must be between 0.0 and 2.0')
|
||||
if parsed_args.temperature is not None and not (
|
||||
0.0 <= parsed_args.temperature <= 2.0
|
||||
):
|
||||
parser.error("Temperature must be between 0.0 and 2.0")
|
||||
|
||||
return parsed_args
|
||||
|
||||
|
||||
# Create console instance
|
||||
console = Console()
|
||||
|
||||
|
|
@ -157,14 +176,18 @@ implementation_memory = MemorySaver()
|
|||
|
||||
def is_informational_query() -> bool:
|
||||
"""Determine if the current query is informational based on implementation_requested state."""
|
||||
return _global_memory.get('config', {}).get('research_only', False) or not is_stage_requested('implementation')
|
||||
return _global_memory.get("config", {}).get(
|
||||
"research_only", False
|
||||
) or not is_stage_requested("implementation")
|
||||
|
||||
|
||||
def is_stage_requested(stage: str) -> bool:
|
||||
"""Check if a stage has been requested to proceed."""
|
||||
if stage == 'implementation':
|
||||
return _global_memory.get('implementation_requested', False)
|
||||
if stage == "implementation":
|
||||
return _global_memory.get("implementation_requested", False)
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the ra-aid command line tool."""
|
||||
args = parse_arguments()
|
||||
|
|
@ -175,26 +198,32 @@ def main():
|
|||
# Check dependencies before proceeding
|
||||
check_dependencies()
|
||||
|
||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) # Will exit if main env vars missing
|
||||
expert_enabled, expert_missing, web_research_enabled, web_research_missing = (
|
||||
validate_environment(args)
|
||||
) # Will exit if main env vars missing
|
||||
logger.debug("Environment validation successful")
|
||||
|
||||
if expert_missing:
|
||||
console.print(Panel(
|
||||
f"[yellow]Expert tools disabled due to missing configuration:[/yellow]\n" +
|
||||
"\n".join(f"- {m}" for m in expert_missing) +
|
||||
"\nSet the required environment variables or args to enable expert mode.",
|
||||
title="Expert Tools Disabled",
|
||||
style="yellow"
|
||||
))
|
||||
console.print(
|
||||
Panel(
|
||||
f"[yellow]Expert tools disabled due to missing configuration:[/yellow]\n"
|
||||
+ "\n".join(f"- {m}" for m in expert_missing)
|
||||
+ "\nSet the required environment variables or args to enable expert mode.",
|
||||
title="Expert Tools Disabled",
|
||||
style="yellow",
|
||||
)
|
||||
)
|
||||
|
||||
if web_research_missing:
|
||||
console.print(Panel(
|
||||
f"[yellow]Web research disabled due to missing configuration:[/yellow]\n" +
|
||||
"\n".join(f"- {m}" for m in web_research_missing) +
|
||||
"\nSet the required environment variables to enable web research.",
|
||||
title="Web Research Disabled",
|
||||
style="yellow"
|
||||
))
|
||||
console.print(
|
||||
Panel(
|
||||
f"[yellow]Web research disabled due to missing configuration:[/yellow]\n"
|
||||
+ "\n".join(f"- {m}" for m in web_research_missing)
|
||||
+ "\nSet the required environment variables to enable web research.",
|
||||
title="Web Research Disabled",
|
||||
style="yellow",
|
||||
)
|
||||
)
|
||||
|
||||
# Create the base model after validation
|
||||
model = initialize_llm(args.provider, args.model, temperature=args.temperature)
|
||||
|
|
@ -216,7 +245,9 @@ def main():
|
|||
formatted_project_info = ""
|
||||
|
||||
# Get initial request from user
|
||||
initial_request = ask_human.invoke({"question": "What would you like help with?"})
|
||||
initial_request = ask_human.invoke(
|
||||
{"question": "What would you like help with?"}
|
||||
)
|
||||
|
||||
# Get working directory and current date
|
||||
working_directory = os.getcwd()
|
||||
|
|
@ -230,31 +261,41 @@ def main():
|
|||
"cowboy_mode": args.cowboy_mode,
|
||||
"hil": True, # Always true in chat mode
|
||||
"web_research_enabled": web_research_enabled,
|
||||
"initial_request": initial_request
|
||||
"initial_request": initial_request,
|
||||
"limit_tokens": args.disable_limit_tokens,
|
||||
}
|
||||
|
||||
# Store config in global memory
|
||||
_global_memory['config'] = config
|
||||
_global_memory['config']['provider'] = args.provider
|
||||
_global_memory['config']['model'] = args.model
|
||||
_global_memory['config']['expert_provider'] = args.expert_provider
|
||||
_global_memory['config']['expert_model'] = args.expert_model
|
||||
_global_memory["config"] = config
|
||||
_global_memory["config"]["provider"] = args.provider
|
||||
_global_memory["config"]["model"] = args.model
|
||||
_global_memory["config"]["expert_provider"] = args.expert_provider
|
||||
_global_memory["config"]["expert_model"] = args.expert_model
|
||||
|
||||
# Create chat agent with appropriate tools
|
||||
chat_agent = create_agent(
|
||||
model,
|
||||
get_chat_tools(expert_enabled=expert_enabled, web_research_enabled=web_research_enabled),
|
||||
checkpointer=MemorySaver()
|
||||
get_chat_tools(
|
||||
expert_enabled=expert_enabled,
|
||||
web_research_enabled=web_research_enabled,
|
||||
),
|
||||
checkpointer=MemorySaver(),
|
||||
)
|
||||
|
||||
# Run chat agent and exit
|
||||
run_agent_with_retry(chat_agent, CHAT_PROMPT.format(
|
||||
run_agent_with_retry(
|
||||
chat_agent,
|
||||
CHAT_PROMPT.format(
|
||||
initial_request=initial_request,
|
||||
web_research_section=WEB_RESEARCH_PROMPT_SECTION_CHAT if web_research_enabled else "",
|
||||
web_research_section=WEB_RESEARCH_PROMPT_SECTION_CHAT
|
||||
if web_research_enabled
|
||||
else "",
|
||||
working_directory=working_directory,
|
||||
current_date=current_date,
|
||||
project_info=formatted_project_info
|
||||
), config)
|
||||
project_info=formatted_project_info,
|
||||
),
|
||||
config,
|
||||
)
|
||||
return
|
||||
|
||||
# Validate message is provided
|
||||
|
|
@ -268,19 +309,20 @@ def main():
|
|||
"recursion_limit": 100,
|
||||
"research_only": args.research_only,
|
||||
"cowboy_mode": args.cowboy_mode,
|
||||
"web_research_enabled": web_research_enabled
|
||||
"web_research_enabled": web_research_enabled,
|
||||
"limit_tokens": args.disable_limit_tokens,
|
||||
}
|
||||
|
||||
# Store config in global memory for access by is_informational_query
|
||||
_global_memory['config'] = config
|
||||
_global_memory["config"] = config
|
||||
|
||||
# Store model configuration
|
||||
_global_memory['config']['provider'] = args.provider
|
||||
_global_memory['config']['model'] = args.model
|
||||
_global_memory["config"]["provider"] = args.provider
|
||||
_global_memory["config"]["model"] = args.model
|
||||
|
||||
# Store expert provider and model in config
|
||||
_global_memory['config']['expert_provider'] = args.expert_provider
|
||||
_global_memory['config']['expert_model'] = args.expert_model
|
||||
_global_memory["config"]["expert_provider"] = args.expert_provider
|
||||
_global_memory["config"]["expert_model"] = args.expert_model
|
||||
|
||||
# Run research stage
|
||||
print_stage_header("Research Stage")
|
||||
|
|
@ -292,7 +334,7 @@ def main():
|
|||
research_only=args.research_only,
|
||||
hil=args.hil,
|
||||
memory=research_memory,
|
||||
config=config
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Proceed with planning and implementation if not an informational query
|
||||
|
|
@ -304,7 +346,7 @@ def main():
|
|||
expert_enabled=expert_enabled,
|
||||
hil=args.hil,
|
||||
memory=planning_memory,
|
||||
config=config
|
||||
config=config,
|
||||
)
|
||||
|
||||
except (KeyboardInterrupt, AgentInterrupt):
|
||||
|
|
@ -313,5 +355,6 @@ def main():
|
|||
print()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -3,22 +3,27 @@
|
|||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional, Any
|
||||
from typing import Optional, Any, List, Dict, Sequence
|
||||
from langchain_core.messages import BaseMessage, trim_messages
|
||||
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from ra_aid.project_info import get_project_info, format_project_info, display_project_status
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||
from ra_aid.models_tokens import DEFAULT_TOKEN_LIMIT, models_tokens
|
||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||
import threading
|
||||
|
||||
from ra_aid.project_info import (
|
||||
get_project_info,
|
||||
format_project_info,
|
||||
display_project_status,
|
||||
)
|
||||
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||
from ra_aid.project_info import get_project_info, format_project_info, display_project_status
|
||||
from ra_aid.console.formatting import print_stage_header, print_error
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import tool
|
||||
from typing import List, Any
|
||||
from ra_aid.console.output import print_agent_output
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.exceptions import AgentInterrupt
|
||||
|
|
@ -26,7 +31,7 @@ from ra_aid.tool_configs import (
|
|||
get_implementation_tools,
|
||||
get_research_tools,
|
||||
get_planning_tools,
|
||||
get_web_research_tools
|
||||
get_web_research_tools,
|
||||
)
|
||||
from ra_aid.prompts import (
|
||||
IMPLEMENTATION_PROMPT,
|
||||
|
|
@ -41,13 +46,9 @@ from ra_aid.prompts import (
|
|||
HUMAN_PROMPT_SECTION_RESEARCH,
|
||||
PLANNING_PROMPT,
|
||||
EXPERT_PROMPT_SECTION_PLANNING,
|
||||
WEB_RESEARCH_PROMPT_SECTION_PLANNING,
|
||||
HUMAN_PROMPT_SECTION_PLANNING,
|
||||
WEB_RESEARCH_PROMPT,
|
||||
EXPERT_PROMPT_SECTION_CHAT,
|
||||
CHAT_PROMPT,
|
||||
)
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError
|
||||
|
|
@ -60,58 +61,189 @@ from ra_aid.tools.memory import (
|
|||
get_memory_value,
|
||||
get_related_files,
|
||||
)
|
||||
from ra_aid.tool_configs import get_research_tools
|
||||
from ra_aid.prompts import (
|
||||
RESEARCH_PROMPT,
|
||||
RESEARCH_ONLY_PROMPT,
|
||||
EXPERT_PROMPT_SECTION_RESEARCH,
|
||||
HUMAN_PROMPT_SECTION_RESEARCH
|
||||
)
|
||||
|
||||
|
||||
console = Console()
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@tool
|
||||
def output_markdown_message(message: str) -> str:
|
||||
"""Outputs a message to the user, optionally prompting for input."""
|
||||
console.print(Panel(Markdown(message.strip()), title="🤖 Assistant"))
|
||||
return "Message output."
|
||||
|
||||
|
||||
def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
"""Helper function to estimate total tokens in a sequence of messages.
|
||||
|
||||
Args:
|
||||
messages: Sequence of messages to count tokens for
|
||||
|
||||
Returns:
|
||||
Total estimated token count
|
||||
"""
|
||||
if not messages:
|
||||
return 0
|
||||
|
||||
estimate_tokens = CiaynAgent._estimate_tokens
|
||||
return sum(estimate_tokens(msg) for msg in messages)
|
||||
|
||||
|
||||
def state_modifier(
|
||||
state: AgentState, max_tokens: int = DEFAULT_TOKEN_LIMIT
|
||||
) -> list[BaseMessage]:
|
||||
"""Given the agent state and max_tokens, return a trimmed list of messages.
|
||||
|
||||
Args:
|
||||
state: The current agent state containing messages
|
||||
max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
|
||||
|
||||
Returns:
|
||||
list[BaseMessage]: Trimmed list of messages that fits within token limit
|
||||
"""
|
||||
messages = state["messages"]
|
||||
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
first_message = messages[0]
|
||||
remaining_messages = messages[1:]
|
||||
first_tokens = estimate_messages_tokens([first_message])
|
||||
new_max_tokens = max_tokens - first_tokens
|
||||
|
||||
trimmed_remaining = trim_messages(
|
||||
remaining_messages,
|
||||
token_counter=estimate_messages_tokens,
|
||||
max_tokens=new_max_tokens,
|
||||
strategy="last",
|
||||
allow_partial=False,
|
||||
)
|
||||
|
||||
return [first_message] + trimmed_remaining
|
||||
|
||||
|
||||
def get_model_token_limit(config: Dict[str, Any]) -> Optional[int]:
|
||||
"""Get the token limit for the current model configuration.
|
||||
|
||||
Returns:
|
||||
Optional[int]: The token limit if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
provider = config.get("provider", "")
|
||||
model_name = config.get("model", "")
|
||||
|
||||
provider_tokens = models_tokens.get(provider, {})
|
||||
token_limit = provider_tokens.get(model_name, None)
|
||||
if token_limit:
|
||||
logger.debug(
|
||||
f"Found token limit for {provider}/{model_name}: {token_limit}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Could not find token limit for {provider}/{model_name}")
|
||||
|
||||
return token_limit
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get model token limit: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def build_agent_kwargs(
|
||||
checkpointer: Optional[Any] = None,
|
||||
config: Dict[str, Any] = None,
|
||||
token_limit: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build kwargs dictionary for agent creation.
|
||||
|
||||
Args:
|
||||
checkpointer: Optional memory checkpointer
|
||||
config: Optional configuration dictionary
|
||||
token_limit: Optional token limit for the model
|
||||
|
||||
Returns:
|
||||
Dictionary of kwargs for agent creation
|
||||
"""
|
||||
agent_kwargs = {}
|
||||
|
||||
if checkpointer is not None:
|
||||
agent_kwargs["checkpointer"] = checkpointer
|
||||
|
||||
if config.get("limit_tokens", True) and is_anthropic_claude(config):
|
||||
|
||||
def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]:
|
||||
return state_modifier(state, max_tokens=token_limit)
|
||||
|
||||
agent_kwargs["state_modifier"] = wrapped_state_modifier
|
||||
|
||||
return agent_kwargs
|
||||
|
||||
|
||||
def is_anthropic_claude(config: Dict[str, Any]) -> bool:
|
||||
"""Check if the provider and model name indicate an Anthropic Claude model.
|
||||
|
||||
Args:
|
||||
provider: The provider name
|
||||
model_name: The model name
|
||||
|
||||
Returns:
|
||||
bool: True if this is an Anthropic Claude model
|
||||
"""
|
||||
provider = config.get("provider", "")
|
||||
model_name = config.get("model", "")
|
||||
return (
|
||||
provider.lower() == "anthropic"
|
||||
and model_name
|
||||
and "claude" in model_name.lower()
|
||||
)
|
||||
|
||||
|
||||
def create_agent(
|
||||
model: BaseChatModel,
|
||||
tools: List[Any],
|
||||
*,
|
||||
checkpointer: Any = None
|
||||
checkpointer: Any = None,
|
||||
) -> Any:
|
||||
"""Create a react agent with the given configuration.
|
||||
|
||||
|
||||
Args:
|
||||
model: The LLM model to use
|
||||
tools: List of tools to provide to the agent
|
||||
checkpointer: Optional memory checkpointer
|
||||
|
||||
config: Optional configuration dictionary containing settings like:
|
||||
- limit_tokens (bool): Whether to apply token limiting (default: True)
|
||||
- provider (str): The LLM provider name
|
||||
- model (str): The model name
|
||||
|
||||
Returns:
|
||||
The created agent instance
|
||||
|
||||
Token limiting helps prevent context window overflow by trimming older messages
|
||||
while preserving system messages. It can be disabled by setting
|
||||
config['limit_tokens'] = False.
|
||||
"""
|
||||
try:
|
||||
# Get model name if available
|
||||
provider = _global_memory.get('config', {}).get('provider')
|
||||
model_name = _global_memory.get('config', {}).get('model')
|
||||
|
||||
config = _global_memory.get("config", {})
|
||||
token_limit = get_model_token_limit(config) or DEFAULT_TOKEN_LIMIT
|
||||
|
||||
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN
|
||||
if provider == 'anthropic' and 'claude' in model_name:
|
||||
if is_anthropic_claude(config):
|
||||
logger.debug("Using create_react_agent to instantiate agent.")
|
||||
return create_react_agent(model, tools, checkpointer=checkpointer)
|
||||
agent_kwargs = build_agent_kwargs(checkpointer, config, token_limit)
|
||||
return create_react_agent(model, tools, **agent_kwargs)
|
||||
else:
|
||||
logger.debug("Using CiaynAgent agent instance.")
|
||||
return CiaynAgent(model, tools)
|
||||
|
||||
logger.debug("Using CiaynAgent agent instance")
|
||||
return CiaynAgent(model, tools, max_tokens=token_limit)
|
||||
|
||||
except Exception as e:
|
||||
# Default to REACT agent if provider/model detection fails
|
||||
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
|
||||
return create_react_agent(model, tools, checkpointer=checkpointer)
|
||||
config = _global_memory.get("config", {})
|
||||
token_limit = get_model_token_limit(config)
|
||||
agent_kwargs = build_agent_kwargs(checkpointer, config, token_limit)
|
||||
return create_react_agent(model, tools, **agent_kwargs)
|
||||
|
||||
|
||||
def run_research_agent(
|
||||
base_task_or_query: str,
|
||||
|
|
@ -124,7 +256,7 @@ def run_research_agent(
|
|||
memory: Optional[Any] = None,
|
||||
config: Optional[dict] = None,
|
||||
thread_id: Optional[str] = None,
|
||||
console_message: Optional[str] = None
|
||||
console_message: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""Run a research agent with the given configuration.
|
||||
|
||||
|
|
@ -153,8 +285,13 @@ def run_research_agent(
|
|||
"""
|
||||
thread_id = thread_id or str(uuid.uuid4())
|
||||
logger.debug("Starting research agent with thread_id=%s", thread_id)
|
||||
logger.debug("Research configuration: expert=%s, research_only=%s, hil=%s, web=%s",
|
||||
expert_enabled, research_only, hil, web_research_enabled)
|
||||
logger.debug(
|
||||
"Research configuration: expert=%s, research_only=%s, hil=%s, web=%s",
|
||||
expert_enabled,
|
||||
research_only,
|
||||
hil,
|
||||
web_research_enabled,
|
||||
)
|
||||
|
||||
# Initialize memory if not provided
|
||||
if memory is None:
|
||||
|
|
@ -169,7 +306,7 @@ def run_research_agent(
|
|||
research_only=research_only,
|
||||
expert_enabled=expert_enabled,
|
||||
human_interaction=hil,
|
||||
web_research_enabled=config.get('web_research_enabled', False)
|
||||
web_research_enabled=config.get("web_research_enabled", False),
|
||||
)
|
||||
|
||||
# Create agent
|
||||
|
|
@ -178,7 +315,11 @@ def run_research_agent(
|
|||
# Format prompt sections
|
||||
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
|
||||
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
|
||||
web_research_section = WEB_RESEARCH_PROMPT_SECTION_RESEARCH if config.get('web_research_enabled') else ""
|
||||
web_research_section = (
|
||||
WEB_RESEARCH_PROMPT_SECTION_RESEARCH
|
||||
if config.get("web_research_enabled")
|
||||
else ""
|
||||
)
|
||||
|
||||
# Get research context from memory
|
||||
key_facts = _global_memory.get("key_facts", "")
|
||||
|
|
@ -196,29 +337,30 @@ def run_research_agent(
|
|||
# Build prompt
|
||||
prompt = (RESEARCH_ONLY_PROMPT if research_only else RESEARCH_PROMPT).format(
|
||||
base_task=base_task_or_query,
|
||||
research_only_note='' if research_only else ' Only request implementation if the user explicitly asked for changes to be made.',
|
||||
research_only_note=""
|
||||
if research_only
|
||||
else " Only request implementation if the user explicitly asked for changes to be made.",
|
||||
expert_section=expert_section,
|
||||
human_section=human_section,
|
||||
web_research_section=web_research_section,
|
||||
key_facts=key_facts,
|
||||
work_log=get_memory_value('work_log'),
|
||||
work_log=get_memory_value("work_log"),
|
||||
code_snippets=code_snippets,
|
||||
related_files=related_files,
|
||||
project_info=formatted_project_info
|
||||
project_info=formatted_project_info,
|
||||
)
|
||||
|
||||
# Set up configuration
|
||||
run_config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
"recursion_limit": 100
|
||||
}
|
||||
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 100}
|
||||
if config:
|
||||
run_config.update(config)
|
||||
|
||||
try:
|
||||
# Display console message if provided
|
||||
if console_message:
|
||||
console.print(Panel(Markdown(console_message), title="🔬 Looking into it..."))
|
||||
console.print(
|
||||
Panel(Markdown(console_message), title="🔬 Looking into it...")
|
||||
)
|
||||
|
||||
if project_info:
|
||||
display_project_status(project_info)
|
||||
|
|
@ -239,7 +381,7 @@ def run_research_agent(
|
|||
memory=memory,
|
||||
config=config,
|
||||
thread_id=thread_id,
|
||||
console_message=console_message
|
||||
console_message=console_message,
|
||||
)
|
||||
except (KeyboardInterrupt, AgentInterrupt):
|
||||
raise
|
||||
|
|
@ -247,6 +389,7 @@ def run_research_agent(
|
|||
logger.error("Research agent failed: %s", str(e), exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def run_web_research_agent(
|
||||
query: str,
|
||||
model,
|
||||
|
|
@ -257,7 +400,7 @@ def run_web_research_agent(
|
|||
memory: Optional[Any] = None,
|
||||
config: Optional[dict] = None,
|
||||
thread_id: Optional[str] = None,
|
||||
console_message: Optional[str] = None
|
||||
console_message: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""Run a web research agent with the given configuration.
|
||||
|
||||
|
|
@ -284,8 +427,12 @@ def run_web_research_agent(
|
|||
"""
|
||||
thread_id = thread_id or str(uuid.uuid4())
|
||||
logger.debug("Starting web research agent with thread_id=%s", thread_id)
|
||||
logger.debug("Web research configuration: expert=%s, hil=%s, web=%s",
|
||||
expert_enabled, hil, web_research_enabled)
|
||||
logger.debug(
|
||||
"Web research configuration: expert=%s, hil=%s, web=%s",
|
||||
expert_enabled,
|
||||
hil,
|
||||
web_research_enabled,
|
||||
)
|
||||
|
||||
# Initialize memory if not provided
|
||||
if memory is None:
|
||||
|
|
@ -317,14 +464,11 @@ def run_web_research_agent(
|
|||
human_section=human_section,
|
||||
key_facts=key_facts,
|
||||
code_snippets=code_snippets,
|
||||
related_files=related_files
|
||||
related_files=related_files,
|
||||
)
|
||||
|
||||
# Set up configuration
|
||||
run_config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
"recursion_limit": 100
|
||||
}
|
||||
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 100}
|
||||
if config:
|
||||
run_config.update(config)
|
||||
|
||||
|
|
@ -342,6 +486,7 @@ def run_web_research_agent(
|
|||
logger.error("Web research agent failed: %s", str(e), exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def run_planning_agent(
|
||||
base_task: str,
|
||||
model,
|
||||
|
|
@ -350,7 +495,7 @@ def run_planning_agent(
|
|||
hil: bool = False,
|
||||
memory: Optional[Any] = None,
|
||||
config: Optional[dict] = None,
|
||||
thread_id: Optional[str] = None
|
||||
thread_id: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""Run a planning agent to create implementation plans.
|
||||
|
||||
|
|
@ -379,7 +524,10 @@ def run_planning_agent(
|
|||
thread_id = str(uuid.uuid4())
|
||||
|
||||
# Configure tools
|
||||
tools = get_planning_tools(expert_enabled=expert_enabled, web_research_enabled=config.get('web_research_enabled', False))
|
||||
tools = get_planning_tools(
|
||||
expert_enabled=expert_enabled,
|
||||
web_research_enabled=config.get("web_research_enabled", False),
|
||||
)
|
||||
|
||||
# Create agent
|
||||
agent = create_agent(model, tools, checkpointer=memory)
|
||||
|
|
@ -387,7 +535,11 @@ def run_planning_agent(
|
|||
# Format prompt sections
|
||||
expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else ""
|
||||
human_section = HUMAN_PROMPT_SECTION_PLANNING if hil else ""
|
||||
web_research_section = WEB_RESEARCH_PROMPT_SECTION_PLANNING if config.get('web_research_enabled') else ""
|
||||
web_research_section = (
|
||||
WEB_RESEARCH_PROMPT_SECTION_PLANNING
|
||||
if config.get("web_research_enabled")
|
||||
else ""
|
||||
)
|
||||
|
||||
# Build prompt
|
||||
planning_prompt = PLANNING_PROMPT.format(
|
||||
|
|
@ -395,19 +547,18 @@ def run_planning_agent(
|
|||
human_section=human_section,
|
||||
web_research_section=web_research_section,
|
||||
base_task=base_task,
|
||||
research_notes=get_memory_value('research_notes'),
|
||||
research_notes=get_memory_value("research_notes"),
|
||||
related_files="\n".join(get_related_files()),
|
||||
key_facts=get_memory_value('key_facts'),
|
||||
key_snippets=get_memory_value('key_snippets'),
|
||||
work_log=get_memory_value('work_log'),
|
||||
research_only_note='' if config.get('research_only') else ' Only request implementation if the user explicitly asked for changes to be made.'
|
||||
key_facts=get_memory_value("key_facts"),
|
||||
key_snippets=get_memory_value("key_snippets"),
|
||||
work_log=get_memory_value("work_log"),
|
||||
research_only_note=""
|
||||
if config.get("research_only")
|
||||
else " Only request implementation if the user explicitly asked for changes to be made.",
|
||||
)
|
||||
|
||||
# Set up configuration
|
||||
run_config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
"recursion_limit": 100
|
||||
}
|
||||
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 100}
|
||||
if config:
|
||||
run_config.update(config)
|
||||
|
||||
|
|
@ -421,6 +572,7 @@ def run_planning_agent(
|
|||
logger.error("Planning agent failed: %s", str(e), exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def run_task_implementation_agent(
|
||||
base_task: str,
|
||||
tasks: list,
|
||||
|
|
@ -433,7 +585,7 @@ def run_task_implementation_agent(
|
|||
web_research_enabled: bool = False,
|
||||
memory: Optional[Any] = None,
|
||||
config: Optional[dict] = None,
|
||||
thread_id: Optional[str] = None
|
||||
thread_id: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""Run an implementation agent for a specific task.
|
||||
|
||||
|
|
@ -454,7 +606,11 @@ def run_task_implementation_agent(
|
|||
"""
|
||||
thread_id = thread_id or str(uuid.uuid4())
|
||||
logger.debug("Starting implementation agent with thread_id=%s", thread_id)
|
||||
logger.debug("Implementation configuration: expert=%s, web=%s", expert_enabled, web_research_enabled)
|
||||
logger.debug(
|
||||
"Implementation configuration: expert=%s, web=%s",
|
||||
expert_enabled,
|
||||
web_research_enabled,
|
||||
)
|
||||
logger.debug("Task details: base_task=%s, current_task=%s", base_task, task)
|
||||
logger.debug("Related files: %s", related_files)
|
||||
|
||||
|
|
@ -467,7 +623,10 @@ def run_task_implementation_agent(
|
|||
thread_id = str(uuid.uuid4())
|
||||
|
||||
# Configure tools
|
||||
tools = get_implementation_tools(expert_enabled=expert_enabled, web_research_enabled=config.get('web_research_enabled', False))
|
||||
tools = get_implementation_tools(
|
||||
expert_enabled=expert_enabled,
|
||||
web_research_enabled=config.get("web_research_enabled", False),
|
||||
)
|
||||
|
||||
# Create agent
|
||||
agent = create_agent(model, tools, checkpointer=memory)
|
||||
|
|
@ -479,20 +638,21 @@ def run_task_implementation_agent(
|
|||
tasks=tasks,
|
||||
plan=plan,
|
||||
related_files=related_files,
|
||||
key_facts=get_memory_value('key_facts'),
|
||||
key_snippets=get_memory_value('key_snippets'),
|
||||
research_notes=get_memory_value('research_notes'),
|
||||
work_log=get_memory_value('work_log'),
|
||||
key_facts=get_memory_value("key_facts"),
|
||||
key_snippets=get_memory_value("key_snippets"),
|
||||
research_notes=get_memory_value("research_notes"),
|
||||
work_log=get_memory_value("work_log"),
|
||||
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
|
||||
human_section=HUMAN_PROMPT_SECTION_IMPLEMENTATION if _global_memory.get('config', {}).get('hil', False) else "",
|
||||
web_research_section=WEB_RESEARCH_PROMPT_SECTION_CHAT if config.get('web_research_enabled') else ""
|
||||
human_section=HUMAN_PROMPT_SECTION_IMPLEMENTATION
|
||||
if _global_memory.get("config", {}).get("hil", False)
|
||||
else "",
|
||||
web_research_section=WEB_RESEARCH_PROMPT_SECTION_CHAT
|
||||
if config.get("web_research_enabled")
|
||||
else "",
|
||||
)
|
||||
|
||||
# Set up configuration
|
||||
run_config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
"recursion_limit": 100
|
||||
}
|
||||
run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 100}
|
||||
if config:
|
||||
run_config.update(config)
|
||||
|
||||
|
|
@ -505,10 +665,12 @@ def run_task_implementation_agent(
|
|||
logger.error("Implementation agent failed: %s", str(e), exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
_CONTEXT_STACK = []
|
||||
_INTERRUPT_CONTEXT = None
|
||||
_FEEDBACK_MODE = False
|
||||
|
||||
|
||||
def _request_interrupt(signum, frame):
|
||||
global _INTERRUPT_CONTEXT
|
||||
if _CONTEXT_STACK:
|
||||
|
|
@ -520,6 +682,7 @@ def _request_interrupt(signum, frame):
|
|||
print()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
class InterruptibleSection:
|
||||
def __enter__(self):
|
||||
_CONTEXT_STACK.append(self)
|
||||
|
|
@ -528,10 +691,12 @@ class InterruptibleSection:
|
|||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
_CONTEXT_STACK.remove(self)
|
||||
|
||||
|
||||
def check_interrupt():
|
||||
if _CONTEXT_STACK and _INTERRUPT_CONTEXT is _CONTEXT_STACK[-1]:
|
||||
raise AgentInterrupt("Interrupt requested")
|
||||
|
||||
|
||||
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
|
||||
"""Run an agent with retry logic for API errors."""
|
||||
logger.debug("Running agent with prompt length: %d", len(prompt))
|
||||
|
|
@ -546,48 +711,68 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
|
|||
with InterruptibleSection():
|
||||
try:
|
||||
# Track agent execution depth
|
||||
current_depth = _global_memory.get('agent_depth', 0)
|
||||
_global_memory['agent_depth'] = current_depth + 1
|
||||
current_depth = _global_memory.get("agent_depth", 0)
|
||||
_global_memory["agent_depth"] = current_depth + 1
|
||||
|
||||
for attempt in range(max_retries):
|
||||
logger.debug("Attempt %d/%d", attempt + 1, max_retries)
|
||||
check_interrupt()
|
||||
try:
|
||||
for chunk in agent.stream({"messages": [HumanMessage(content=prompt)]}, config):
|
||||
for chunk in agent.stream(
|
||||
{"messages": [HumanMessage(content=prompt)]}, config
|
||||
):
|
||||
logger.debug("Agent output: %s", chunk)
|
||||
check_interrupt()
|
||||
print_agent_output(chunk)
|
||||
if _global_memory['plan_completed']:
|
||||
_global_memory['plan_completed'] = False
|
||||
_global_memory['task_completed'] = False
|
||||
_global_memory['completion_message'] = ''
|
||||
if _global_memory["plan_completed"]:
|
||||
_global_memory["plan_completed"] = False
|
||||
_global_memory["task_completed"] = False
|
||||
_global_memory["completion_message"] = ""
|
||||
break
|
||||
if _global_memory['task_completed']:
|
||||
_global_memory['task_completed'] = False
|
||||
_global_memory['completion_message'] = ''
|
||||
if _global_memory["task_completed"]:
|
||||
_global_memory["task_completed"] = False
|
||||
_global_memory["completion_message"] = ""
|
||||
break
|
||||
logger.debug("Agent run completed successfully")
|
||||
return "Agent run completed successfully"
|
||||
except (KeyboardInterrupt, AgentInterrupt):
|
||||
raise
|
||||
except (InternalServerError, APITimeoutError, RateLimitError, APIError, ValueError) as e:
|
||||
except (
|
||||
InternalServerError,
|
||||
APITimeoutError,
|
||||
RateLimitError,
|
||||
APIError,
|
||||
ValueError,
|
||||
) as e:
|
||||
if isinstance(e, ValueError):
|
||||
error_str = str(e).lower()
|
||||
if 'code' not in error_str or '429' not in error_str:
|
||||
if "code" not in error_str or "429" not in error_str:
|
||||
raise # Re-raise ValueError if it's not a Lambda 429
|
||||
if attempt == max_retries - 1:
|
||||
logger.error("Max retries reached, failing: %s", str(e))
|
||||
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {e}")
|
||||
logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e))
|
||||
delay = base_delay * (2 ** attempt)
|
||||
print_error(f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})")
|
||||
raise RuntimeError(
|
||||
f"Max retries ({max_retries}) exceeded. Last error: {e}"
|
||||
)
|
||||
logger.warning(
|
||||
"API error (attempt %d/%d): %s",
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
str(e),
|
||||
)
|
||||
delay = base_delay * (2**attempt)
|
||||
print_error(
|
||||
f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
|
||||
)
|
||||
start = time.monotonic()
|
||||
while time.monotonic() - start < delay:
|
||||
check_interrupt()
|
||||
time.sleep(0.1)
|
||||
finally:
|
||||
# Reset depth tracking
|
||||
_global_memory['agent_depth'] = _global_memory.get('agent_depth', 1) - 1
|
||||
_global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1
|
||||
|
||||
if original_handler and threading.current_thread() is threading.main_thread():
|
||||
if (
|
||||
original_handler
|
||||
and threading.current_thread() is threading.main_thread()
|
||||
):
|
||||
signal.signal(signal.SIGINT, original_handler)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Any, Generator, List, Optional, Union
|
||||
from typing import Dict, Any, Generator, List, Optional, Union
|
||||
|
||||
from ra_aid.models_tokens import DEFAULT_TOKEN_LIMIT
|
||||
from ra_aid.tools.reflection import get_function_info
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage, SystemMessage
|
||||
|
|
@ -66,7 +66,7 @@ class CiaynAgent:
|
|||
"""
|
||||
|
||||
|
||||
def __init__(self, model, tools: list, max_history_messages: int = 50, max_tokens: Optional[int] = 100000):
|
||||
def __init__(self, model, tools: list, max_history_messages: int = 50, max_tokens: Optional[int] = DEFAULT_TOKEN_LIMIT):
|
||||
"""Initialize the agent with a model and list of tools.
|
||||
|
||||
Args:
|
||||
|
|
@ -263,6 +263,10 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
|||
text = content.content
|
||||
else:
|
||||
text = content
|
||||
|
||||
# create-react-agent tool calls can be lists
|
||||
if isinstance(text, List):
|
||||
return 0
|
||||
|
||||
if not text:
|
||||
return 0
|
||||
|
|
|
|||
|
|
@ -104,4 +104,4 @@ def initialize_expert_llm(provider: str = "openai", model_name: str = "o1") -> B
|
|||
model=model_name,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,266 @@
|
|||
"""
|
||||
List of model tokens
|
||||
"""
|
||||
|
||||
DEFAULT_TOKEN_LIMIT = 100000
|
||||
|
||||
models_tokens = {
|
||||
"openai": {
|
||||
"gpt-3.5-turbo-0125": 16385,
|
||||
"gpt-3.5": 4096,
|
||||
"gpt-3.5-turbo": 16385,
|
||||
"gpt-3.5-turbo-1106": 16385,
|
||||
"gpt-3.5-turbo-instruct": 4096,
|
||||
"gpt-4-0125-preview": 128000,
|
||||
"gpt-4-turbo-preview": 128000,
|
||||
"gpt-4-turbo": 128000,
|
||||
"gpt-4-turbo-2024-04-09": 128000,
|
||||
"gpt-4-1106-preview": 128000,
|
||||
"gpt-4-vision-preview": 128000,
|
||||
"gpt-4": 8192,
|
||||
"gpt-4-0613": 8192,
|
||||
"gpt-4-32k": 32768,
|
||||
"gpt-4-32k-0613": 32768,
|
||||
"gpt-4o": 128000,
|
||||
"gpt-4o-2024-08-06": 128000,
|
||||
"gpt-4o-2024-05-13": 128000,
|
||||
"gpt-4o-mini": 128000,
|
||||
"o1-preview": 128000,
|
||||
"o1-mini": 128000,
|
||||
},
|
||||
"azure_openai": {
|
||||
"gpt-3.5-turbo-0125": 16385,
|
||||
"gpt-3.5": 4096,
|
||||
"gpt-3.5-turbo": 16385,
|
||||
"gpt-3.5-turbo-1106": 16385,
|
||||
"gpt-3.5-turbo-instruct": 4096,
|
||||
"gpt-4-0125-preview": 128000,
|
||||
"gpt-4-turbo-preview": 128000,
|
||||
"gpt-4-turbo": 128000,
|
||||
"gpt-4-turbo-2024-04-09": 128000,
|
||||
"gpt-4-1106-preview": 128000,
|
||||
"gpt-4-vision-preview": 128000,
|
||||
"gpt-4": 8192,
|
||||
"gpt-4-0613": 8192,
|
||||
"gpt-4-32k": 32768,
|
||||
"gpt-4-32k-0613": 32768,
|
||||
"gpt-4o": 128000,
|
||||
"gpt-4o-mini": 128000,
|
||||
"chatgpt-4o-latest": 128000,
|
||||
"o1-preview": 128000,
|
||||
"o1-mini": 128000,
|
||||
},
|
||||
"google_genai": {
|
||||
"gemini-pro": 128000,
|
||||
"gemini-1.5-flash-latest": 128000,
|
||||
"gemini-1.5-pro-latest": 128000,
|
||||
"models/embedding-001": 2048,
|
||||
},
|
||||
"google_vertexai": {
|
||||
"gemini-1.5-flash": 128000,
|
||||
"gemini-1.5-pro": 128000,
|
||||
"gemini-1.0-pro": 128000,
|
||||
},
|
||||
"ollama": {
|
||||
"command-r": 12800,
|
||||
"codellama": 16000,
|
||||
"dbrx": 32768,
|
||||
"deepseek-coder:33b": 16000,
|
||||
"falcon": 2048,
|
||||
"llama2": 4096,
|
||||
"llama2:7b": 4096,
|
||||
"llama2:13b": 4096,
|
||||
"llama2:70b": 4096,
|
||||
"llama3": 8192,
|
||||
"llama3:8b": 8192,
|
||||
"llama3:70b": 8192,
|
||||
"llama3.1": 128000,
|
||||
"llama3.1:8b": 128000,
|
||||
"llama3.1:70b": 128000,
|
||||
"lama3.1:405b": 128000,
|
||||
"llama3.2": 128000,
|
||||
"llama3.2:1b": 128000,
|
||||
"llama3.2:3b": 128000,
|
||||
"llama3.3:70b": 128000,
|
||||
"scrapegraph": 8192,
|
||||
"mistral-small": 128000,
|
||||
"mistral-openorca": 32000,
|
||||
"mistral-large": 128000,
|
||||
"grok-1": 8192,
|
||||
"llava": 4096,
|
||||
"mixtral:8x22b-instruct": 65536,
|
||||
"nomic-embed-text": 8192,
|
||||
"nous-hermes2:34b": 4096,
|
||||
"orca-mini": 2048,
|
||||
"phi3:3.8b": 12800,
|
||||
"phi3:14b": 128000,
|
||||
"qwen:0.5b": 32000,
|
||||
"qwen:1.8b": 32000,
|
||||
"qwen:4b": 32000,
|
||||
"qwen:14b": 32000,
|
||||
"qwen:32b": 32000,
|
||||
"qwen:72b": 32000,
|
||||
"qwen:110b": 32000,
|
||||
"stablelm-zephyr": 8192,
|
||||
"wizardlm2:8x22b": 65536,
|
||||
"mistral": 128000,
|
||||
"gemma2": 128000,
|
||||
"gemma2:9b": 128000,
|
||||
"gemma2:27b": 128000,
|
||||
# embedding models
|
||||
"shaw/dmeta-embedding-zh-small-q4": 8192,
|
||||
"shaw/dmeta-embedding-zh-q4": 8192,
|
||||
"chevalblanc/acge_text_embedding": 8192,
|
||||
"martcreation/dmeta-embedding-zh": 8192,
|
||||
"snowflake-arctic-embed": 8192,
|
||||
"mxbai-embed-large": 512,
|
||||
},
|
||||
"oneapi": {
|
||||
"qwen-turbo": 6000,
|
||||
},
|
||||
"nvidia": {
|
||||
"meta/llama3-70b-instruct": 419,
|
||||
"meta/llama3-8b-instruct": 419,
|
||||
"nemotron-4-340b-instruct": 1024,
|
||||
"databricks/dbrx-instruct": 4096,
|
||||
"google/codegemma-7b": 8192,
|
||||
"google/gemma-2b": 2048,
|
||||
"google/gemma-7b": 8192,
|
||||
"google/recurrentgemma-2b": 2048,
|
||||
"meta/codellama-70b": 16384,
|
||||
"meta/llama2-70b": 4096,
|
||||
"microsoft/phi-3-mini-128k-instruct": 122880,
|
||||
"mistralai/mistral-7b-instruct-v0.2": 4096,
|
||||
"mistralai/mistral-large": 8192,
|
||||
"mistralai/mixtral-8x22b-instruct-v0.1": 32768,
|
||||
"mistralai/mixtral-8x7b-instruct-v0.1": 8192,
|
||||
"snowflake/arctic": 16384,
|
||||
},
|
||||
"groq": {
|
||||
"llama3-8b-8192": 8192,
|
||||
"llama3-70b-8192": 8192,
|
||||
"mixtral-8x7b-32768": 32768,
|
||||
"gemma-7b-it": 8192,
|
||||
"claude-3-haiku-20240307'": 8192,
|
||||
},
|
||||
"toghetherai": {
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": 128000,
|
||||
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": 128000,
|
||||
"mistralai/Mixtral-8x22B-Instruct-v0.1": 128000,
|
||||
"stabilityai/stable-diffusion-xl-base-1.0": 2048,
|
||||
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": 128000,
|
||||
"NousResearch/Hermes-3-Llama-3.1-405B-Turbo": 128000,
|
||||
"Gryphe/MythoMax-L2-13b-Lite": 8192,
|
||||
"Salesforce/Llama-Rank-V1": 8192,
|
||||
"meta-llama/Meta-Llama-Guard-3-8B": 128000,
|
||||
"meta-llama/Meta-Llama-3-70B-Instruct-Turbo": 128000,
|
||||
"meta-llama/Llama-3-8b-chat-hf": 8192,
|
||||
"meta-llama/Llama-3-70b-chat-hf": 8192,
|
||||
"Qwen/Qwen2-72B-Instruct": 128000,
|
||||
"google/gemma-2-27b-it": 8192,
|
||||
},
|
||||
"anthropic": {
|
||||
"claude_instant": 100000,
|
||||
"claude2": 9000,
|
||||
"claude2.1": 200000,
|
||||
"claude3": 200000,
|
||||
"claude3.5": 200000,
|
||||
"claude-3-opus-20240229": 200000,
|
||||
"claude-3-sonnet-20240229": 200000,
|
||||
"claude-3-haiku-20240307": 200000,
|
||||
"claude-3-5-sonnet-20240620": 200000,
|
||||
"claude-3-5-sonnet-20241022": 200000,
|
||||
"claude-3-5-haiku-latest": 200000,
|
||||
},
|
||||
"bedrock": {
|
||||
"anthropic.claude-3-haiku-20240307-v1:0": 200000,
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0": 200000,
|
||||
"anthropic.claude-3-opus-20240229-v1:0": 200000,
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0": 200000,
|
||||
"claude-3-5-haiku-latest": 200000,
|
||||
"anthropic.claude-v2:1": 200000,
|
||||
"anthropic.claude-v2": 100000,
|
||||
"anthropic.claude-instant-v1": 100000,
|
||||
"meta.llama3-8b-instruct-v1:0": 8192,
|
||||
"meta.llama3-70b-instruct-v1:0": 8192,
|
||||
"meta.llama2-13b-chat-v1": 4096,
|
||||
"meta.llama2-70b-chat-v1": 4096,
|
||||
"mistral.mistral-7b-instruct-v0:2": 32768,
|
||||
"mistral.mixtral-8x7b-instruct-v0:1": 32768,
|
||||
"mistral.mistral-large-2402-v1:0": 32768,
|
||||
"mistral.mistral-small-2402-v1:0": 32768,
|
||||
"amazon.titan-embed-text-v1": 8000,
|
||||
"amazon.titan-embed-text-v2:0": 8000,
|
||||
"cohere.embed-english-v3": 512,
|
||||
"cohere.embed-multilingual-v3": 512,
|
||||
},
|
||||
"mistralai": {
|
||||
"mistral-large-latest": 128000,
|
||||
"open-mistral-nemo": 128000,
|
||||
"codestral-latest": 32000,
|
||||
"mistral-embed": 8000,
|
||||
"open-mistral-7b": 32000,
|
||||
"open-mixtral-8x7b": 32000,
|
||||
"open-mixtral-8x22b": 64000,
|
||||
"open-codestral-mamba": 256000,
|
||||
},
|
||||
"hugging_face": {
|
||||
"xai-org/grok-1": 8192,
|
||||
"meta-llama/Meta-Llama-3-8B": 8192,
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct": 8192,
|
||||
"meta-llama/Meta-Llama-3-70B": 8192,
|
||||
"meta-llama/Meta-Llama-3-70B-Instruct": 8192,
|
||||
"google/gemma-2b": 8192,
|
||||
"google/gemma-2b-it": 8192,
|
||||
"google/gemma-7b": 8192,
|
||||
"google/gemma-7b-it": 8192,
|
||||
"microsoft/phi-2": 2048,
|
||||
"openai-community/gpt2": 1024,
|
||||
"openai-community/gpt2-medium": 1024,
|
||||
"openai-community/gpt2-large": 1024,
|
||||
"facebook/opt-125m": 2048,
|
||||
"petals-team/StableBeluga2": 8192,
|
||||
"distilbert/distilgpt2": 1024,
|
||||
"mistralai/Mistral-7B-Instruct-v0.2": 32768,
|
||||
"gradientai/Llama-3-8B-Instruct-Gradient-1048k": 1040200,
|
||||
"NousResearch/Hermes-2-Pro-Llama-3-8B": 8192,
|
||||
"NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF": 8192,
|
||||
"nvidia/Llama3-ChatQA-1.5-8B": 8192,
|
||||
"microsoft/Phi-3-mini-4k-instruct": 4192,
|
||||
"microsoft/Phi-3-mini-128k-instruct": 131072,
|
||||
"mlabonne/Meta-Llama-3-120B-Instruct": 8192,
|
||||
"cognitivecomputations/dolphin-2.9-llama3-8b": 8192,
|
||||
"cognitivecomputations/dolphin-2.9-llama3-8b-gguf": 8192,
|
||||
"cognitivecomputations/dolphin-2.8-mistral-7b-v02": 32768,
|
||||
"cognitivecomputations/dolphin-2.5-mixtral-8x7b": 32768,
|
||||
"TheBloke/dolphin-2.7-mixtral-8x7b-GGUF": 32768,
|
||||
"deepseek-ai/DeepSeek-V2": 131072,
|
||||
"deepseek-ai/DeepSeek-V2-Chat": 131072,
|
||||
"claude-3-haiku": 200000,
|
||||
},
|
||||
"deepseek": {
|
||||
"deepseek-chat": 28672,
|
||||
"deepseek-coder": 16384,
|
||||
},
|
||||
"ernie": {
|
||||
"ernie-bot-turbo": 4096,
|
||||
"ernie-bot": 4096,
|
||||
"ernie-bot-2": 4096,
|
||||
"ernie-bot-2-base": 4096,
|
||||
"ernie-bot-2-base-zh": 4096,
|
||||
"ernie-bot-2-base-en": 4096,
|
||||
"ernie-bot-2-base-en-zh": 4096,
|
||||
"ernie-bot-2-base-zh-en": 4096,
|
||||
},
|
||||
"fireworks": {
|
||||
"llama-v2-7b": 4096,
|
||||
"mixtral-8x7b-instruct": 4096,
|
||||
"nomic-ai/nomic-embed-text-v1.5": 8192,
|
||||
"llama-3.1-405B-instruct": 131072,
|
||||
"llama-3.1-70B-instruct": 131072,
|
||||
"llama-3.1-8B-instruct": 131072,
|
||||
"mixtral-moe-8x22B-instruct": 65536,
|
||||
"mixtral-moe-8x7B-instruct": 65536,
|
||||
},
|
||||
"togetherai": {"Meta-Llama-3.1-70B-Instruct-Turbo": 128000},
|
||||
}
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
import os
|
||||
from typing import List, Optional, Dict, Union
|
||||
from typing import List, Dict, Union
|
||||
from ra_aid.tools.memory import _global_memory
|
||||
from langchain_core.tools import tool
|
||||
from rich.console import Console
|
||||
|
|
@ -7,7 +7,6 @@ from rich.panel import Panel
|
|||
from rich.markdown import Markdown
|
||||
from rich.text import Text
|
||||
from ra_aid.proc.interactive import run_interactive_command
|
||||
from pydantic import BaseModel, Field
|
||||
from ra_aid.text.processing import truncate_output
|
||||
|
||||
console = Console()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,204 @@
|
|||
"""Unit tests for agent_utils.py."""
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
|
||||
from ra_aid.models_tokens import DEFAULT_TOKEN_LIMIT
|
||||
from ra_aid.agent_utils import state_modifier, AgentState
|
||||
from unittest.mock import Mock, patch
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from ra_aid.agent_utils import create_agent, get_model_token_limit
|
||||
from ra_aid.models_tokens import models_tokens
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model():
|
||||
"""Fixture providing a mock LLM model."""
|
||||
model = Mock(spec=BaseChatModel)
|
||||
return model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_memory():
|
||||
"""Fixture providing a mock global memory store."""
|
||||
with patch("ra_aid.agent_utils._global_memory") as mock_mem:
|
||||
mock_mem.get.return_value = {}
|
||||
yield mock_mem
|
||||
|
||||
|
||||
def test_get_model_token_limit_anthropic(mock_memory):
|
||||
"""Test get_model_token_limit with Anthropic model."""
|
||||
config = {"provider": "anthropic", "model": "claude2"}
|
||||
|
||||
token_limit = get_model_token_limit(config)
|
||||
assert token_limit == models_tokens["anthropic"]["claude2"]
|
||||
|
||||
|
||||
def test_get_model_token_limit_openai(mock_memory):
|
||||
"""Test get_model_token_limit with OpenAI model."""
|
||||
config = {"provider": "openai", "model": "gpt-4"}
|
||||
|
||||
token_limit = get_model_token_limit(config)
|
||||
assert token_limit == models_tokens["openai"]["gpt-4"]
|
||||
|
||||
|
||||
def test_get_model_token_limit_unknown(mock_memory):
|
||||
"""Test get_model_token_limit with unknown provider/model."""
|
||||
config = {"provider": "unknown", "model": "unknown-model"}
|
||||
|
||||
token_limit = get_model_token_limit(config)
|
||||
assert token_limit is None
|
||||
|
||||
|
||||
def test_get_model_token_limit_missing_config(mock_memory):
|
||||
"""Test get_model_token_limit with missing configuration."""
|
||||
config = {}
|
||||
|
||||
token_limit = get_model_token_limit(config)
|
||||
assert token_limit is None
|
||||
|
||||
|
||||
def test_create_agent_anthropic(mock_model, mock_memory):
|
||||
"""Test create_agent with Anthropic Claude model."""
|
||||
mock_memory.get.return_value = {"provider": "anthropic", "model": "claude-2"}
|
||||
|
||||
with patch("ra_aid.agent_utils.create_react_agent") as mock_react:
|
||||
mock_react.return_value = "react_agent"
|
||||
agent = create_agent(mock_model, [])
|
||||
|
||||
assert agent == "react_agent"
|
||||
mock_react.assert_called_once_with(
|
||||
mock_model, [], state_modifier=mock_react.call_args[1]["state_modifier"]
|
||||
)
|
||||
|
||||
|
||||
def test_create_agent_openai(mock_model, mock_memory):
|
||||
"""Test create_agent with OpenAI model."""
|
||||
mock_memory.get.return_value = {"provider": "openai", "model": "gpt-4"}
|
||||
|
||||
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
|
||||
mock_ciayn.return_value = "ciayn_agent"
|
||||
agent = create_agent(mock_model, [])
|
||||
|
||||
assert agent == "ciayn_agent"
|
||||
mock_ciayn.assert_called_once_with(
|
||||
mock_model, [], max_tokens=models_tokens["openai"]["gpt-4"]
|
||||
)
|
||||
|
||||
|
||||
def test_create_agent_no_token_limit(mock_model, mock_memory):
|
||||
"""Test create_agent when no token limit is found."""
|
||||
mock_memory.get.return_value = {"provider": "unknown", "model": "unknown-model"}
|
||||
|
||||
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
|
||||
mock_ciayn.return_value = "ciayn_agent"
|
||||
agent = create_agent(mock_model, [])
|
||||
|
||||
assert agent == "ciayn_agent"
|
||||
mock_ciayn.assert_called_once_with(
|
||||
mock_model, [], max_tokens=DEFAULT_TOKEN_LIMIT
|
||||
)
|
||||
|
||||
|
||||
def test_create_agent_missing_config(mock_model, mock_memory):
|
||||
"""Test create_agent with missing configuration."""
|
||||
mock_memory.get.return_value = {"provider": "openai"}
|
||||
|
||||
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
|
||||
mock_ciayn.return_value = "ciayn_agent"
|
||||
agent = create_agent(mock_model, [])
|
||||
|
||||
assert agent == "ciayn_agent"
|
||||
mock_ciayn.assert_called_once_with(
|
||||
mock_model,
|
||||
[],
|
||||
max_tokens=DEFAULT_TOKEN_LIMIT,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_messages():
|
||||
"""Fixture providing mock message objects."""
|
||||
|
||||
return [
|
||||
SystemMessage(content="System prompt"),
|
||||
HumanMessage(content="Human message 1"),
|
||||
AIMessage(content="AI response 1"),
|
||||
HumanMessage(content="Human message 2"),
|
||||
AIMessage(content="AI response 2"),
|
||||
]
|
||||
|
||||
|
||||
def test_state_modifier(mock_messages):
|
||||
"""Test that state_modifier correctly trims recent messages while preserving the first message when total tokens > max_tokens."""
|
||||
state = AgentState(messages=mock_messages)
|
||||
|
||||
with patch(
|
||||
"ra_aid.agents.ciayn_agent.CiaynAgent._estimate_tokens"
|
||||
) as mock_estimate:
|
||||
mock_estimate.side_effect = lambda msg: 100 if msg else 0
|
||||
|
||||
result = state_modifier(state, max_tokens=250)
|
||||
|
||||
assert len(result) < len(mock_messages)
|
||||
assert isinstance(result[0], SystemMessage)
|
||||
assert result[-1] == mock_messages[-1]
|
||||
|
||||
|
||||
def test_create_agent_with_checkpointer(mock_model, mock_memory):
|
||||
"""Test create_agent with checkpointer argument."""
|
||||
mock_memory.get.return_value = {"provider": "openai", "model": "gpt-4"}
|
||||
mock_checkpointer = Mock()
|
||||
|
||||
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
|
||||
mock_ciayn.return_value = "ciayn_agent"
|
||||
agent = create_agent(mock_model, [], checkpointer=mock_checkpointer)
|
||||
|
||||
assert agent == "ciayn_agent"
|
||||
mock_ciayn.assert_called_once_with(
|
||||
mock_model, [], max_tokens=models_tokens["openai"]["gpt-4"]
|
||||
)
|
||||
|
||||
|
||||
def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_memory):
|
||||
"""Test create_agent sets up token limiting for Claude models when enabled."""
|
||||
mock_memory.get.return_value = {
|
||||
"provider": "anthropic",
|
||||
"model": "claude-2",
|
||||
"limit_tokens": True,
|
||||
}
|
||||
|
||||
with (
|
||||
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
||||
patch("ra_aid.agent_utils.get_model_token_limit") as mock_limit,
|
||||
):
|
||||
mock_react.return_value = "react_agent"
|
||||
mock_limit.return_value = 100000
|
||||
|
||||
agent = create_agent(mock_model, [])
|
||||
|
||||
assert agent == "react_agent"
|
||||
args = mock_react.call_args
|
||||
assert "state_modifier" in args[1]
|
||||
assert callable(args[1]["state_modifier"])
|
||||
|
||||
|
||||
def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_memory):
|
||||
"""Test create_agent doesn't set up token limiting for Claude models when disabled."""
|
||||
mock_memory.get.return_value = {
|
||||
"provider": "anthropic",
|
||||
"model": "claude-2",
|
||||
"limit_tokens": False,
|
||||
}
|
||||
|
||||
with (
|
||||
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
||||
patch("ra_aid.agent_utils.get_model_token_limit") as mock_limit,
|
||||
):
|
||||
mock_react.return_value = "react_agent"
|
||||
mock_limit.return_value = 100000
|
||||
|
||||
agent = create_agent(mock_model, [])
|
||||
|
||||
assert agent == "react_agent"
|
||||
mock_react.assert_called_once_with(mock_model, [])
|
||||
Loading…
Reference in New Issue