Set Default Max Token Limit with Provider/Model Dictionary and Limit Tokens for Anthropic Claude React Agent (#45)

* feat(agent_utils.py): add get_model_token_limit function to retrieve token limits for models based on provider and model name
feat(models_tokens.py): create models_tokens module to store token limits for various models and providers
test(agent_utils.py): implement unit tests for get_model_token_limit and create_agent functions to ensure correct behavior and error handling

* test: Add unit tests for token limiting and agent creation functionality

* fix: Correct indentation and add missing test function for error handling

* fix: Update test assertion to use messages_modifier instead of state_modifier

* feat(agent_utils.py): add limit_tokens function to manage message token limits and preserve system message
fix(agent_utils.py): update get_model_token_limit to handle exceptions and return None on error
fix(ciayn_agent.py): set default max_tokens to DEFAULT_TOKEN_LIMIT in CiaynAgent initialization
feat(models_tokens.py): define DEFAULT_TOKEN_LIMIT for consistent token limit management across agents
style(agent_utils.py): improve code formatting and consistency in function definitions and comments
style(agent_utils.py): refactor imports for better organization and readability
fix(test_agent_utils.py): correct test assertion to use state_modifier instead of messages_modifier for create_agent function

* refactor(agent_utils.py): improve docstring clarity and formatting for limit_tokens function to enhance readability
refactor(test_agent_utils.py): format test assertions for consistency and readability in agent creation tests

* feat: Update limit_tokens function to support Dict type for state parameter

* feat: Update limit_tokens to handle both list and dict input types

* refactor: Extract duplicate token trimming logic into helper function

* refactor: Rename and update message trimming functions for clarity

* refactor: Extract agent kwargs logic into a helper method for reuse

* refactor: Rename _build_agent_kwargs to build_agent_kwargs for clarity

* fix: Ensure state_modifier is passed correctly for agent creation

* test: Add tests for create_agent token limiting configuration

* refactor: Simplify CiaynAgent instantiation to only use max_tokens parameter

* refactor: Remove is_react_agent parameter from build_agent_kwargs function

* test: Fix test assertions for state_modifier in agent creation tests

* fix: Update agent creation to handle checkpointer and simplify tests

* test: Remove unnecessary assertions from agent creation test

* feat: Implement token limiting configuration for create_agent function

* refactor: Remove unused model info and token limit retrieval code

* test: Fix assertion errors in agent creation tests and update state_modifier handling

* test: Remove commented-out code and clarify assertions in tests

* test: Fix assertion in test_create_agent_anthropic_token_limiting_disabled

* feat(main.py): add --limit-tokens argument to control token limiting in agent state
fix(main.py): include limit_tokens in configuration to ensure proper state management

* test: Refactor agent creation tests for improved readability and consistency

* test: Modify error handling in create_agent test to use side_effect on get_model_token_limit

* test: Improve error handling in create_agent test to verify fallback behavior

* test: Trigger exception on get_model_token_limit in error handling test

* refactor(agent_utils.py): remove unused config parameter from create_agent function to simplify the function signature
fix(agent_utils.py): ensure config is always retrieved from _global_memory with a default value to prevent potential errors
test(tests/test_agent_utils.py): remove outdated test for create_agent error handling to clean up the test suite

* feat: Add debug print for agent_kwargs in create_agent function

* refactor: Replace lambda with inner function for state_modifier in agent_utils

* refactor: Simplify limit_tokens function to return only message sequences

* feat: Add debug print statements to show token trimming details in trim_messages

* PAIN

* feat(main.py): add debug print statement for args.chat to assist in troubleshooting chat mode
feat(agent_utils.py): implement estimate_messages_tokens function to calculate total tokens in messages
refactor(agent_utils.py): replace token counting logic in trim_messages_with_removal with estimate_messages_tokens for clarity
refactor(agent_utils.py): modify state_modifier to accept model and max_tokens parameters for better flexibility
refactor(agent_utils.py): update build_agent_kwargs to pass model to state_modifier for improved functionality

* feat: Add .direnvrc to manage Python virtual environment activation

* refactor: Update state_modifier to handle first message token count and trim messages

* chore: remove unused .direnvrc file to clean up project structure
feat: add .envrc to .gitignore to prevent environment configuration file from being tracked
fix: update help text for --disable-limit-tokens argument for clarity
refactor: clean up imports in agent_utils.py for better readability
refactor: remove unused functions and comments in agent_utils.py to streamline code
test: add unit tests for state_modifier function to ensure correct message trimming behavior

* refactor: Remove commented-out code in create_agent function

* feat: Add is_anthropic_claude method to check provider and model name

* fix: Correct search/replace block to match existing lines in agent_utils.py

* fix(main.py): update help text for --disable-limit-tokens argument to clarify it applies to react agents
refactor(agent_utils.py): streamline token limit retrieval and improve readability by removing redundant checks and restructuring code
refactor(agent_utils.py): modify build_agent_kwargs to use is_anthropic_claude for clarity and maintainability

* test: Update tests to pass config argument to get_model_token_limit()

* refactor(agent_utils.py): remove unnecessary print statements and improve function signatures for clarity
refactor(agent_utils.py): consolidate provider and model_name retrieval into config parameter for better maintainability

* test: Remove redundant token limiting tests from agent creation logic

* test: Refactor test_state_modifier to use mock_messages fixture

* test(tests): update test description for clarity on state_modifier behavior and use mock_messages for assertions to ensure consistency

* refactor(agent_utils.py): simplify token limit retrieval by removing unnecessary variable initialization and defaulting to None in get method

* chore(main.py): remove debug print statement for args.chat to clean up output
This commit is contained in:
Ariel Frischer 2025-01-20 11:41:29 -08:00 committed by GitHub
parent 0c39166172
commit 32fcf914ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 905 additions and 203 deletions

1
.gitignore vendored
View File

@ -11,3 +11,4 @@ __pycache__/
/venv /venv
/.idea /.idea
/htmlcov /htmlcov
.envrc

View File

@ -6,112 +6,120 @@ from rich.panel import Panel
from rich.console import Console from rich.console import Console
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
from ra_aid.env import validate_environment from ra_aid.env import validate_environment
from ra_aid.project_info import get_project_info, format_project_info, display_project_status from ra_aid.project_info import (
get_project_info,
format_project_info,
display_project_status,
)
from ra_aid.tools.memory import _global_memory from ra_aid.tools.memory import _global_memory
from ra_aid.tools.human import ask_human from ra_aid.tools.human import ask_human
from ra_aid import print_stage_header, print_error from ra_aid import print_stage_header, print_error
from ra_aid.tools.human import ask_human
from ra_aid.__version__ import __version__ from ra_aid.__version__ import __version__
from ra_aid.agent_utils import ( from ra_aid.agent_utils import (
AgentInterrupt, AgentInterrupt,
run_agent_with_retry, run_agent_with_retry,
run_research_agent, run_research_agent,
run_planning_agent, run_planning_agent,
create_agent create_agent,
)
from ra_aid.prompts import (
CHAT_PROMPT,
WEB_RESEARCH_PROMPT_SECTION_CHAT
) )
from ra_aid.prompts import CHAT_PROMPT, WEB_RESEARCH_PROMPT_SECTION_CHAT
from ra_aid.llm import initialize_llm from ra_aid.llm import initialize_llm
from ra_aid.logging_config import setup_logging, get_logger from ra_aid.logging_config import setup_logging, get_logger
from ra_aid.tool_configs import ( from ra_aid.tool_configs import get_chat_tools
get_chat_tools
)
from ra_aid.dependencies import check_dependencies from ra_aid.dependencies import check_dependencies
import os import os
logger = get_logger(__name__) logger = get_logger(__name__)
def parse_arguments(args=None): def parse_arguments(args=None):
VALID_PROVIDERS = ['anthropic', 'openai', 'openrouter', 'openai-compatible', 'gemini'] VALID_PROVIDERS = [
ANTHROPIC_DEFAULT_MODEL = 'claude-3-5-sonnet-20241022' "anthropic",
OPENAI_DEFAULT_MODEL = 'gpt-4o' "openai",
"openrouter",
"openai-compatible",
"gemini",
]
ANTHROPIC_DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
OPENAI_DEFAULT_MODEL = "gpt-4o"
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='RA.Aid - AI Agent for executing programming and research tasks', description="RA.Aid - AI Agent for executing programming and research tasks",
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=''' epilog="""
Examples: Examples:
ra-aid -m "Add error handling to the database module" ra-aid -m "Add error handling to the database module"
ra-aid -m "Explain the authentication flow" --research-only ra-aid -m "Explain the authentication flow" --research-only
''' """,
) )
parser.add_argument( parser.add_argument(
'-m', '--message', "-m",
"--message",
type=str, type=str,
help='The task or query to be executed by the agent' help="The task or query to be executed by the agent",
) )
parser.add_argument( parser.add_argument(
'--version', "--version",
action='version', action="version",
version=f'%(prog)s {__version__}', version=f"%(prog)s {__version__}",
help='Show program version number and exit' help="Show program version number and exit",
) )
parser.add_argument( parser.add_argument(
'--research-only', "--research-only",
action='store_true', action="store_true",
help='Only perform research without implementation' help="Only perform research without implementation",
) )
parser.add_argument( parser.add_argument(
'--provider', "--provider",
type=str, type=str,
default='openai' if (os.getenv('OPENAI_API_KEY') and not os.getenv('ANTHROPIC_API_KEY')) else 'anthropic', default="openai"
if (os.getenv("OPENAI_API_KEY") and not os.getenv("ANTHROPIC_API_KEY"))
else "anthropic",
choices=VALID_PROVIDERS, choices=VALID_PROVIDERS,
help='The LLM provider to use' help="The LLM provider to use",
)
parser.add_argument("--model", type=str, help="The model name to use")
parser.add_argument(
"--cowboy-mode",
action="store_true",
help="Skip interactive approval for shell commands",
) )
parser.add_argument( parser.add_argument(
'--model', "--expert-provider",
type=str, type=str,
help='The model name to use' default="openai",
)
parser.add_argument(
'--cowboy-mode',
action='store_true',
help='Skip interactive approval for shell commands'
)
parser.add_argument(
'--expert-provider',
type=str,
default='openai',
choices=VALID_PROVIDERS, choices=VALID_PROVIDERS,
help='The LLM provider to use for expert knowledge queries (default: openai)' help="The LLM provider to use for expert knowledge queries (default: openai)",
) )
parser.add_argument( parser.add_argument(
'--expert-model', "--expert-model",
type=str, type=str,
help='The model name to use for expert knowledge queries (required for non-OpenAI providers)' help="The model name to use for expert knowledge queries (required for non-OpenAI providers)",
) )
parser.add_argument( parser.add_argument(
'--hil', '-H', "--hil",
action='store_true', "-H",
help='Enable human-in-the-loop mode, where the agent can prompt the user for additional information.' action="store_true",
help="Enable human-in-the-loop mode, where the agent can prompt the user for additional information.",
) )
parser.add_argument( parser.add_argument(
'--chat', "--chat",
action='store_true', action="store_true",
help='Enable chat mode with direct human interaction (implies --hil)' help="Enable chat mode with direct human interaction (implies --hil)",
) )
parser.add_argument( parser.add_argument(
'--verbose', "--verbose", action="store_true", help="Enable verbose logging output"
action='store_true',
help='Enable verbose logging output'
) )
parser.add_argument( parser.add_argument(
'--temperature', "--temperature",
type=float, type=float,
help='LLM temperature (0.0-2.0). Controls randomness in responses', help="LLM temperature (0.0-2.0). Controls randomness in responses",
default=None default=None,
)
parser.add_argument(
"--disable-limit-tokens",
action="store_false",
help="Whether to disable token limiting for Anthropic Claude react agents. Token limiter removes older messages to prevent maximum token limit API errors.",
) )
if args is None: if args is None:
@ -129,23 +137,34 @@ Examples:
if parsed_args.provider == "openai": if parsed_args.provider == "openai":
parsed_args.model = parsed_args.model or OPENAI_DEFAULT_MODEL parsed_args.model = parsed_args.model or OPENAI_DEFAULT_MODEL
if parsed_args.provider == 'anthropic': if parsed_args.provider == "anthropic":
# Always use default model for Anthropic # Always use default model for Anthropic
parsed_args.model = ANTHROPIC_DEFAULT_MODEL parsed_args.model = ANTHROPIC_DEFAULT_MODEL
elif not parsed_args.model and not parsed_args.research_only: elif not parsed_args.model and not parsed_args.research_only:
# Require model for other providers unless in research mode # Require model for other providers unless in research mode
parser.error(f"--model is required when using provider '{parsed_args.provider}'") parser.error(
f"--model is required when using provider '{parsed_args.provider}'"
)
# Validate expert model requirement # Validate expert model requirement
if parsed_args.expert_provider != 'openai' and not parsed_args.expert_model and not parsed_args.research_only: if (
parser.error(f"--expert-model is required when using expert provider '{parsed_args.expert_provider}'") parsed_args.expert_provider != "openai"
and not parsed_args.expert_model
and not parsed_args.research_only
):
parser.error(
f"--expert-model is required when using expert provider '{parsed_args.expert_provider}'"
)
# Validate temperature range if provided # Validate temperature range if provided
if parsed_args.temperature is not None and not (0.0 <= parsed_args.temperature <= 2.0): if parsed_args.temperature is not None and not (
parser.error('Temperature must be between 0.0 and 2.0') 0.0 <= parsed_args.temperature <= 2.0
):
parser.error("Temperature must be between 0.0 and 2.0")
return parsed_args return parsed_args
# Create console instance # Create console instance
console = Console() console = Console()
@ -157,14 +176,18 @@ implementation_memory = MemorySaver()
def is_informational_query() -> bool: def is_informational_query() -> bool:
"""Determine if the current query is informational based on implementation_requested state.""" """Determine if the current query is informational based on implementation_requested state."""
return _global_memory.get('config', {}).get('research_only', False) or not is_stage_requested('implementation') return _global_memory.get("config", {}).get(
"research_only", False
) or not is_stage_requested("implementation")
def is_stage_requested(stage: str) -> bool: def is_stage_requested(stage: str) -> bool:
"""Check if a stage has been requested to proceed.""" """Check if a stage has been requested to proceed."""
if stage == 'implementation': if stage == "implementation":
return _global_memory.get('implementation_requested', False) return _global_memory.get("implementation_requested", False)
return False return False
def main(): def main():
"""Main entry point for the ra-aid command line tool.""" """Main entry point for the ra-aid command line tool."""
args = parse_arguments() args = parse_arguments()
@ -175,26 +198,32 @@ def main():
# Check dependencies before proceeding # Check dependencies before proceeding
check_dependencies() check_dependencies()
expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) # Will exit if main env vars missing expert_enabled, expert_missing, web_research_enabled, web_research_missing = (
validate_environment(args)
) # Will exit if main env vars missing
logger.debug("Environment validation successful") logger.debug("Environment validation successful")
if expert_missing: if expert_missing:
console.print(Panel( console.print(
f"[yellow]Expert tools disabled due to missing configuration:[/yellow]\n" + Panel(
"\n".join(f"- {m}" for m in expert_missing) + f"[yellow]Expert tools disabled due to missing configuration:[/yellow]\n"
"\nSet the required environment variables or args to enable expert mode.", + "\n".join(f"- {m}" for m in expert_missing)
title="Expert Tools Disabled", + "\nSet the required environment variables or args to enable expert mode.",
style="yellow" title="Expert Tools Disabled",
)) style="yellow",
)
)
if web_research_missing: if web_research_missing:
console.print(Panel( console.print(
f"[yellow]Web research disabled due to missing configuration:[/yellow]\n" + Panel(
"\n".join(f"- {m}" for m in web_research_missing) + f"[yellow]Web research disabled due to missing configuration:[/yellow]\n"
"\nSet the required environment variables to enable web research.", + "\n".join(f"- {m}" for m in web_research_missing)
title="Web Research Disabled", + "\nSet the required environment variables to enable web research.",
style="yellow" title="Web Research Disabled",
)) style="yellow",
)
)
# Create the base model after validation # Create the base model after validation
model = initialize_llm(args.provider, args.model, temperature=args.temperature) model = initialize_llm(args.provider, args.model, temperature=args.temperature)
@ -216,7 +245,9 @@ def main():
formatted_project_info = "" formatted_project_info = ""
# Get initial request from user # Get initial request from user
initial_request = ask_human.invoke({"question": "What would you like help with?"}) initial_request = ask_human.invoke(
{"question": "What would you like help with?"}
)
# Get working directory and current date # Get working directory and current date
working_directory = os.getcwd() working_directory = os.getcwd()
@ -230,31 +261,41 @@ def main():
"cowboy_mode": args.cowboy_mode, "cowboy_mode": args.cowboy_mode,
"hil": True, # Always true in chat mode "hil": True, # Always true in chat mode
"web_research_enabled": web_research_enabled, "web_research_enabled": web_research_enabled,
"initial_request": initial_request "initial_request": initial_request,
"limit_tokens": args.disable_limit_tokens,
} }
# Store config in global memory # Store config in global memory
_global_memory['config'] = config _global_memory["config"] = config
_global_memory['config']['provider'] = args.provider _global_memory["config"]["provider"] = args.provider
_global_memory['config']['model'] = args.model _global_memory["config"]["model"] = args.model
_global_memory['config']['expert_provider'] = args.expert_provider _global_memory["config"]["expert_provider"] = args.expert_provider
_global_memory['config']['expert_model'] = args.expert_model _global_memory["config"]["expert_model"] = args.expert_model
# Create chat agent with appropriate tools # Create chat agent with appropriate tools
chat_agent = create_agent( chat_agent = create_agent(
model, model,
get_chat_tools(expert_enabled=expert_enabled, web_research_enabled=web_research_enabled), get_chat_tools(
checkpointer=MemorySaver() expert_enabled=expert_enabled,
web_research_enabled=web_research_enabled,
),
checkpointer=MemorySaver(),
) )
# Run chat agent and exit # Run chat agent and exit
run_agent_with_retry(chat_agent, CHAT_PROMPT.format( run_agent_with_retry(
chat_agent,
CHAT_PROMPT.format(
initial_request=initial_request, initial_request=initial_request,
web_research_section=WEB_RESEARCH_PROMPT_SECTION_CHAT if web_research_enabled else "", web_research_section=WEB_RESEARCH_PROMPT_SECTION_CHAT
if web_research_enabled
else "",
working_directory=working_directory, working_directory=working_directory,
current_date=current_date, current_date=current_date,
project_info=formatted_project_info project_info=formatted_project_info,
), config) ),
config,
)
return return
# Validate message is provided # Validate message is provided
@ -268,19 +309,20 @@ def main():
"recursion_limit": 100, "recursion_limit": 100,
"research_only": args.research_only, "research_only": args.research_only,
"cowboy_mode": args.cowboy_mode, "cowboy_mode": args.cowboy_mode,
"web_research_enabled": web_research_enabled "web_research_enabled": web_research_enabled,
"limit_tokens": args.disable_limit_tokens,
} }
# Store config in global memory for access by is_informational_query # Store config in global memory for access by is_informational_query
_global_memory['config'] = config _global_memory["config"] = config
# Store model configuration # Store model configuration
_global_memory['config']['provider'] = args.provider _global_memory["config"]["provider"] = args.provider
_global_memory['config']['model'] = args.model _global_memory["config"]["model"] = args.model
# Store expert provider and model in config # Store expert provider and model in config
_global_memory['config']['expert_provider'] = args.expert_provider _global_memory["config"]["expert_provider"] = args.expert_provider
_global_memory['config']['expert_model'] = args.expert_model _global_memory["config"]["expert_model"] = args.expert_model
# Run research stage # Run research stage
print_stage_header("Research Stage") print_stage_header("Research Stage")
@ -292,7 +334,7 @@ def main():
research_only=args.research_only, research_only=args.research_only,
hil=args.hil, hil=args.hil,
memory=research_memory, memory=research_memory,
config=config config=config,
) )
# Proceed with planning and implementation if not an informational query # Proceed with planning and implementation if not an informational query
@ -304,7 +346,7 @@ def main():
expert_enabled=expert_enabled, expert_enabled=expert_enabled,
hil=args.hil, hil=args.hil,
memory=planning_memory, memory=planning_memory,
config=config config=config,
) )
except (KeyboardInterrupt, AgentInterrupt): except (KeyboardInterrupt, AgentInterrupt):
@ -313,5 +355,6 @@ def main():
print() print()
sys.exit(0) sys.exit(0)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -3,22 +3,27 @@
import sys import sys
import time import time
import uuid import uuid
from typing import Optional, Any from typing import Optional, Any, List, Dict, Sequence
from langchain_core.messages import BaseMessage, trim_messages
import signal import signal
import threading
import time
from typing import Optional
from ra_aid.project_info import get_project_info, format_project_info, display_project_status from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt.chat_agent_executor import AgentState
from ra_aid.models_tokens import DEFAULT_TOKEN_LIMIT, models_tokens
from ra_aid.agents.ciayn_agent import CiaynAgent
import threading
from ra_aid.project_info import (
get_project_info,
format_project_info,
display_project_status,
)
from langgraph.prebuilt import create_react_agent from langgraph.prebuilt import create_react_agent
from ra_aid.agents.ciayn_agent import CiaynAgent
from ra_aid.project_info import get_project_info, format_project_info, display_project_status
from ra_aid.console.formatting import print_stage_header, print_error from ra_aid.console.formatting import print_stage_header, print_error
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_core.tools import tool from langchain_core.tools import tool
from typing import List, Any
from ra_aid.console.output import print_agent_output from ra_aid.console.output import print_agent_output
from ra_aid.logging_config import get_logger from ra_aid.logging_config import get_logger
from ra_aid.exceptions import AgentInterrupt from ra_aid.exceptions import AgentInterrupt
@ -26,7 +31,7 @@ from ra_aid.tool_configs import (
get_implementation_tools, get_implementation_tools,
get_research_tools, get_research_tools,
get_planning_tools, get_planning_tools,
get_web_research_tools get_web_research_tools,
) )
from ra_aid.prompts import ( from ra_aid.prompts import (
IMPLEMENTATION_PROMPT, IMPLEMENTATION_PROMPT,
@ -41,13 +46,9 @@ from ra_aid.prompts import (
HUMAN_PROMPT_SECTION_RESEARCH, HUMAN_PROMPT_SECTION_RESEARCH,
PLANNING_PROMPT, PLANNING_PROMPT,
EXPERT_PROMPT_SECTION_PLANNING, EXPERT_PROMPT_SECTION_PLANNING,
WEB_RESEARCH_PROMPT_SECTION_PLANNING,
HUMAN_PROMPT_SECTION_PLANNING, HUMAN_PROMPT_SECTION_PLANNING,
WEB_RESEARCH_PROMPT, WEB_RESEARCH_PROMPT,
EXPERT_PROMPT_SECTION_CHAT,
CHAT_PROMPT,
) )
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError
@ -60,58 +61,189 @@ from ra_aid.tools.memory import (
get_memory_value, get_memory_value,
get_related_files, get_related_files,
) )
from ra_aid.tool_configs import get_research_tools
from ra_aid.prompts import (
RESEARCH_PROMPT,
RESEARCH_ONLY_PROMPT,
EXPERT_PROMPT_SECTION_RESEARCH,
HUMAN_PROMPT_SECTION_RESEARCH
)
console = Console() console = Console()
logger = get_logger(__name__) logger = get_logger(__name__)
@tool @tool
def output_markdown_message(message: str) -> str: def output_markdown_message(message: str) -> str:
"""Outputs a message to the user, optionally prompting for input.""" """Outputs a message to the user, optionally prompting for input."""
console.print(Panel(Markdown(message.strip()), title="🤖 Assistant")) console.print(Panel(Markdown(message.strip()), title="🤖 Assistant"))
return "Message output." return "Message output."
def estimate_messages_tokens(messages: Sequence[BaseMessage]) -> int:
"""Helper function to estimate total tokens in a sequence of messages.
Args:
messages: Sequence of messages to count tokens for
Returns:
Total estimated token count
"""
if not messages:
return 0
estimate_tokens = CiaynAgent._estimate_tokens
return sum(estimate_tokens(msg) for msg in messages)
def state_modifier(
state: AgentState, max_tokens: int = DEFAULT_TOKEN_LIMIT
) -> list[BaseMessage]:
"""Given the agent state and max_tokens, return a trimmed list of messages.
Args:
state: The current agent state containing messages
max_tokens: Maximum number of tokens to allow (default: DEFAULT_TOKEN_LIMIT)
Returns:
list[BaseMessage]: Trimmed list of messages that fits within token limit
"""
messages = state["messages"]
if not messages:
return []
first_message = messages[0]
remaining_messages = messages[1:]
first_tokens = estimate_messages_tokens([first_message])
new_max_tokens = max_tokens - first_tokens
trimmed_remaining = trim_messages(
remaining_messages,
token_counter=estimate_messages_tokens,
max_tokens=new_max_tokens,
strategy="last",
allow_partial=False,
)
return [first_message] + trimmed_remaining
def get_model_token_limit(config: Dict[str, Any]) -> Optional[int]:
"""Get the token limit for the current model configuration.
Returns:
Optional[int]: The token limit if found, None otherwise
"""
try:
provider = config.get("provider", "")
model_name = config.get("model", "")
provider_tokens = models_tokens.get(provider, {})
token_limit = provider_tokens.get(model_name, None)
if token_limit:
logger.debug(
f"Found token limit for {provider}/{model_name}: {token_limit}"
)
else:
logger.debug(f"Could not find token limit for {provider}/{model_name}")
return token_limit
except Exception as e:
logger.warning(f"Failed to get model token limit: {e}")
return None
def build_agent_kwargs(
checkpointer: Optional[Any] = None,
config: Dict[str, Any] = None,
token_limit: Optional[int] = None,
) -> Dict[str, Any]:
"""Build kwargs dictionary for agent creation.
Args:
checkpointer: Optional memory checkpointer
config: Optional configuration dictionary
token_limit: Optional token limit for the model
Returns:
Dictionary of kwargs for agent creation
"""
agent_kwargs = {}
if checkpointer is not None:
agent_kwargs["checkpointer"] = checkpointer
if config.get("limit_tokens", True) and is_anthropic_claude(config):
def wrapped_state_modifier(state: AgentState) -> list[BaseMessage]:
return state_modifier(state, max_tokens=token_limit)
agent_kwargs["state_modifier"] = wrapped_state_modifier
return agent_kwargs
def is_anthropic_claude(config: Dict[str, Any]) -> bool:
"""Check if the provider and model name indicate an Anthropic Claude model.
Args:
provider: The provider name
model_name: The model name
Returns:
bool: True if this is an Anthropic Claude model
"""
provider = config.get("provider", "")
model_name = config.get("model", "")
return (
provider.lower() == "anthropic"
and model_name
and "claude" in model_name.lower()
)
def create_agent( def create_agent(
model: BaseChatModel, model: BaseChatModel,
tools: List[Any], tools: List[Any],
*, *,
checkpointer: Any = None checkpointer: Any = None,
) -> Any: ) -> Any:
"""Create a react agent with the given configuration. """Create a react agent with the given configuration.
Args: Args:
model: The LLM model to use model: The LLM model to use
tools: List of tools to provide to the agent tools: List of tools to provide to the agent
checkpointer: Optional memory checkpointer checkpointer: Optional memory checkpointer
config: Optional configuration dictionary containing settings like:
- limit_tokens (bool): Whether to apply token limiting (default: True)
- provider (str): The LLM provider name
- model (str): The model name
Returns: Returns:
The created agent instance The created agent instance
Token limiting helps prevent context window overflow by trimming older messages
while preserving system messages. It can be disabled by setting
config['limit_tokens'] = False.
""" """
try: try:
# Get model name if available config = _global_memory.get("config", {})
provider = _global_memory.get('config', {}).get('provider') token_limit = get_model_token_limit(config) or DEFAULT_TOKEN_LIMIT
model_name = _global_memory.get('config', {}).get('model')
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN # Use REACT agent for Anthropic Claude models, otherwise use CIAYN
if provider == 'anthropic' and 'claude' in model_name: if is_anthropic_claude(config):
logger.debug("Using create_react_agent to instantiate agent.") logger.debug("Using create_react_agent to instantiate agent.")
return create_react_agent(model, tools, checkpointer=checkpointer) agent_kwargs = build_agent_kwargs(checkpointer, config, token_limit)
return create_react_agent(model, tools, **agent_kwargs)
else: else:
logger.debug("Using CiaynAgent agent instance.") logger.debug("Using CiaynAgent agent instance")
return CiaynAgent(model, tools) return CiaynAgent(model, tools, max_tokens=token_limit)
except Exception as e: except Exception as e:
# Default to REACT agent if provider/model detection fails # Default to REACT agent if provider/model detection fails
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.") logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
return create_react_agent(model, tools, checkpointer=checkpointer) config = _global_memory.get("config", {})
token_limit = get_model_token_limit(config)
agent_kwargs = build_agent_kwargs(checkpointer, config, token_limit)
return create_react_agent(model, tools, **agent_kwargs)
def run_research_agent( def run_research_agent(
base_task_or_query: str, base_task_or_query: str,
@ -124,7 +256,7 @@ def run_research_agent(
memory: Optional[Any] = None, memory: Optional[Any] = None,
config: Optional[dict] = None, config: Optional[dict] = None,
thread_id: Optional[str] = None, thread_id: Optional[str] = None,
console_message: Optional[str] = None console_message: Optional[str] = None,
) -> Optional[str]: ) -> Optional[str]:
"""Run a research agent with the given configuration. """Run a research agent with the given configuration.
@ -153,8 +285,13 @@ def run_research_agent(
""" """
thread_id = thread_id or str(uuid.uuid4()) thread_id = thread_id or str(uuid.uuid4())
logger.debug("Starting research agent with thread_id=%s", thread_id) logger.debug("Starting research agent with thread_id=%s", thread_id)
logger.debug("Research configuration: expert=%s, research_only=%s, hil=%s, web=%s", logger.debug(
expert_enabled, research_only, hil, web_research_enabled) "Research configuration: expert=%s, research_only=%s, hil=%s, web=%s",
expert_enabled,
research_only,
hil,
web_research_enabled,
)
# Initialize memory if not provided # Initialize memory if not provided
if memory is None: if memory is None:
@ -169,7 +306,7 @@ def run_research_agent(
research_only=research_only, research_only=research_only,
expert_enabled=expert_enabled, expert_enabled=expert_enabled,
human_interaction=hil, human_interaction=hil,
web_research_enabled=config.get('web_research_enabled', False) web_research_enabled=config.get("web_research_enabled", False),
) )
# Create agent # Create agent
@ -178,7 +315,11 @@ def run_research_agent(
# Format prompt sections # Format prompt sections
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else "" expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else "" human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
web_research_section = WEB_RESEARCH_PROMPT_SECTION_RESEARCH if config.get('web_research_enabled') else "" web_research_section = (
WEB_RESEARCH_PROMPT_SECTION_RESEARCH
if config.get("web_research_enabled")
else ""
)
# Get research context from memory # Get research context from memory
key_facts = _global_memory.get("key_facts", "") key_facts = _global_memory.get("key_facts", "")
@ -196,29 +337,30 @@ def run_research_agent(
# Build prompt # Build prompt
prompt = (RESEARCH_ONLY_PROMPT if research_only else RESEARCH_PROMPT).format( prompt = (RESEARCH_ONLY_PROMPT if research_only else RESEARCH_PROMPT).format(
base_task=base_task_or_query, base_task=base_task_or_query,
research_only_note='' if research_only else ' Only request implementation if the user explicitly asked for changes to be made.', research_only_note=""
if research_only
else " Only request implementation if the user explicitly asked for changes to be made.",
expert_section=expert_section, expert_section=expert_section,
human_section=human_section, human_section=human_section,
web_research_section=web_research_section, web_research_section=web_research_section,
key_facts=key_facts, key_facts=key_facts,
work_log=get_memory_value('work_log'), work_log=get_memory_value("work_log"),
code_snippets=code_snippets, code_snippets=code_snippets,
related_files=related_files, related_files=related_files,
project_info=formatted_project_info project_info=formatted_project_info,
) )
# Set up configuration # Set up configuration
run_config = { run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 100}
"configurable": {"thread_id": thread_id},
"recursion_limit": 100
}
if config: if config:
run_config.update(config) run_config.update(config)
try: try:
# Display console message if provided # Display console message if provided
if console_message: if console_message:
console.print(Panel(Markdown(console_message), title="🔬 Looking into it...")) console.print(
Panel(Markdown(console_message), title="🔬 Looking into it...")
)
if project_info: if project_info:
display_project_status(project_info) display_project_status(project_info)
@ -239,7 +381,7 @@ def run_research_agent(
memory=memory, memory=memory,
config=config, config=config,
thread_id=thread_id, thread_id=thread_id,
console_message=console_message console_message=console_message,
) )
except (KeyboardInterrupt, AgentInterrupt): except (KeyboardInterrupt, AgentInterrupt):
raise raise
@ -247,6 +389,7 @@ def run_research_agent(
logger.error("Research agent failed: %s", str(e), exc_info=True) logger.error("Research agent failed: %s", str(e), exc_info=True)
raise raise
def run_web_research_agent( def run_web_research_agent(
query: str, query: str,
model, model,
@ -257,7 +400,7 @@ def run_web_research_agent(
memory: Optional[Any] = None, memory: Optional[Any] = None,
config: Optional[dict] = None, config: Optional[dict] = None,
thread_id: Optional[str] = None, thread_id: Optional[str] = None,
console_message: Optional[str] = None console_message: Optional[str] = None,
) -> Optional[str]: ) -> Optional[str]:
"""Run a web research agent with the given configuration. """Run a web research agent with the given configuration.
@ -284,8 +427,12 @@ def run_web_research_agent(
""" """
thread_id = thread_id or str(uuid.uuid4()) thread_id = thread_id or str(uuid.uuid4())
logger.debug("Starting web research agent with thread_id=%s", thread_id) logger.debug("Starting web research agent with thread_id=%s", thread_id)
logger.debug("Web research configuration: expert=%s, hil=%s, web=%s", logger.debug(
expert_enabled, hil, web_research_enabled) "Web research configuration: expert=%s, hil=%s, web=%s",
expert_enabled,
hil,
web_research_enabled,
)
# Initialize memory if not provided # Initialize memory if not provided
if memory is None: if memory is None:
@ -317,14 +464,11 @@ def run_web_research_agent(
human_section=human_section, human_section=human_section,
key_facts=key_facts, key_facts=key_facts,
code_snippets=code_snippets, code_snippets=code_snippets,
related_files=related_files related_files=related_files,
) )
# Set up configuration # Set up configuration
run_config = { run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 100}
"configurable": {"thread_id": thread_id},
"recursion_limit": 100
}
if config: if config:
run_config.update(config) run_config.update(config)
@ -342,6 +486,7 @@ def run_web_research_agent(
logger.error("Web research agent failed: %s", str(e), exc_info=True) logger.error("Web research agent failed: %s", str(e), exc_info=True)
raise raise
def run_planning_agent( def run_planning_agent(
base_task: str, base_task: str,
model, model,
@ -350,7 +495,7 @@ def run_planning_agent(
hil: bool = False, hil: bool = False,
memory: Optional[Any] = None, memory: Optional[Any] = None,
config: Optional[dict] = None, config: Optional[dict] = None,
thread_id: Optional[str] = None thread_id: Optional[str] = None,
) -> Optional[str]: ) -> Optional[str]:
"""Run a planning agent to create implementation plans. """Run a planning agent to create implementation plans.
@ -379,7 +524,10 @@ def run_planning_agent(
thread_id = str(uuid.uuid4()) thread_id = str(uuid.uuid4())
# Configure tools # Configure tools
tools = get_planning_tools(expert_enabled=expert_enabled, web_research_enabled=config.get('web_research_enabled', False)) tools = get_planning_tools(
expert_enabled=expert_enabled,
web_research_enabled=config.get("web_research_enabled", False),
)
# Create agent # Create agent
agent = create_agent(model, tools, checkpointer=memory) agent = create_agent(model, tools, checkpointer=memory)
@ -387,7 +535,11 @@ def run_planning_agent(
# Format prompt sections # Format prompt sections
expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else "" expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else ""
human_section = HUMAN_PROMPT_SECTION_PLANNING if hil else "" human_section = HUMAN_PROMPT_SECTION_PLANNING if hil else ""
web_research_section = WEB_RESEARCH_PROMPT_SECTION_PLANNING if config.get('web_research_enabled') else "" web_research_section = (
WEB_RESEARCH_PROMPT_SECTION_PLANNING
if config.get("web_research_enabled")
else ""
)
# Build prompt # Build prompt
planning_prompt = PLANNING_PROMPT.format( planning_prompt = PLANNING_PROMPT.format(
@ -395,19 +547,18 @@ def run_planning_agent(
human_section=human_section, human_section=human_section,
web_research_section=web_research_section, web_research_section=web_research_section,
base_task=base_task, base_task=base_task,
research_notes=get_memory_value('research_notes'), research_notes=get_memory_value("research_notes"),
related_files="\n".join(get_related_files()), related_files="\n".join(get_related_files()),
key_facts=get_memory_value('key_facts'), key_facts=get_memory_value("key_facts"),
key_snippets=get_memory_value('key_snippets'), key_snippets=get_memory_value("key_snippets"),
work_log=get_memory_value('work_log'), work_log=get_memory_value("work_log"),
research_only_note='' if config.get('research_only') else ' Only request implementation if the user explicitly asked for changes to be made.' research_only_note=""
if config.get("research_only")
else " Only request implementation if the user explicitly asked for changes to be made.",
) )
# Set up configuration # Set up configuration
run_config = { run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 100}
"configurable": {"thread_id": thread_id},
"recursion_limit": 100
}
if config: if config:
run_config.update(config) run_config.update(config)
@ -421,6 +572,7 @@ def run_planning_agent(
logger.error("Planning agent failed: %s", str(e), exc_info=True) logger.error("Planning agent failed: %s", str(e), exc_info=True)
raise raise
def run_task_implementation_agent( def run_task_implementation_agent(
base_task: str, base_task: str,
tasks: list, tasks: list,
@ -433,7 +585,7 @@ def run_task_implementation_agent(
web_research_enabled: bool = False, web_research_enabled: bool = False,
memory: Optional[Any] = None, memory: Optional[Any] = None,
config: Optional[dict] = None, config: Optional[dict] = None,
thread_id: Optional[str] = None thread_id: Optional[str] = None,
) -> Optional[str]: ) -> Optional[str]:
"""Run an implementation agent for a specific task. """Run an implementation agent for a specific task.
@ -454,7 +606,11 @@ def run_task_implementation_agent(
""" """
thread_id = thread_id or str(uuid.uuid4()) thread_id = thread_id or str(uuid.uuid4())
logger.debug("Starting implementation agent with thread_id=%s", thread_id) logger.debug("Starting implementation agent with thread_id=%s", thread_id)
logger.debug("Implementation configuration: expert=%s, web=%s", expert_enabled, web_research_enabled) logger.debug(
"Implementation configuration: expert=%s, web=%s",
expert_enabled,
web_research_enabled,
)
logger.debug("Task details: base_task=%s, current_task=%s", base_task, task) logger.debug("Task details: base_task=%s, current_task=%s", base_task, task)
logger.debug("Related files: %s", related_files) logger.debug("Related files: %s", related_files)
@ -467,7 +623,10 @@ def run_task_implementation_agent(
thread_id = str(uuid.uuid4()) thread_id = str(uuid.uuid4())
# Configure tools # Configure tools
tools = get_implementation_tools(expert_enabled=expert_enabled, web_research_enabled=config.get('web_research_enabled', False)) tools = get_implementation_tools(
expert_enabled=expert_enabled,
web_research_enabled=config.get("web_research_enabled", False),
)
# Create agent # Create agent
agent = create_agent(model, tools, checkpointer=memory) agent = create_agent(model, tools, checkpointer=memory)
@ -479,20 +638,21 @@ def run_task_implementation_agent(
tasks=tasks, tasks=tasks,
plan=plan, plan=plan,
related_files=related_files, related_files=related_files,
key_facts=get_memory_value('key_facts'), key_facts=get_memory_value("key_facts"),
key_snippets=get_memory_value('key_snippets'), key_snippets=get_memory_value("key_snippets"),
research_notes=get_memory_value('research_notes'), research_notes=get_memory_value("research_notes"),
work_log=get_memory_value('work_log'), work_log=get_memory_value("work_log"),
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "", expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
human_section=HUMAN_PROMPT_SECTION_IMPLEMENTATION if _global_memory.get('config', {}).get('hil', False) else "", human_section=HUMAN_PROMPT_SECTION_IMPLEMENTATION
web_research_section=WEB_RESEARCH_PROMPT_SECTION_CHAT if config.get('web_research_enabled') else "" if _global_memory.get("config", {}).get("hil", False)
else "",
web_research_section=WEB_RESEARCH_PROMPT_SECTION_CHAT
if config.get("web_research_enabled")
else "",
) )
# Set up configuration # Set up configuration
run_config = { run_config = {"configurable": {"thread_id": thread_id}, "recursion_limit": 100}
"configurable": {"thread_id": thread_id},
"recursion_limit": 100
}
if config: if config:
run_config.update(config) run_config.update(config)
@ -505,10 +665,12 @@ def run_task_implementation_agent(
logger.error("Implementation agent failed: %s", str(e), exc_info=True) logger.error("Implementation agent failed: %s", str(e), exc_info=True)
raise raise
_CONTEXT_STACK = [] _CONTEXT_STACK = []
_INTERRUPT_CONTEXT = None _INTERRUPT_CONTEXT = None
_FEEDBACK_MODE = False _FEEDBACK_MODE = False
def _request_interrupt(signum, frame): def _request_interrupt(signum, frame):
global _INTERRUPT_CONTEXT global _INTERRUPT_CONTEXT
if _CONTEXT_STACK: if _CONTEXT_STACK:
@ -520,6 +682,7 @@ def _request_interrupt(signum, frame):
print() print()
sys.exit(0) sys.exit(0)
class InterruptibleSection: class InterruptibleSection:
def __enter__(self): def __enter__(self):
_CONTEXT_STACK.append(self) _CONTEXT_STACK.append(self)
@ -528,10 +691,12 @@ class InterruptibleSection:
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
_CONTEXT_STACK.remove(self) _CONTEXT_STACK.remove(self)
def check_interrupt(): def check_interrupt():
if _CONTEXT_STACK and _INTERRUPT_CONTEXT is _CONTEXT_STACK[-1]: if _CONTEXT_STACK and _INTERRUPT_CONTEXT is _CONTEXT_STACK[-1]:
raise AgentInterrupt("Interrupt requested") raise AgentInterrupt("Interrupt requested")
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
"""Run an agent with retry logic for API errors.""" """Run an agent with retry logic for API errors."""
logger.debug("Running agent with prompt length: %d", len(prompt)) logger.debug("Running agent with prompt length: %d", len(prompt))
@ -546,48 +711,68 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
with InterruptibleSection(): with InterruptibleSection():
try: try:
# Track agent execution depth # Track agent execution depth
current_depth = _global_memory.get('agent_depth', 0) current_depth = _global_memory.get("agent_depth", 0)
_global_memory['agent_depth'] = current_depth + 1 _global_memory["agent_depth"] = current_depth + 1
for attempt in range(max_retries): for attempt in range(max_retries):
logger.debug("Attempt %d/%d", attempt + 1, max_retries) logger.debug("Attempt %d/%d", attempt + 1, max_retries)
check_interrupt() check_interrupt()
try: try:
for chunk in agent.stream({"messages": [HumanMessage(content=prompt)]}, config): for chunk in agent.stream(
{"messages": [HumanMessage(content=prompt)]}, config
):
logger.debug("Agent output: %s", chunk) logger.debug("Agent output: %s", chunk)
check_interrupt() check_interrupt()
print_agent_output(chunk) print_agent_output(chunk)
if _global_memory['plan_completed']: if _global_memory["plan_completed"]:
_global_memory['plan_completed'] = False _global_memory["plan_completed"] = False
_global_memory['task_completed'] = False _global_memory["task_completed"] = False
_global_memory['completion_message'] = '' _global_memory["completion_message"] = ""
break break
if _global_memory['task_completed']: if _global_memory["task_completed"]:
_global_memory['task_completed'] = False _global_memory["task_completed"] = False
_global_memory['completion_message'] = '' _global_memory["completion_message"] = ""
break break
logger.debug("Agent run completed successfully") logger.debug("Agent run completed successfully")
return "Agent run completed successfully" return "Agent run completed successfully"
except (KeyboardInterrupt, AgentInterrupt): except (KeyboardInterrupt, AgentInterrupt):
raise raise
except (InternalServerError, APITimeoutError, RateLimitError, APIError, ValueError) as e: except (
InternalServerError,
APITimeoutError,
RateLimitError,
APIError,
ValueError,
) as e:
if isinstance(e, ValueError): if isinstance(e, ValueError):
error_str = str(e).lower() error_str = str(e).lower()
if 'code' not in error_str or '429' not in error_str: if "code" not in error_str or "429" not in error_str:
raise # Re-raise ValueError if it's not a Lambda 429 raise # Re-raise ValueError if it's not a Lambda 429
if attempt == max_retries - 1: if attempt == max_retries - 1:
logger.error("Max retries reached, failing: %s", str(e)) logger.error("Max retries reached, failing: %s", str(e))
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {e}") raise RuntimeError(
logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e)) f"Max retries ({max_retries}) exceeded. Last error: {e}"
delay = base_delay * (2 ** attempt) )
print_error(f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})") logger.warning(
"API error (attempt %d/%d): %s",
attempt + 1,
max_retries,
str(e),
)
delay = base_delay * (2**attempt)
print_error(
f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
)
start = time.monotonic() start = time.monotonic()
while time.monotonic() - start < delay: while time.monotonic() - start < delay:
check_interrupt() check_interrupt()
time.sleep(0.1) time.sleep(0.1)
finally: finally:
# Reset depth tracking # Reset depth tracking
_global_memory['agent_depth'] = _global_memory.get('agent_depth', 1) - 1 _global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1
if original_handler and threading.current_thread() is threading.main_thread(): if (
original_handler
and threading.current_thread() is threading.main_thread()
):
signal.signal(signal.SIGINT, original_handler) signal.signal(signal.SIGINT, original_handler)

View File

@ -1,8 +1,8 @@
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Any, Generator, List, Optional, Union from typing import Dict, Any, Generator, List, Optional, Union
from typing import Dict, Any, Generator, List, Optional, Union
from ra_aid.models_tokens import DEFAULT_TOKEN_LIMIT
from ra_aid.tools.reflection import get_function_info from ra_aid.tools.reflection import get_function_info
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage, SystemMessage from langchain_core.messages import AIMessage, HumanMessage, BaseMessage, SystemMessage
@ -66,7 +66,7 @@ class CiaynAgent:
""" """
def __init__(self, model, tools: list, max_history_messages: int = 50, max_tokens: Optional[int] = 100000): def __init__(self, model, tools: list, max_history_messages: int = 50, max_tokens: Optional[int] = DEFAULT_TOKEN_LIMIT):
"""Initialize the agent with a model and list of tools. """Initialize the agent with a model and list of tools.
Args: Args:
@ -263,6 +263,10 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
text = content.content text = content.content
else: else:
text = content text = content
# create-react-agent tool calls can be lists
if isinstance(text, List):
return 0
if not text: if not text:
return 0 return 0

View File

@ -104,4 +104,4 @@ def initialize_expert_llm(provider: str = "openai", model_name: str = "o1") -> B
model=model_name, model=model_name,
) )
else: else:
raise ValueError(f"Unsupported provider: {provider}") raise ValueError(f"Unsupported provider: {provider}")

266
ra_aid/models_tokens.py Normal file
View File

@ -0,0 +1,266 @@
"""
List of model tokens
"""
DEFAULT_TOKEN_LIMIT = 100000
models_tokens = {
"openai": {
"gpt-3.5-turbo-0125": 16385,
"gpt-3.5": 4096,
"gpt-3.5-turbo": 16385,
"gpt-3.5-turbo-1106": 16385,
"gpt-3.5-turbo-instruct": 4096,
"gpt-4-0125-preview": 128000,
"gpt-4-turbo-preview": 128000,
"gpt-4-turbo": 128000,
"gpt-4-turbo-2024-04-09": 128000,
"gpt-4-1106-preview": 128000,
"gpt-4-vision-preview": 128000,
"gpt-4": 8192,
"gpt-4-0613": 8192,
"gpt-4-32k": 32768,
"gpt-4-32k-0613": 32768,
"gpt-4o": 128000,
"gpt-4o-2024-08-06": 128000,
"gpt-4o-2024-05-13": 128000,
"gpt-4o-mini": 128000,
"o1-preview": 128000,
"o1-mini": 128000,
},
"azure_openai": {
"gpt-3.5-turbo-0125": 16385,
"gpt-3.5": 4096,
"gpt-3.5-turbo": 16385,
"gpt-3.5-turbo-1106": 16385,
"gpt-3.5-turbo-instruct": 4096,
"gpt-4-0125-preview": 128000,
"gpt-4-turbo-preview": 128000,
"gpt-4-turbo": 128000,
"gpt-4-turbo-2024-04-09": 128000,
"gpt-4-1106-preview": 128000,
"gpt-4-vision-preview": 128000,
"gpt-4": 8192,
"gpt-4-0613": 8192,
"gpt-4-32k": 32768,
"gpt-4-32k-0613": 32768,
"gpt-4o": 128000,
"gpt-4o-mini": 128000,
"chatgpt-4o-latest": 128000,
"o1-preview": 128000,
"o1-mini": 128000,
},
"google_genai": {
"gemini-pro": 128000,
"gemini-1.5-flash-latest": 128000,
"gemini-1.5-pro-latest": 128000,
"models/embedding-001": 2048,
},
"google_vertexai": {
"gemini-1.5-flash": 128000,
"gemini-1.5-pro": 128000,
"gemini-1.0-pro": 128000,
},
"ollama": {
"command-r": 12800,
"codellama": 16000,
"dbrx": 32768,
"deepseek-coder:33b": 16000,
"falcon": 2048,
"llama2": 4096,
"llama2:7b": 4096,
"llama2:13b": 4096,
"llama2:70b": 4096,
"llama3": 8192,
"llama3:8b": 8192,
"llama3:70b": 8192,
"llama3.1": 128000,
"llama3.1:8b": 128000,
"llama3.1:70b": 128000,
"lama3.1:405b": 128000,
"llama3.2": 128000,
"llama3.2:1b": 128000,
"llama3.2:3b": 128000,
"llama3.3:70b": 128000,
"scrapegraph": 8192,
"mistral-small": 128000,
"mistral-openorca": 32000,
"mistral-large": 128000,
"grok-1": 8192,
"llava": 4096,
"mixtral:8x22b-instruct": 65536,
"nomic-embed-text": 8192,
"nous-hermes2:34b": 4096,
"orca-mini": 2048,
"phi3:3.8b": 12800,
"phi3:14b": 128000,
"qwen:0.5b": 32000,
"qwen:1.8b": 32000,
"qwen:4b": 32000,
"qwen:14b": 32000,
"qwen:32b": 32000,
"qwen:72b": 32000,
"qwen:110b": 32000,
"stablelm-zephyr": 8192,
"wizardlm2:8x22b": 65536,
"mistral": 128000,
"gemma2": 128000,
"gemma2:9b": 128000,
"gemma2:27b": 128000,
# embedding models
"shaw/dmeta-embedding-zh-small-q4": 8192,
"shaw/dmeta-embedding-zh-q4": 8192,
"chevalblanc/acge_text_embedding": 8192,
"martcreation/dmeta-embedding-zh": 8192,
"snowflake-arctic-embed": 8192,
"mxbai-embed-large": 512,
},
"oneapi": {
"qwen-turbo": 6000,
},
"nvidia": {
"meta/llama3-70b-instruct": 419,
"meta/llama3-8b-instruct": 419,
"nemotron-4-340b-instruct": 1024,
"databricks/dbrx-instruct": 4096,
"google/codegemma-7b": 8192,
"google/gemma-2b": 2048,
"google/gemma-7b": 8192,
"google/recurrentgemma-2b": 2048,
"meta/codellama-70b": 16384,
"meta/llama2-70b": 4096,
"microsoft/phi-3-mini-128k-instruct": 122880,
"mistralai/mistral-7b-instruct-v0.2": 4096,
"mistralai/mistral-large": 8192,
"mistralai/mixtral-8x22b-instruct-v0.1": 32768,
"mistralai/mixtral-8x7b-instruct-v0.1": 8192,
"snowflake/arctic": 16384,
},
"groq": {
"llama3-8b-8192": 8192,
"llama3-70b-8192": 8192,
"mixtral-8x7b-32768": 32768,
"gemma-7b-it": 8192,
"claude-3-haiku-20240307'": 8192,
},
"toghetherai": {
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": 128000,
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": 128000,
"mistralai/Mixtral-8x22B-Instruct-v0.1": 128000,
"stabilityai/stable-diffusion-xl-base-1.0": 2048,
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": 128000,
"NousResearch/Hermes-3-Llama-3.1-405B-Turbo": 128000,
"Gryphe/MythoMax-L2-13b-Lite": 8192,
"Salesforce/Llama-Rank-V1": 8192,
"meta-llama/Meta-Llama-Guard-3-8B": 128000,
"meta-llama/Meta-Llama-3-70B-Instruct-Turbo": 128000,
"meta-llama/Llama-3-8b-chat-hf": 8192,
"meta-llama/Llama-3-70b-chat-hf": 8192,
"Qwen/Qwen2-72B-Instruct": 128000,
"google/gemma-2-27b-it": 8192,
},
"anthropic": {
"claude_instant": 100000,
"claude2": 9000,
"claude2.1": 200000,
"claude3": 200000,
"claude3.5": 200000,
"claude-3-opus-20240229": 200000,
"claude-3-sonnet-20240229": 200000,
"claude-3-haiku-20240307": 200000,
"claude-3-5-sonnet-20240620": 200000,
"claude-3-5-sonnet-20241022": 200000,
"claude-3-5-haiku-latest": 200000,
},
"bedrock": {
"anthropic.claude-3-haiku-20240307-v1:0": 200000,
"anthropic.claude-3-sonnet-20240229-v1:0": 200000,
"anthropic.claude-3-opus-20240229-v1:0": 200000,
"anthropic.claude-3-5-sonnet-20240620-v1:0": 200000,
"claude-3-5-haiku-latest": 200000,
"anthropic.claude-v2:1": 200000,
"anthropic.claude-v2": 100000,
"anthropic.claude-instant-v1": 100000,
"meta.llama3-8b-instruct-v1:0": 8192,
"meta.llama3-70b-instruct-v1:0": 8192,
"meta.llama2-13b-chat-v1": 4096,
"meta.llama2-70b-chat-v1": 4096,
"mistral.mistral-7b-instruct-v0:2": 32768,
"mistral.mixtral-8x7b-instruct-v0:1": 32768,
"mistral.mistral-large-2402-v1:0": 32768,
"mistral.mistral-small-2402-v1:0": 32768,
"amazon.titan-embed-text-v1": 8000,
"amazon.titan-embed-text-v2:0": 8000,
"cohere.embed-english-v3": 512,
"cohere.embed-multilingual-v3": 512,
},
"mistralai": {
"mistral-large-latest": 128000,
"open-mistral-nemo": 128000,
"codestral-latest": 32000,
"mistral-embed": 8000,
"open-mistral-7b": 32000,
"open-mixtral-8x7b": 32000,
"open-mixtral-8x22b": 64000,
"open-codestral-mamba": 256000,
},
"hugging_face": {
"xai-org/grok-1": 8192,
"meta-llama/Meta-Llama-3-8B": 8192,
"meta-llama/Meta-Llama-3-8B-Instruct": 8192,
"meta-llama/Meta-Llama-3-70B": 8192,
"meta-llama/Meta-Llama-3-70B-Instruct": 8192,
"google/gemma-2b": 8192,
"google/gemma-2b-it": 8192,
"google/gemma-7b": 8192,
"google/gemma-7b-it": 8192,
"microsoft/phi-2": 2048,
"openai-community/gpt2": 1024,
"openai-community/gpt2-medium": 1024,
"openai-community/gpt2-large": 1024,
"facebook/opt-125m": 2048,
"petals-team/StableBeluga2": 8192,
"distilbert/distilgpt2": 1024,
"mistralai/Mistral-7B-Instruct-v0.2": 32768,
"gradientai/Llama-3-8B-Instruct-Gradient-1048k": 1040200,
"NousResearch/Hermes-2-Pro-Llama-3-8B": 8192,
"NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF": 8192,
"nvidia/Llama3-ChatQA-1.5-8B": 8192,
"microsoft/Phi-3-mini-4k-instruct": 4192,
"microsoft/Phi-3-mini-128k-instruct": 131072,
"mlabonne/Meta-Llama-3-120B-Instruct": 8192,
"cognitivecomputations/dolphin-2.9-llama3-8b": 8192,
"cognitivecomputations/dolphin-2.9-llama3-8b-gguf": 8192,
"cognitivecomputations/dolphin-2.8-mistral-7b-v02": 32768,
"cognitivecomputations/dolphin-2.5-mixtral-8x7b": 32768,
"TheBloke/dolphin-2.7-mixtral-8x7b-GGUF": 32768,
"deepseek-ai/DeepSeek-V2": 131072,
"deepseek-ai/DeepSeek-V2-Chat": 131072,
"claude-3-haiku": 200000,
},
"deepseek": {
"deepseek-chat": 28672,
"deepseek-coder": 16384,
},
"ernie": {
"ernie-bot-turbo": 4096,
"ernie-bot": 4096,
"ernie-bot-2": 4096,
"ernie-bot-2-base": 4096,
"ernie-bot-2-base-zh": 4096,
"ernie-bot-2-base-en": 4096,
"ernie-bot-2-base-en-zh": 4096,
"ernie-bot-2-base-zh-en": 4096,
},
"fireworks": {
"llama-v2-7b": 4096,
"mixtral-8x7b-instruct": 4096,
"nomic-ai/nomic-embed-text-v1.5": 8192,
"llama-3.1-405B-instruct": 131072,
"llama-3.1-70B-instruct": 131072,
"llama-3.1-8B-instruct": 131072,
"mixtral-moe-8x22B-instruct": 65536,
"mixtral-moe-8x7B-instruct": 65536,
},
"togetherai": {"Meta-Llama-3.1-70B-Instruct-Turbo": 128000},
}

View File

@ -1,5 +1,5 @@
import os import os
from typing import List, Optional, Dict, Union from typing import List, Dict, Union
from ra_aid.tools.memory import _global_memory from ra_aid.tools.memory import _global_memory
from langchain_core.tools import tool from langchain_core.tools import tool
from rich.console import Console from rich.console import Console
@ -7,7 +7,6 @@ from rich.panel import Panel
from rich.markdown import Markdown from rich.markdown import Markdown
from rich.text import Text from rich.text import Text
from ra_aid.proc.interactive import run_interactive_command from ra_aid.proc.interactive import run_interactive_command
from pydantic import BaseModel, Field
from ra_aid.text.processing import truncate_output from ra_aid.text.processing import truncate_output
console = Console() console = Console()

View File

@ -0,0 +1,204 @@
"""Unit tests for agent_utils.py."""
import pytest
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from ra_aid.models_tokens import DEFAULT_TOKEN_LIMIT
from ra_aid.agent_utils import state_modifier, AgentState
from unittest.mock import Mock, patch
from langchain_core.language_models import BaseChatModel
from ra_aid.agent_utils import create_agent, get_model_token_limit
from ra_aid.models_tokens import models_tokens
@pytest.fixture
def mock_model():
"""Fixture providing a mock LLM model."""
model = Mock(spec=BaseChatModel)
return model
@pytest.fixture
def mock_memory():
"""Fixture providing a mock global memory store."""
with patch("ra_aid.agent_utils._global_memory") as mock_mem:
mock_mem.get.return_value = {}
yield mock_mem
def test_get_model_token_limit_anthropic(mock_memory):
"""Test get_model_token_limit with Anthropic model."""
config = {"provider": "anthropic", "model": "claude2"}
token_limit = get_model_token_limit(config)
assert token_limit == models_tokens["anthropic"]["claude2"]
def test_get_model_token_limit_openai(mock_memory):
"""Test get_model_token_limit with OpenAI model."""
config = {"provider": "openai", "model": "gpt-4"}
token_limit = get_model_token_limit(config)
assert token_limit == models_tokens["openai"]["gpt-4"]
def test_get_model_token_limit_unknown(mock_memory):
"""Test get_model_token_limit with unknown provider/model."""
config = {"provider": "unknown", "model": "unknown-model"}
token_limit = get_model_token_limit(config)
assert token_limit is None
def test_get_model_token_limit_missing_config(mock_memory):
"""Test get_model_token_limit with missing configuration."""
config = {}
token_limit = get_model_token_limit(config)
assert token_limit is None
def test_create_agent_anthropic(mock_model, mock_memory):
"""Test create_agent with Anthropic Claude model."""
mock_memory.get.return_value = {"provider": "anthropic", "model": "claude-2"}
with patch("ra_aid.agent_utils.create_react_agent") as mock_react:
mock_react.return_value = "react_agent"
agent = create_agent(mock_model, [])
assert agent == "react_agent"
mock_react.assert_called_once_with(
mock_model, [], state_modifier=mock_react.call_args[1]["state_modifier"]
)
def test_create_agent_openai(mock_model, mock_memory):
"""Test create_agent with OpenAI model."""
mock_memory.get.return_value = {"provider": "openai", "model": "gpt-4"}
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
mock_ciayn.return_value = "ciayn_agent"
agent = create_agent(mock_model, [])
assert agent == "ciayn_agent"
mock_ciayn.assert_called_once_with(
mock_model, [], max_tokens=models_tokens["openai"]["gpt-4"]
)
def test_create_agent_no_token_limit(mock_model, mock_memory):
"""Test create_agent when no token limit is found."""
mock_memory.get.return_value = {"provider": "unknown", "model": "unknown-model"}
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
mock_ciayn.return_value = "ciayn_agent"
agent = create_agent(mock_model, [])
assert agent == "ciayn_agent"
mock_ciayn.assert_called_once_with(
mock_model, [], max_tokens=DEFAULT_TOKEN_LIMIT
)
def test_create_agent_missing_config(mock_model, mock_memory):
"""Test create_agent with missing configuration."""
mock_memory.get.return_value = {"provider": "openai"}
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
mock_ciayn.return_value = "ciayn_agent"
agent = create_agent(mock_model, [])
assert agent == "ciayn_agent"
mock_ciayn.assert_called_once_with(
mock_model,
[],
max_tokens=DEFAULT_TOKEN_LIMIT,
)
@pytest.fixture
def mock_messages():
"""Fixture providing mock message objects."""
return [
SystemMessage(content="System prompt"),
HumanMessage(content="Human message 1"),
AIMessage(content="AI response 1"),
HumanMessage(content="Human message 2"),
AIMessage(content="AI response 2"),
]
def test_state_modifier(mock_messages):
"""Test that state_modifier correctly trims recent messages while preserving the first message when total tokens > max_tokens."""
state = AgentState(messages=mock_messages)
with patch(
"ra_aid.agents.ciayn_agent.CiaynAgent._estimate_tokens"
) as mock_estimate:
mock_estimate.side_effect = lambda msg: 100 if msg else 0
result = state_modifier(state, max_tokens=250)
assert len(result) < len(mock_messages)
assert isinstance(result[0], SystemMessage)
assert result[-1] == mock_messages[-1]
def test_create_agent_with_checkpointer(mock_model, mock_memory):
"""Test create_agent with checkpointer argument."""
mock_memory.get.return_value = {"provider": "openai", "model": "gpt-4"}
mock_checkpointer = Mock()
with patch("ra_aid.agent_utils.CiaynAgent") as mock_ciayn:
mock_ciayn.return_value = "ciayn_agent"
agent = create_agent(mock_model, [], checkpointer=mock_checkpointer)
assert agent == "ciayn_agent"
mock_ciayn.assert_called_once_with(
mock_model, [], max_tokens=models_tokens["openai"]["gpt-4"]
)
def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_memory):
"""Test create_agent sets up token limiting for Claude models when enabled."""
mock_memory.get.return_value = {
"provider": "anthropic",
"model": "claude-2",
"limit_tokens": True,
}
with (
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
patch("ra_aid.agent_utils.get_model_token_limit") as mock_limit,
):
mock_react.return_value = "react_agent"
mock_limit.return_value = 100000
agent = create_agent(mock_model, [])
assert agent == "react_agent"
args = mock_react.call_args
assert "state_modifier" in args[1]
assert callable(args[1]["state_modifier"])
def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_memory):
"""Test create_agent doesn't set up token limiting for Claude models when disabled."""
mock_memory.get.return_value = {
"provider": "anthropic",
"model": "claude-2",
"limit_tokens": False,
}
with (
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
patch("ra_aid.agent_utils.get_model_token_limit") as mock_limit,
):
mock_react.return_value = "react_agent"
mock_limit.return_value = 100000
agent = create_agent(mock_model, [])
assert agent == "react_agent"
mock_react.assert_called_once_with(mock_model, [])