feat(issue): add LLM Tool Call Fallback Feature documentation to outline the new functionality for automatic fallback to alternative LLM models after consecutive failures
feat(ciayn_agent): implement fallback mechanism in CiaynAgent to handle tool call failures and switch to alternative models feat(__main__): add command line arguments for fallback configuration in the main application feat(llm): add validation for required environment variables for LLM providers and merge chat histories during fallback fix(config): define default values for maximum tool failures in configuration test(ciayn_agent): add unit tests for fallback logic and tool call execution with retries and error handling test(llm): enhance tests for LLM initialization and environment variable validation
This commit is contained in:
parent
00a455d586
commit
45b993cfd0
|
|
@ -0,0 +1,115 @@
|
||||||
|
# LLM Tool Call Fallback Feature
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
Add functionality to automatically fallback to alternative LLM models when a tool call experiences multiple consecutive failures.
|
||||||
|
|
||||||
|
## Background
|
||||||
|
Currently, when a tool call fails due to LLM-related errors (e.g., API timeouts, rate limits, context length issues), there is no automatic fallback mechanism. This can lead to interrupted workflows and poor user experience.
|
||||||
|
|
||||||
|
## Relevant Files
|
||||||
|
- ra_aid/agents/ciayn_agent.py
|
||||||
|
- ra_aid/llm.py
|
||||||
|
- ra_aid/agent_utils.py
|
||||||
|
- ra_aid/__main__.py
|
||||||
|
- ra_aid/models_params.py
|
||||||
|
|
||||||
|
|
||||||
|
## Implementation Details
|
||||||
|
|
||||||
|
### Configuration
|
||||||
|
- Add new configuration value `max_tool_failures` (default: 3) to track consecutive failures before triggering fallback
|
||||||
|
- Add new command line argument `--no-fallback-tool` to disable fallback behavior (enabled by default)
|
||||||
|
- **Add new command line argument** `--fallback-tool-models` to specify a comma-separated list of fallback tool models (default: "gpt-3.5-turbo,gpt-4")
|
||||||
|
This list defines the fallback model sequence used by forced tool calls (via `bind_tools`) when tool call failures occur.
|
||||||
|
- Track failure count per tool call context
|
||||||
|
- Reset failure counter on successful tool call
|
||||||
|
- Store fallback model sequence per provider
|
||||||
|
- Need to validate if ENV vars are set for provider usage of that fallback model
|
||||||
|
before usage, if that fallback ENV is not available then fallback to the next model
|
||||||
|
- Have default list of common models, first try `claude-3-5-sonnet-20241022` but
|
||||||
|
have many alternative fallback models.
|
||||||
|
|
||||||
|
### Tool Call Wrapper
|
||||||
|
Create a new wrapper function to handle tool call execution with fallback logic:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def execute_tool_with_fallback(tool_call_func, *args, **kwargs):
|
||||||
|
failures = 0
|
||||||
|
max_failures = get_config().max_tool_failures
|
||||||
|
|
||||||
|
while failures < max_failures:
|
||||||
|
try:
|
||||||
|
return tool_call_func(*args, **kwargs)
|
||||||
|
except LLMError as e:
|
||||||
|
failures += 1
|
||||||
|
if failures >= max_failures:
|
||||||
|
# Use forced tool call via bind_tools with retry:
|
||||||
|
llm_retry = llm_model.with_retry(stop_after_attempt=3) # Try three times
|
||||||
|
try_fallback_model(force=True, model=llm_retry)
|
||||||
|
# Merge fallback model chat messages back into the original chat history.
|
||||||
|
merge_fallback_chat_history()
|
||||||
|
failures = 0 # Reset counter for new model
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
```
|
||||||
|
|
||||||
|
The prompt passed to `try_fallback_model`, should be the failed last few failing tool calls.
|
||||||
|
|
||||||
|
### Model Fallback Sequence
|
||||||
|
Define fallback sequences for each provider based on model capabilities:
|
||||||
|
|
||||||
|
1. Try same provider's smaller models
|
||||||
|
2. Try alternative providers' equivalent models
|
||||||
|
3. Raise final error if all fallbacks fail
|
||||||
|
|
||||||
|
### Provider Strategy Updates
|
||||||
|
Update provider strategies to support fallback configuration:
|
||||||
|
- Add provider-specific fallback sequences
|
||||||
|
- Handle model capability validation during fallback
|
||||||
|
- Track successful/failed attempts
|
||||||
|
|
||||||
|
## Risks and Mitigations
|
||||||
|
1. **Performance Impact**
|
||||||
|
- Risk: Multiple fallback attempts could increase latency
|
||||||
|
- Mitigation: Set reasonable max_failures limit and timeouts
|
||||||
|
|
||||||
|
2. **Consistency**
|
||||||
|
- Risk: Different models may give slightly different outputs
|
||||||
|
- Mitigation: Validate output schema consistency across models
|
||||||
|
|
||||||
|
3. **Cost**
|
||||||
|
- Risk: Fallback to more expensive models
|
||||||
|
- Mitigation: Configure cost limits and preferred fallback sequences
|
||||||
|
|
||||||
|
4. **State Management**
|
||||||
|
- Risk: Loss of context during fallbacks
|
||||||
|
- Mitigation: Preserve conversation state and tool context
|
||||||
|
|
||||||
|
## Acceptance Criteria
|
||||||
|
1. Tool calls automatically attempt fallback models after N consecutive failures
|
||||||
|
2. `--no-fallback-tool` argument successfully disables fallback behavior
|
||||||
|
3. Fallback sequence respects provider and model capabilities
|
||||||
|
4. Original error is preserved if all fallbacks fail
|
||||||
|
5. Unit tests cover fallback scenarios and edge cases
|
||||||
|
6. README.md updated to reflect new behavior
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
1. Unit tests for fallback wrapper
|
||||||
|
2. Integration tests with mock LLM failures
|
||||||
|
3. Provider strategy fallback tests
|
||||||
|
4. Command line argument handling
|
||||||
|
5. Error preservation and reporting
|
||||||
|
6. Performance impact measurement
|
||||||
|
7. Edge cases (e.g., partial failures, timeout handling)
|
||||||
|
8. State preservation during fallbacks
|
||||||
|
|
||||||
|
## Documentation Updates
|
||||||
|
1. Add fallback feature to main README
|
||||||
|
2. Document `--no-fallback-tool` in CLI help
|
||||||
|
3. Document provider-specific fallback sequences
|
||||||
|
|
||||||
|
## Future Considerations
|
||||||
|
1. Allow custom fallback sequences via configuration
|
||||||
|
2. Add monitoring and alerting for fallback frequency
|
||||||
|
3. Optimize fallback selection based on historical success rates
|
||||||
|
4. Cost-aware fallback routing
|
||||||
|
|
@ -149,6 +149,17 @@ Examples:
|
||||||
action="store_false",
|
action="store_false",
|
||||||
help="Whether to disable token limiting for Anthropic Claude react agents. Token limiter removes older messages to prevent maximum token limit API errors.",
|
help="Whether to disable token limiting for Anthropic Claude react agents. Token limiter removes older messages to prevent maximum token limit API errors.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-fallback-tool",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable fallback model switching.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fallback-tool-models",
|
||||||
|
type=str,
|
||||||
|
default="gpt-3.5-turbo,gpt-4",
|
||||||
|
help="Comma-separated list of fallback models to use in order.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--recursion-limit",
|
"--recursion-limit",
|
||||||
type=int,
|
type=int,
|
||||||
|
|
@ -414,12 +425,35 @@ def main():
|
||||||
)
|
)
|
||||||
_global_memory["config"]["planner_model"] = args.planner_model or args.model
|
_global_memory["config"]["planner_model"] = args.planner_model or args.model
|
||||||
|
|
||||||
|
_global_memory["config"]["no_fallback_tool"] = args.no_fallback_tool
|
||||||
|
_global_memory["config"]["fallback_tool_models"] = (
|
||||||
|
[
|
||||||
|
model.strip()
|
||||||
|
for model in args.fallback_tool_models.split(",")
|
||||||
|
if model.strip()
|
||||||
|
]
|
||||||
|
if args.fallback_tool_models
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
# Store research config with fallback to base values
|
# Store research config with fallback to base values
|
||||||
_global_memory["config"]["research_provider"] = (
|
_global_memory["config"]["research_provider"] = (
|
||||||
args.research_provider or args.provider
|
args.research_provider or args.provider
|
||||||
)
|
)
|
||||||
_global_memory["config"]["research_model"] = args.research_model or args.model
|
_global_memory["config"]["research_model"] = args.research_model or args.model
|
||||||
|
|
||||||
|
# Store fallback tool configuration
|
||||||
|
_global_memory["config"]["no_fallback_tool"] = args.no_fallback_tool
|
||||||
|
_global_memory["config"]["fallback_tool_models"] = (
|
||||||
|
[
|
||||||
|
model.strip()
|
||||||
|
for model in args.fallback_tool_models.split(",")
|
||||||
|
if model.strip()
|
||||||
|
]
|
||||||
|
if args.fallback_tool_models
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
# Run research stage
|
# Run research stage
|
||||||
print_stage_header("Research Stage")
|
print_stage_header("Research Stage")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -68,12 +68,29 @@ class CiaynAgent:
|
||||||
- Memory management with configurable limits
|
- Memory management with configurable limits
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
class ToolCallFailure:
|
||||||
|
"""Tracks consecutive failures and fallback model usage for tool calls.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
consecutive_failures (int): Count of consecutive failures for current model
|
||||||
|
current_provider (Optional[str]): Current provider being used
|
||||||
|
current_model (Optional[str]): Current model being used
|
||||||
|
used_fallbacks (Set[str]): Set of fallback models already attempted
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.consecutive_failures = 0
|
||||||
|
self.current_provider = None
|
||||||
|
self.current_model = None
|
||||||
|
self.used_fallbacks = set()
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
tools: list,
|
tools: list,
|
||||||
max_history_messages: int = 50,
|
max_history_messages: int = 50,
|
||||||
max_tokens: Optional[int] = DEFAULT_TOKEN_LIMIT,
|
max_tokens: Optional[int] = DEFAULT_TOKEN_LIMIT,
|
||||||
|
config: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
"""Initialize the agent with a model and list of tools.
|
"""Initialize the agent with a model and list of tools.
|
||||||
|
|
||||||
|
|
@ -82,7 +99,17 @@ class CiaynAgent:
|
||||||
tools: List of tools available to the agent
|
tools: List of tools available to the agent
|
||||||
max_history_messages: Maximum number of messages to keep in chat history
|
max_history_messages: Maximum number of messages to keep in chat history
|
||||||
max_tokens: Maximum number of tokens allowed in message history (None for no limit)
|
max_tokens: Maximum number of tokens allowed in message history (None for no limit)
|
||||||
|
config: Optional configuration dictionary for fallback settings
|
||||||
"""
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = {}
|
||||||
|
self.config = config
|
||||||
|
self.provider = config.get("provider", "openai")
|
||||||
|
self.fallback_enabled = config.get("fallback_tool_enabled", True)
|
||||||
|
fallback_models_str = config.get("fallback_tool_models", "gpt-3.5-turbo,gpt-4")
|
||||||
|
self.fallback_tool_models = [
|
||||||
|
m.strip() for m in fallback_models_str.split(",") if m.strip()
|
||||||
|
]
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
self.max_history_messages = max_history_messages
|
self.max_history_messages = max_history_messages
|
||||||
|
|
@ -90,6 +117,7 @@ class CiaynAgent:
|
||||||
self.available_functions = []
|
self.available_functions = []
|
||||||
for t in tools:
|
for t in tools:
|
||||||
self.available_functions.append(get_function_info(t.func))
|
self.available_functions.append(get_function_info(t.func))
|
||||||
|
self._tool_failure = CiaynAgent.ToolCallFailure()
|
||||||
|
|
||||||
def _build_prompt(self, last_result: Optional[str] = None) -> str:
|
def _build_prompt(self, last_result: Optional[str] = None) -> str:
|
||||||
"""Build the prompt for the agent including available tools and context."""
|
"""Build the prompt for the agent including available tools and context."""
|
||||||
|
|
@ -221,23 +249,56 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
||||||
return base_prompt
|
return base_prompt
|
||||||
|
|
||||||
def _execute_tool(self, code: str) -> str:
|
def _execute_tool(self, code: str) -> str:
|
||||||
"""Execute a tool call and return its result."""
|
"""Execute a tool call with retry and fallback logic and return its result."""
|
||||||
globals_dict = {tool.func.__name__: tool.func for tool in self.tools}
|
max_retries = 3
|
||||||
|
retries = 0
|
||||||
|
last_error = None
|
||||||
|
while retries < max_retries:
|
||||||
|
try:
|
||||||
|
code = code.strip()
|
||||||
|
if validate_function_call_pattern(code):
|
||||||
|
functions_list = "\n\n".join(self.available_functions)
|
||||||
|
code = _extract_tool_call(code, functions_list)
|
||||||
|
globals_dict = {tool.func.__name__: tool.func for tool in self.tools}
|
||||||
|
result = eval(code, globals_dict)
|
||||||
|
self._tool_failure.consecutive_failures = 0
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
self._handle_tool_failure(code, e)
|
||||||
|
last_error = e
|
||||||
|
retries += 1
|
||||||
|
raise ToolExecutionError(
|
||||||
|
f"Error executing code after {max_retries} attempts: {str(last_error)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handle_tool_failure(self, code: str, error: Exception) -> None:
|
||||||
|
self._tool_failure.consecutive_failures += 1
|
||||||
|
max_failures = self.config.get("max_tool_failures", 3)
|
||||||
|
if (
|
||||||
|
self.fallback_enabled
|
||||||
|
and self._tool_failure.consecutive_failures >= max_failures
|
||||||
|
and self.fallback_tool_models
|
||||||
|
):
|
||||||
|
self._attempt_fallback(code)
|
||||||
|
|
||||||
|
def _attempt_fallback(self, code: str) -> None:
|
||||||
|
new_model = self.fallback_tool_models[0]
|
||||||
|
failed_tool_call_name = code.split('(')[0].strip()
|
||||||
|
logger.error(
|
||||||
|
f"Tool call failed {self._tool_failure.consecutive_failures} times. Attempting fallback to model: {new_model} for tool: {failed_tool_call_name}"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
code = code.strip()
|
from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env
|
||||||
# code = code.replace("\n", " ")
|
if not validate_provider_env(self.provider):
|
||||||
|
logger.error(f"Missing environment configuration for provider {self.provider}. Cannot fallback.")
|
||||||
# if the eval fails, try to extract it via a model call
|
else:
|
||||||
if validate_function_call_pattern(code):
|
self.model = initialize_llm(self.provider, new_model)
|
||||||
functions_list = "\n\n".join(self.available_functions)
|
self.model.bind_tools(self.tools, tool_choice=failed_tool_call_name)
|
||||||
code = _extract_tool_call(code, functions_list)
|
self._tool_failure.used_fallbacks.add(new_model)
|
||||||
|
merge_chat_history() # Assuming merge_chat_history handles merging fallback history
|
||||||
result = eval(code.strip(), globals_dict)
|
self._tool_failure.consecutive_failures = 0
|
||||||
return result
|
except Exception as switch_e:
|
||||||
except Exception as e:
|
logger.error(f"Fallback model switching failed: {switch_e}")
|
||||||
error_msg = f"Error executing code: {str(e)}"
|
|
||||||
raise ToolExecutionError(error_msg)
|
|
||||||
|
|
||||||
def _create_agent_chunk(self, content: str) -> Dict[str, Any]:
|
def _create_agent_chunk(self, content: str) -> Dict[str, Any]:
|
||||||
"""Create an agent chunk in the format expected by print_agent_output."""
|
"""Create an agent chunk in the format expected by print_agent_output."""
|
||||||
|
|
|
||||||
|
|
@ -2,3 +2,5 @@
|
||||||
|
|
||||||
DEFAULT_RECURSION_LIMIT = 100
|
DEFAULT_RECURSION_LIMIT = 100
|
||||||
DEFAULT_MAX_TEST_CMD_RETRIES = 3
|
DEFAULT_MAX_TEST_CMD_RETRIES = 3
|
||||||
|
DEFAULT_MAX_TOOL_FAILURES = 3
|
||||||
|
MAX_TOOL_FAILURES = 3
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from langchain_anthropic import ChatAnthropic
|
from langchain_anthropic import ChatAnthropic
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
from langchain_core.messages import BaseMessage
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
|
@ -47,9 +48,9 @@ def create_deepseek_client(
|
||||||
return ChatDeepseekReasoner(
|
return ChatDeepseekReasoner(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
temperature=0
|
temperature=(
|
||||||
if is_expert
|
0 if is_expert else (temperature if temperature is not None else 1)
|
||||||
else (temperature if temperature is not None else 1),
|
),
|
||||||
model=model_name,
|
model=model_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -72,9 +73,9 @@ def create_openrouter_client(
|
||||||
return ChatDeepseekReasoner(
|
return ChatDeepseekReasoner(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url="https://openrouter.ai/api/v1",
|
base_url="https://openrouter.ai/api/v1",
|
||||||
temperature=0
|
temperature=(
|
||||||
if is_expert
|
0 if is_expert else (temperature if temperature is not None else 1)
|
||||||
else (temperature if temperature is not None else 1),
|
),
|
||||||
model=model_name,
|
model=model_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -114,7 +115,12 @@ def get_provider_config(provider: str, is_expert: bool = False) -> Dict[str, Any
|
||||||
"base_url": "https://api.deepseek.com",
|
"base_url": "https://api.deepseek.com",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
return configs.get(provider, {})
|
config = configs.get(provider, {})
|
||||||
|
if not config or not config.get("api_key"):
|
||||||
|
raise ValueError(
|
||||||
|
f"Missing required environment variable for provider: {provider}"
|
||||||
|
)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
def create_llm_client(
|
def create_llm_client(
|
||||||
|
|
@ -219,8 +225,41 @@ def initialize_llm(
|
||||||
return create_llm_client(provider, model_name, temperature, is_expert=False)
|
return create_llm_client(provider, model_name, temperature, is_expert=False)
|
||||||
|
|
||||||
|
|
||||||
def initialize_expert_llm(
|
def initialize_expert_llm(provider: str, model_name: str) -> BaseChatModel:
|
||||||
provider: str, model_name: str
|
|
||||||
) -> BaseChatModel:
|
|
||||||
"""Initialize an expert language model client based on the specified provider and model."""
|
"""Initialize an expert language model client based on the specified provider and model."""
|
||||||
return create_llm_client(provider, model_name, temperature=None, is_expert=True)
|
return create_llm_client(provider, model_name, temperature=None, is_expert=True)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_provider_env(provider: str) -> bool:
|
||||||
|
"""Check if the required environment variables for a provider are set."""
|
||||||
|
required_vars = {
|
||||||
|
"openai": "OPENAI_API_KEY",
|
||||||
|
"anthropic": "ANTHROPIC_API_KEY",
|
||||||
|
"openrouter": "OPENROUTER_API_KEY",
|
||||||
|
"openai-compatible": "OPENAI_API_KEY",
|
||||||
|
"gemini": "GEMINI_API_KEY",
|
||||||
|
"deepseek": "DEEPSEEK_API_KEY",
|
||||||
|
}
|
||||||
|
key = required_vars.get(provider.lower())
|
||||||
|
if key:
|
||||||
|
return bool(os.getenv(key))
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def merge_chat_history(
|
||||||
|
original_history: List[BaseMessage], fallback_history: List[BaseMessage]
|
||||||
|
) -> List[BaseMessage]:
|
||||||
|
"""Merge original and fallback chat histories while preserving order.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original_history: The original chat message history
|
||||||
|
fallback_history: Additional messages from fallback attempts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[BaseMessage]: Combined message history preserving chronological order
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The function appends fallback messages to maintain context for future
|
||||||
|
interactions while preserving the original conversation flow.
|
||||||
|
"""
|
||||||
|
return original_history + fallback_history
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,6 @@ models_params = {
|
||||||
"gpt-4o-mini": {"token_limit": 128000, "supports_temperature": True},
|
"gpt-4o-mini": {"token_limit": 128000, "supports_temperature": True},
|
||||||
"o1-preview": {"token_limit": 128000, "supports_temperature": False},
|
"o1-preview": {"token_limit": 128000, "supports_temperature": False},
|
||||||
"o1-mini": {"token_limit": 128000, "supports_temperature": False},
|
"o1-mini": {"token_limit": 128000, "supports_temperature": False},
|
||||||
"o1-preview": {"token_limit": 128000, "supports_temperature": False},
|
|
||||||
"o1": {"token_limit": 200000, "supports_temperature": False},
|
"o1": {"token_limit": 200000, "supports_temperature": False},
|
||||||
"o3-mini": {"token_limit": 200000, "supports_temperature": False},
|
"o3-mini": {"token_limit": 200000, "supports_temperature": False},
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -1,25 +1,18 @@
|
||||||
from ra_aid.tools import (
|
from ra_aid.tools import (
|
||||||
ask_expert,
|
ask_expert,
|
||||||
ask_human,
|
ask_human,
|
||||||
delete_key_facts,
|
|
||||||
delete_key_snippets,
|
|
||||||
deregister_related_files,
|
|
||||||
emit_expert_context,
|
emit_expert_context,
|
||||||
emit_key_facts,
|
emit_key_facts,
|
||||||
emit_key_snippets,
|
emit_key_snippets,
|
||||||
emit_plan,
|
|
||||||
emit_related_files,
|
emit_related_files,
|
||||||
emit_research_notes,
|
emit_research_notes,
|
||||||
fuzzy_find_project_files,
|
fuzzy_find_project_files,
|
||||||
list_directory_tree,
|
list_directory_tree,
|
||||||
monorepo_detected,
|
|
||||||
plan_implementation_completed,
|
|
||||||
read_file_tool,
|
read_file_tool,
|
||||||
ripgrep_search,
|
ripgrep_search,
|
||||||
run_programming_task,
|
run_programming_task,
|
||||||
run_shell_command,
|
run_shell_command,
|
||||||
task_completed,
|
task_completed,
|
||||||
ui_detected,
|
|
||||||
web_search_tavily,
|
web_search_tavily,
|
||||||
)
|
)
|
||||||
from ra_aid.tools.agent import (
|
from ra_aid.tools.agent import (
|
||||||
|
|
@ -29,7 +22,6 @@ from ra_aid.tools.agent import (
|
||||||
request_task_implementation,
|
request_task_implementation,
|
||||||
request_web_research,
|
request_web_research,
|
||||||
)
|
)
|
||||||
from ra_aid.tools.memory import one_shot_completed
|
|
||||||
from ra_aid.tools.write_file import write_file_tool
|
from ra_aid.tools.write_file import write_file_tool
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -185,7 +185,11 @@ def ask_expert(question: str) -> str:
|
||||||
|
|
||||||
query_parts.extend(["# Question", question])
|
query_parts.extend(["# Question", question])
|
||||||
query_parts.extend(
|
query_parts.extend(
|
||||||
["\n # Addidional Requirements", "**DO NOT OVERTHINK**", "**DO NOT OVERCOMPLICATE**"]
|
[
|
||||||
|
"\n # Addidional Requirements",
|
||||||
|
"**DO NOT OVERTHINK**",
|
||||||
|
"**DO NOT OVERCOMPLICATE**",
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Join all parts
|
# Join all parts
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ def run_programming_task(
|
||||||
If new files are created, emit them after finishing.
|
If new files are created, emit them after finishing.
|
||||||
|
|
||||||
They can add/modify files, but not remove. Use run_shell_command to remove files. If referencing files you’ll delete, remove them after they finish.
|
They can add/modify files, but not remove. Use run_shell_command to remove files. If referencing files you’ll delete, remove them after they finish.
|
||||||
|
|
||||||
Use write_file_tool instead if you need to write the entire contents of file(s).
|
Use write_file_tool instead if you need to write the entire contents of file(s).
|
||||||
|
|
||||||
If the programmer wrote files, they actually wrote to disk. You do not need to rewrite the output of what the programmer showed you.
|
If the programmer wrote files, they actually wrote to disk. You do not need to rewrite the output of what the programmer showed you.
|
||||||
|
|
@ -117,7 +117,7 @@ def run_programming_task(
|
||||||
|
|
||||||
# Log the programming task
|
# Log the programming task
|
||||||
log_work_event(f"Executed programming task: {_truncate_for_log(instructions)}")
|
log_work_event(f"Executed programming task: {_truncate_for_log(instructions)}")
|
||||||
|
|
||||||
# Return structured output
|
# Return structured output
|
||||||
return {
|
return {
|
||||||
"output": truncate_output(result[0].decode()) if result[0] else "",
|
"output": truncate_output(result[0].decode()) if result[0] else "",
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,41 @@
|
||||||
|
import unittest
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
from ra_aid.agents.ciayn_agent import CiaynAgent, validate_function_call_pattern
|
from ra_aid.agents.ciayn_agent import CiaynAgent, validate_function_call_pattern
|
||||||
|
from ra_aid.exceptions import ToolExecutionError
|
||||||
|
|
||||||
|
|
||||||
|
# Dummy tool function for testing retry and fallback behavior
|
||||||
|
def dummy_tool():
|
||||||
|
dummy_tool.attempt += 1
|
||||||
|
if dummy_tool.attempt < 3:
|
||||||
|
raise Exception("Simulated failure")
|
||||||
|
return "dummy success"
|
||||||
|
|
||||||
|
|
||||||
|
dummy_tool.attempt = 0
|
||||||
|
|
||||||
|
|
||||||
|
class DummyTool:
|
||||||
|
def __init__(self, func):
|
||||||
|
self.func = func
|
||||||
|
|
||||||
|
|
||||||
|
class DummyModel:
|
||||||
|
def invoke(self, messages):
|
||||||
|
# Always return a code snippet that calls dummy_tool()
|
||||||
|
class Response:
|
||||||
|
content = "dummy_tool()"
|
||||||
|
|
||||||
|
return Response()
|
||||||
|
def bind_tools(self, tools, tool_choice):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Fixtures from the source file
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_model():
|
def mock_model():
|
||||||
"""Create a mock language model."""
|
"""Create a mock language model."""
|
||||||
|
|
@ -21,6 +51,7 @@ def agent(mock_model):
|
||||||
return CiaynAgent(mock_model, tools, max_history_messages=3)
|
return CiaynAgent(mock_model, tools, max_history_messages=3)
|
||||||
|
|
||||||
|
|
||||||
|
# Trimming test functions
|
||||||
def test_trim_chat_history_preserves_initial_messages(agent):
|
def test_trim_chat_history_preserves_initial_messages(agent):
|
||||||
"""Test that initial messages are preserved during trimming."""
|
"""Test that initial messages are preserved during trimming."""
|
||||||
initial_messages = [
|
initial_messages = [
|
||||||
|
|
@ -33,9 +64,7 @@ def test_trim_chat_history_preserves_initial_messages(agent):
|
||||||
HumanMessage(content="Chat 3"),
|
HumanMessage(content="Chat 3"),
|
||||||
AIMessage(content="Chat 4"),
|
AIMessage(content="Chat 4"),
|
||||||
]
|
]
|
||||||
|
|
||||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||||
|
|
||||||
# Verify initial messages are preserved
|
# Verify initial messages are preserved
|
||||||
assert result[:2] == initial_messages
|
assert result[:2] == initial_messages
|
||||||
# Verify only last 3 chat messages are kept (due to max_history_messages=3)
|
# Verify only last 3 chat messages are kept (due to max_history_messages=3)
|
||||||
|
|
@ -47,9 +76,7 @@ def test_trim_chat_history_under_limit(agent):
|
||||||
"""Test trimming when chat history is under the maximum limit."""
|
"""Test trimming when chat history is under the maximum limit."""
|
||||||
initial_messages = [HumanMessage(content="Initial")]
|
initial_messages = [HumanMessage(content="Initial")]
|
||||||
chat_history = [HumanMessage(content="Chat 1"), AIMessage(content="Chat 2")]
|
chat_history = [HumanMessage(content="Chat 1"), AIMessage(content="Chat 2")]
|
||||||
|
|
||||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||||
|
|
||||||
# Verify no trimming occurred
|
# Verify no trimming occurred
|
||||||
assert len(result) == 3
|
assert len(result) == 3
|
||||||
assert result == initial_messages + chat_history
|
assert result == initial_messages + chat_history
|
||||||
|
|
@ -65,9 +92,7 @@ def test_trim_chat_history_over_limit(agent):
|
||||||
AIMessage(content="Chat 4"),
|
AIMessage(content="Chat 4"),
|
||||||
HumanMessage(content="Chat 5"),
|
HumanMessage(content="Chat 5"),
|
||||||
]
|
]
|
||||||
|
|
||||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||||
|
|
||||||
# Verify correct trimming
|
# Verify correct trimming
|
||||||
assert len(result) == 4 # initial + max_history_messages
|
assert len(result) == 4 # initial + max_history_messages
|
||||||
assert result[0] == initial_messages[0] # Initial message preserved
|
assert result[0] == initial_messages[0] # Initial message preserved
|
||||||
|
|
@ -83,9 +108,7 @@ def test_trim_chat_history_empty_initial(agent):
|
||||||
HumanMessage(content="Chat 3"),
|
HumanMessage(content="Chat 3"),
|
||||||
AIMessage(content="Chat 4"),
|
AIMessage(content="Chat 4"),
|
||||||
]
|
]
|
||||||
|
|
||||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||||
|
|
||||||
# Verify only last 3 messages are kept
|
# Verify only last 3 messages are kept
|
||||||
assert len(result) == 3
|
assert len(result) == 3
|
||||||
assert result == chat_history[-3:]
|
assert result == chat_history[-3:]
|
||||||
|
|
@ -98,9 +121,7 @@ def test_trim_chat_history_empty_chat(agent):
|
||||||
AIMessage(content="Initial 2"),
|
AIMessage(content="Initial 2"),
|
||||||
]
|
]
|
||||||
chat_history = []
|
chat_history = []
|
||||||
|
|
||||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||||
|
|
||||||
# Verify initial messages are preserved and no trimming occurred
|
# Verify initial messages are preserved and no trimming occurred
|
||||||
assert result == initial_messages
|
assert result == initial_messages
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
|
|
@ -109,16 +130,13 @@ def test_trim_chat_history_empty_chat(agent):
|
||||||
def test_trim_chat_history_token_limit():
|
def test_trim_chat_history_token_limit():
|
||||||
"""Test trimming based on token limit."""
|
"""Test trimming based on token limit."""
|
||||||
agent = CiaynAgent(Mock(), [], max_history_messages=10, max_tokens=25)
|
agent = CiaynAgent(Mock(), [], max_history_messages=10, max_tokens=25)
|
||||||
|
|
||||||
initial_messages = [HumanMessage(content="Initial")] # ~2 tokens
|
initial_messages = [HumanMessage(content="Initial")] # ~2 tokens
|
||||||
chat_history = [
|
chat_history = [
|
||||||
HumanMessage(content="A" * 40), # ~10 tokens
|
HumanMessage(content="A" * 40), # ~10 tokens
|
||||||
AIMessage(content="B" * 40), # ~10 tokens
|
AIMessage(content="B" * 40), # ~10 tokens
|
||||||
HumanMessage(content="C" * 40), # ~10 tokens
|
HumanMessage(content="C" * 40), # ~10 tokens
|
||||||
]
|
]
|
||||||
|
|
||||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||||
|
|
||||||
# Should keep initial message (~2 tokens) and last message (~10 tokens)
|
# Should keep initial message (~2 tokens) and last message (~10 tokens)
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
assert result[0] == initial_messages[0]
|
assert result[0] == initial_messages[0]
|
||||||
|
|
@ -128,16 +146,13 @@ def test_trim_chat_history_token_limit():
|
||||||
def test_trim_chat_history_no_token_limit():
|
def test_trim_chat_history_no_token_limit():
|
||||||
"""Test trimming with no token limit set."""
|
"""Test trimming with no token limit set."""
|
||||||
agent = CiaynAgent(Mock(), [], max_history_messages=2, max_tokens=None)
|
agent = CiaynAgent(Mock(), [], max_history_messages=2, max_tokens=None)
|
||||||
|
|
||||||
initial_messages = [HumanMessage(content="Initial")]
|
initial_messages = [HumanMessage(content="Initial")]
|
||||||
chat_history = [
|
chat_history = [
|
||||||
HumanMessage(content="A" * 1000),
|
HumanMessage(content="A" * 1000),
|
||||||
AIMessage(content="B" * 1000),
|
AIMessage(content="B" * 1000),
|
||||||
HumanMessage(content="C" * 1000),
|
HumanMessage(content="C" * 1000),
|
||||||
]
|
]
|
||||||
|
|
||||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||||
|
|
||||||
# Should keep initial message and last 2 messages (max_history_messages=2)
|
# Should keep initial message and last 2 messages (max_history_messages=2)
|
||||||
assert len(result) == 3
|
assert len(result) == 3
|
||||||
assert result[0] == initial_messages[0]
|
assert result[0] == initial_messages[0]
|
||||||
|
|
@ -147,7 +162,6 @@ def test_trim_chat_history_no_token_limit():
|
||||||
def test_trim_chat_history_both_limits():
|
def test_trim_chat_history_both_limits():
|
||||||
"""Test trimming with both message count and token limits."""
|
"""Test trimming with both message count and token limits."""
|
||||||
agent = CiaynAgent(Mock(), [], max_history_messages=3, max_tokens=35)
|
agent = CiaynAgent(Mock(), [], max_history_messages=3, max_tokens=35)
|
||||||
|
|
||||||
initial_messages = [HumanMessage(content="Init")] # ~1 token
|
initial_messages = [HumanMessage(content="Init")] # ~1 token
|
||||||
chat_history = [
|
chat_history = [
|
||||||
HumanMessage(content="A" * 40), # ~10 tokens
|
HumanMessage(content="A" * 40), # ~10 tokens
|
||||||
|
|
@ -155,9 +169,7 @@ def test_trim_chat_history_both_limits():
|
||||||
HumanMessage(content="C" * 40), # ~10 tokens
|
HumanMessage(content="C" * 40), # ~10 tokens
|
||||||
AIMessage(content="D" * 40), # ~10 tokens
|
AIMessage(content="D" * 40), # ~10 tokens
|
||||||
]
|
]
|
||||||
|
|
||||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||||
|
|
||||||
# Should first apply message limit (keeping last 3)
|
# Should first apply message limit (keeping last 3)
|
||||||
# Then token limit should further reduce to fit under 15 tokens
|
# Then token limit should further reduce to fit under 15 tokens
|
||||||
assert len(result) == 2 # Initial message + 1 message under token limit
|
assert len(result) == 2 # Initial message + 1 message under token limit
|
||||||
|
|
@ -165,6 +177,33 @@ def test_trim_chat_history_both_limits():
|
||||||
assert result[1] == chat_history[-1]
|
assert result[1] == chat_history[-1]
|
||||||
|
|
||||||
|
|
||||||
|
# Fallback tests
|
||||||
|
class TestCiaynAgentFallback(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
# Reset dummy_tool attempt counter before each test
|
||||||
|
dummy_tool.attempt = 0
|
||||||
|
self.dummy_tool = DummyTool(dummy_tool)
|
||||||
|
self.model = DummyModel()
|
||||||
|
# Create a CiaynAgent with the dummy tool
|
||||||
|
self.agent = CiaynAgent(self.model, [self.dummy_tool])
|
||||||
|
|
||||||
|
def test_retry_logic_with_failure_recovery(self):
|
||||||
|
# Test that _execute_tool retries and eventually returns success
|
||||||
|
result = self.agent._execute_tool("dummy_tool()")
|
||||||
|
self.assertEqual(result, "dummy success")
|
||||||
|
|
||||||
|
def test_switch_models_on_fallback(self):
|
||||||
|
# Test fallback behavior by making dummy_tool always fail
|
||||||
|
def always_fail():
|
||||||
|
raise Exception("Persistent failure")
|
||||||
|
|
||||||
|
always_fail_tool = DummyTool(always_fail)
|
||||||
|
agent = CiaynAgent(self.model, [always_fail_tool])
|
||||||
|
with self.assertRaises(ToolExecutionError):
|
||||||
|
agent._execute_tool("always_fail()")
|
||||||
|
|
||||||
|
|
||||||
|
# Function call validation tests
|
||||||
class TestFunctionCallValidation:
|
class TestFunctionCallValidation:
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"test_input",
|
"test_input",
|
||||||
|
|
@ -221,3 +260,54 @@ class TestFunctionCallValidation:
|
||||||
def test_multiline_responses(self, test_input):
|
def test_multiline_responses(self, test_input):
|
||||||
"""Test function calls spanning multiple lines."""
|
"""Test function calls spanning multiple lines."""
|
||||||
assert not validate_function_call_pattern(test_input)
|
assert not validate_function_call_pattern(test_input)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCiaynAgentNewMethods(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
# Create a dummy tool that always fails for testing fallback
|
||||||
|
def always_fail():
|
||||||
|
raise Exception("Failure for fallback test")
|
||||||
|
self.always_fail_tool = DummyTool(always_fail)
|
||||||
|
# Create a dummy model that does minimal work for fallback tests
|
||||||
|
self.dummy_model = DummyModel()
|
||||||
|
# Initialize CiaynAgent with configuration to trigger fallback quickly
|
||||||
|
self.agent = CiaynAgent(
|
||||||
|
self.dummy_model,
|
||||||
|
[self.always_fail_tool],
|
||||||
|
config={"max_tool_failures": 2, "fallback_tool_models": "dummy-fallback-model"}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_handle_tool_failure_increments_counter(self):
|
||||||
|
initial_failures = self.agent._tool_failure.consecutive_failures
|
||||||
|
self.agent._handle_tool_failure("dummy_call()", Exception("Test error"))
|
||||||
|
self.assertEqual(self.agent._tool_failure.consecutive_failures, initial_failures + 1)
|
||||||
|
|
||||||
|
def test_attempt_fallback_invokes_fallback_logic(self):
|
||||||
|
# Monkey-patch initialize_llm, merge_chat_history, and validate_provider_env
|
||||||
|
# to simulate fallback switching without external dependencies.
|
||||||
|
def dummy_initialize_llm(provider, model_name, temperature=None):
|
||||||
|
return self.dummy_model
|
||||||
|
def dummy_merge_chat_history():
|
||||||
|
return ["merged"]
|
||||||
|
def dummy_validate_provider_env(provider):
|
||||||
|
return True
|
||||||
|
import ra_aid.llm as llm
|
||||||
|
original_initialize = llm.initialize_llm
|
||||||
|
original_merge = llm.merge_chat_history
|
||||||
|
original_validate = llm.validate_provider_env
|
||||||
|
llm.initialize_llm = dummy_initialize_llm
|
||||||
|
llm.merge_chat_history = dummy_merge_chat_history
|
||||||
|
llm.validate_provider_env = dummy_validate_provider_env
|
||||||
|
|
||||||
|
# Set failure counter high enough to trigger fallback in _handle_tool_failure
|
||||||
|
self.agent._tool_failure.consecutive_failures = 2
|
||||||
|
# Call _attempt_fallback; it should reset the failure counter to 0 on success.
|
||||||
|
self.agent._attempt_fallback("always_fail_tool()")
|
||||||
|
self.assertEqual(self.agent._tool_failure.consecutive_failures, 0)
|
||||||
|
# Restore original functions
|
||||||
|
llm.initialize_llm = original_initialize
|
||||||
|
llm.merge_chat_history = original_merge
|
||||||
|
llm.validate_provider_env = original_validate
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
|
|
@ -54,7 +54,9 @@ def test_initialize_expert_defaults(clean_env, mock_openai, monkeypatch):
|
||||||
monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "test-key")
|
monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "test-key")
|
||||||
_llm = initialize_expert_llm("openai", "o1")
|
_llm = initialize_expert_llm("openai", "o1")
|
||||||
|
|
||||||
mock_openai.assert_called_once_with(api_key="test-key", model="o1", reasoning_effort="high")
|
mock_openai.assert_called_once_with(
|
||||||
|
api_key="test-key", model="o1", reasoning_effort="high"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_initialize_expert_openai_custom(clean_env, mock_openai, monkeypatch):
|
def test_initialize_expert_openai_custom(clean_env, mock_openai, monkeypatch):
|
||||||
|
|
@ -63,7 +65,10 @@ def test_initialize_expert_openai_custom(clean_env, mock_openai, monkeypatch):
|
||||||
_llm = initialize_expert_llm("openai", "gpt-4-preview")
|
_llm = initialize_expert_llm("openai", "gpt-4-preview")
|
||||||
|
|
||||||
mock_openai.assert_called_once_with(
|
mock_openai.assert_called_once_with(
|
||||||
api_key="test-key", model="gpt-4-preview", temperature=0, reasoning_effort="high"
|
api_key="test-key",
|
||||||
|
model="gpt-4-preview",
|
||||||
|
temperature=0,
|
||||||
|
reasoning_effort="high",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -348,7 +353,9 @@ def test_environment_variable_precedence(clean_env, mock_openai, monkeypatch):
|
||||||
|
|
||||||
# Test LLM client creation with expert mode
|
# Test LLM client creation with expert mode
|
||||||
_llm = create_llm_client("openai", "o1", is_expert=True)
|
_llm = create_llm_client("openai", "o1", is_expert=True)
|
||||||
mock_openai.assert_called_with(api_key="expert-key", model="o1", reasoning_effort="high")
|
mock_openai.assert_called_with(
|
||||||
|
api_key="expert-key", model="o1", reasoning_effort="high"
|
||||||
|
)
|
||||||
|
|
||||||
# Test environment validation
|
# Test environment validation
|
||||||
monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "")
|
monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue