feat(issue): add LLM Tool Call Fallback Feature documentation to outline the new functionality for automatic fallback to alternative LLM models after consecutive failures
feat(ciayn_agent): implement fallback mechanism in CiaynAgent to handle tool call failures and switch to alternative models feat(__main__): add command line arguments for fallback configuration in the main application feat(llm): add validation for required environment variables for LLM providers and merge chat histories during fallback fix(config): define default values for maximum tool failures in configuration test(ciayn_agent): add unit tests for fallback logic and tool call execution with retries and error handling test(llm): enhance tests for LLM initialization and environment variable validation
This commit is contained in:
parent
00a455d586
commit
45b993cfd0
|
|
@ -0,0 +1,115 @@
|
|||
# LLM Tool Call Fallback Feature
|
||||
|
||||
## Overview
|
||||
Add functionality to automatically fallback to alternative LLM models when a tool call experiences multiple consecutive failures.
|
||||
|
||||
## Background
|
||||
Currently, when a tool call fails due to LLM-related errors (e.g., API timeouts, rate limits, context length issues), there is no automatic fallback mechanism. This can lead to interrupted workflows and poor user experience.
|
||||
|
||||
## Relevant Files
|
||||
- ra_aid/agents/ciayn_agent.py
|
||||
- ra_aid/llm.py
|
||||
- ra_aid/agent_utils.py
|
||||
- ra_aid/__main__.py
|
||||
- ra_aid/models_params.py
|
||||
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### Configuration
|
||||
- Add new configuration value `max_tool_failures` (default: 3) to track consecutive failures before triggering fallback
|
||||
- Add new command line argument `--no-fallback-tool` to disable fallback behavior (enabled by default)
|
||||
- **Add new command line argument** `--fallback-tool-models` to specify a comma-separated list of fallback tool models (default: "gpt-3.5-turbo,gpt-4")
|
||||
This list defines the fallback model sequence used by forced tool calls (via `bind_tools`) when tool call failures occur.
|
||||
- Track failure count per tool call context
|
||||
- Reset failure counter on successful tool call
|
||||
- Store fallback model sequence per provider
|
||||
- Need to validate if ENV vars are set for provider usage of that fallback model
|
||||
before usage, if that fallback ENV is not available then fallback to the next model
|
||||
- Have default list of common models, first try `claude-3-5-sonnet-20241022` but
|
||||
have many alternative fallback models.
|
||||
|
||||
### Tool Call Wrapper
|
||||
Create a new wrapper function to handle tool call execution with fallback logic:
|
||||
|
||||
```python
|
||||
def execute_tool_with_fallback(tool_call_func, *args, **kwargs):
|
||||
failures = 0
|
||||
max_failures = get_config().max_tool_failures
|
||||
|
||||
while failures < max_failures:
|
||||
try:
|
||||
return tool_call_func(*args, **kwargs)
|
||||
except LLMError as e:
|
||||
failures += 1
|
||||
if failures >= max_failures:
|
||||
# Use forced tool call via bind_tools with retry:
|
||||
llm_retry = llm_model.with_retry(stop_after_attempt=3) # Try three times
|
||||
try_fallback_model(force=True, model=llm_retry)
|
||||
# Merge fallback model chat messages back into the original chat history.
|
||||
merge_fallback_chat_history()
|
||||
failures = 0 # Reset counter for new model
|
||||
else:
|
||||
raise
|
||||
```
|
||||
|
||||
The prompt passed to `try_fallback_model`, should be the failed last few failing tool calls.
|
||||
|
||||
### Model Fallback Sequence
|
||||
Define fallback sequences for each provider based on model capabilities:
|
||||
|
||||
1. Try same provider's smaller models
|
||||
2. Try alternative providers' equivalent models
|
||||
3. Raise final error if all fallbacks fail
|
||||
|
||||
### Provider Strategy Updates
|
||||
Update provider strategies to support fallback configuration:
|
||||
- Add provider-specific fallback sequences
|
||||
- Handle model capability validation during fallback
|
||||
- Track successful/failed attempts
|
||||
|
||||
## Risks and Mitigations
|
||||
1. **Performance Impact**
|
||||
- Risk: Multiple fallback attempts could increase latency
|
||||
- Mitigation: Set reasonable max_failures limit and timeouts
|
||||
|
||||
2. **Consistency**
|
||||
- Risk: Different models may give slightly different outputs
|
||||
- Mitigation: Validate output schema consistency across models
|
||||
|
||||
3. **Cost**
|
||||
- Risk: Fallback to more expensive models
|
||||
- Mitigation: Configure cost limits and preferred fallback sequences
|
||||
|
||||
4. **State Management**
|
||||
- Risk: Loss of context during fallbacks
|
||||
- Mitigation: Preserve conversation state and tool context
|
||||
|
||||
## Acceptance Criteria
|
||||
1. Tool calls automatically attempt fallback models after N consecutive failures
|
||||
2. `--no-fallback-tool` argument successfully disables fallback behavior
|
||||
3. Fallback sequence respects provider and model capabilities
|
||||
4. Original error is preserved if all fallbacks fail
|
||||
5. Unit tests cover fallback scenarios and edge cases
|
||||
6. README.md updated to reflect new behavior
|
||||
|
||||
## Testing
|
||||
1. Unit tests for fallback wrapper
|
||||
2. Integration tests with mock LLM failures
|
||||
3. Provider strategy fallback tests
|
||||
4. Command line argument handling
|
||||
5. Error preservation and reporting
|
||||
6. Performance impact measurement
|
||||
7. Edge cases (e.g., partial failures, timeout handling)
|
||||
8. State preservation during fallbacks
|
||||
|
||||
## Documentation Updates
|
||||
1. Add fallback feature to main README
|
||||
2. Document `--no-fallback-tool` in CLI help
|
||||
3. Document provider-specific fallback sequences
|
||||
|
||||
## Future Considerations
|
||||
1. Allow custom fallback sequences via configuration
|
||||
2. Add monitoring and alerting for fallback frequency
|
||||
3. Optimize fallback selection based on historical success rates
|
||||
4. Cost-aware fallback routing
|
||||
|
|
@ -149,6 +149,17 @@ Examples:
|
|||
action="store_false",
|
||||
help="Whether to disable token limiting for Anthropic Claude react agents. Token limiter removes older messages to prevent maximum token limit API errors.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-fallback-tool",
|
||||
action="store_true",
|
||||
help="Disable fallback model switching.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fallback-tool-models",
|
||||
type=str,
|
||||
default="gpt-3.5-turbo,gpt-4",
|
||||
help="Comma-separated list of fallback models to use in order.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recursion-limit",
|
||||
type=int,
|
||||
|
|
@ -414,12 +425,35 @@ def main():
|
|||
)
|
||||
_global_memory["config"]["planner_model"] = args.planner_model or args.model
|
||||
|
||||
_global_memory["config"]["no_fallback_tool"] = args.no_fallback_tool
|
||||
_global_memory["config"]["fallback_tool_models"] = (
|
||||
[
|
||||
model.strip()
|
||||
for model in args.fallback_tool_models.split(",")
|
||||
if model.strip()
|
||||
]
|
||||
if args.fallback_tool_models
|
||||
else []
|
||||
)
|
||||
|
||||
# Store research config with fallback to base values
|
||||
_global_memory["config"]["research_provider"] = (
|
||||
args.research_provider or args.provider
|
||||
)
|
||||
_global_memory["config"]["research_model"] = args.research_model or args.model
|
||||
|
||||
# Store fallback tool configuration
|
||||
_global_memory["config"]["no_fallback_tool"] = args.no_fallback_tool
|
||||
_global_memory["config"]["fallback_tool_models"] = (
|
||||
[
|
||||
model.strip()
|
||||
for model in args.fallback_tool_models.split(",")
|
||||
if model.strip()
|
||||
]
|
||||
if args.fallback_tool_models
|
||||
else []
|
||||
)
|
||||
|
||||
# Run research stage
|
||||
print_stage_header("Research Stage")
|
||||
|
||||
|
|
|
|||
|
|
@ -68,12 +68,29 @@ class CiaynAgent:
|
|||
- Memory management with configurable limits
|
||||
"""
|
||||
|
||||
class ToolCallFailure:
|
||||
"""Tracks consecutive failures and fallback model usage for tool calls.
|
||||
|
||||
Attributes:
|
||||
consecutive_failures (int): Count of consecutive failures for current model
|
||||
current_provider (Optional[str]): Current provider being used
|
||||
current_model (Optional[str]): Current model being used
|
||||
used_fallbacks (Set[str]): Set of fallback models already attempted
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.consecutive_failures = 0
|
||||
self.current_provider = None
|
||||
self.current_model = None
|
||||
self.used_fallbacks = set()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
tools: list,
|
||||
max_history_messages: int = 50,
|
||||
max_tokens: Optional[int] = DEFAULT_TOKEN_LIMIT,
|
||||
config: Optional[dict] = None,
|
||||
):
|
||||
"""Initialize the agent with a model and list of tools.
|
||||
|
||||
|
|
@ -82,7 +99,17 @@ class CiaynAgent:
|
|||
tools: List of tools available to the agent
|
||||
max_history_messages: Maximum number of messages to keep in chat history
|
||||
max_tokens: Maximum number of tokens allowed in message history (None for no limit)
|
||||
config: Optional configuration dictionary for fallback settings
|
||||
"""
|
||||
if config is None:
|
||||
config = {}
|
||||
self.config = config
|
||||
self.provider = config.get("provider", "openai")
|
||||
self.fallback_enabled = config.get("fallback_tool_enabled", True)
|
||||
fallback_models_str = config.get("fallback_tool_models", "gpt-3.5-turbo,gpt-4")
|
||||
self.fallback_tool_models = [
|
||||
m.strip() for m in fallback_models_str.split(",") if m.strip()
|
||||
]
|
||||
self.model = model
|
||||
self.tools = tools
|
||||
self.max_history_messages = max_history_messages
|
||||
|
|
@ -90,6 +117,7 @@ class CiaynAgent:
|
|||
self.available_functions = []
|
||||
for t in tools:
|
||||
self.available_functions.append(get_function_info(t.func))
|
||||
self._tool_failure = CiaynAgent.ToolCallFailure()
|
||||
|
||||
def _build_prompt(self, last_result: Optional[str] = None) -> str:
|
||||
"""Build the prompt for the agent including available tools and context."""
|
||||
|
|
@ -221,23 +249,56 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
|||
return base_prompt
|
||||
|
||||
def _execute_tool(self, code: str) -> str:
|
||||
"""Execute a tool call and return its result."""
|
||||
globals_dict = {tool.func.__name__: tool.func for tool in self.tools}
|
||||
"""Execute a tool call with retry and fallback logic and return its result."""
|
||||
max_retries = 3
|
||||
retries = 0
|
||||
last_error = None
|
||||
while retries < max_retries:
|
||||
try:
|
||||
code = code.strip()
|
||||
if validate_function_call_pattern(code):
|
||||
functions_list = "\n\n".join(self.available_functions)
|
||||
code = _extract_tool_call(code, functions_list)
|
||||
globals_dict = {tool.func.__name__: tool.func for tool in self.tools}
|
||||
result = eval(code, globals_dict)
|
||||
self._tool_failure.consecutive_failures = 0
|
||||
return result
|
||||
except Exception as e:
|
||||
self._handle_tool_failure(code, e)
|
||||
last_error = e
|
||||
retries += 1
|
||||
raise ToolExecutionError(
|
||||
f"Error executing code after {max_retries} attempts: {str(last_error)}"
|
||||
)
|
||||
|
||||
def _handle_tool_failure(self, code: str, error: Exception) -> None:
|
||||
self._tool_failure.consecutive_failures += 1
|
||||
max_failures = self.config.get("max_tool_failures", 3)
|
||||
if (
|
||||
self.fallback_enabled
|
||||
and self._tool_failure.consecutive_failures >= max_failures
|
||||
and self.fallback_tool_models
|
||||
):
|
||||
self._attempt_fallback(code)
|
||||
|
||||
def _attempt_fallback(self, code: str) -> None:
|
||||
new_model = self.fallback_tool_models[0]
|
||||
failed_tool_call_name = code.split('(')[0].strip()
|
||||
logger.error(
|
||||
f"Tool call failed {self._tool_failure.consecutive_failures} times. Attempting fallback to model: {new_model} for tool: {failed_tool_call_name}"
|
||||
)
|
||||
try:
|
||||
code = code.strip()
|
||||
# code = code.replace("\n", " ")
|
||||
|
||||
# if the eval fails, try to extract it via a model call
|
||||
if validate_function_call_pattern(code):
|
||||
functions_list = "\n\n".join(self.available_functions)
|
||||
code = _extract_tool_call(code, functions_list)
|
||||
|
||||
result = eval(code.strip(), globals_dict)
|
||||
return result
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing code: {str(e)}"
|
||||
raise ToolExecutionError(error_msg)
|
||||
from ra_aid.llm import initialize_llm, merge_chat_history, validate_provider_env
|
||||
if not validate_provider_env(self.provider):
|
||||
logger.error(f"Missing environment configuration for provider {self.provider}. Cannot fallback.")
|
||||
else:
|
||||
self.model = initialize_llm(self.provider, new_model)
|
||||
self.model.bind_tools(self.tools, tool_choice=failed_tool_call_name)
|
||||
self._tool_failure.used_fallbacks.add(new_model)
|
||||
merge_chat_history() # Assuming merge_chat_history handles merging fallback history
|
||||
self._tool_failure.consecutive_failures = 0
|
||||
except Exception as switch_e:
|
||||
logger.error(f"Fallback model switching failed: {switch_e}")
|
||||
|
||||
def _create_agent_chunk(self, content: str) -> Dict[str, Any]:
|
||||
"""Create an agent chunk in the format expected by print_agent_output."""
|
||||
|
|
|
|||
|
|
@ -2,3 +2,5 @@
|
|||
|
||||
DEFAULT_RECURSION_LIMIT = 100
|
||||
DEFAULT_MAX_TEST_CMD_RETRIES = 3
|
||||
DEFAULT_MAX_TOOL_FAILURES = 3
|
||||
MAX_TOOL_FAILURES = 3
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
|
|
@ -47,9 +48,9 @@ def create_deepseek_client(
|
|||
return ChatDeepseekReasoner(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
temperature=0
|
||||
if is_expert
|
||||
else (temperature if temperature is not None else 1),
|
||||
temperature=(
|
||||
0 if is_expert else (temperature if temperature is not None else 1)
|
||||
),
|
||||
model=model_name,
|
||||
)
|
||||
|
||||
|
|
@ -72,9 +73,9 @@ def create_openrouter_client(
|
|||
return ChatDeepseekReasoner(
|
||||
api_key=api_key,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
temperature=0
|
||||
if is_expert
|
||||
else (temperature if temperature is not None else 1),
|
||||
temperature=(
|
||||
0 if is_expert else (temperature if temperature is not None else 1)
|
||||
),
|
||||
model=model_name,
|
||||
)
|
||||
|
||||
|
|
@ -114,7 +115,12 @@ def get_provider_config(provider: str, is_expert: bool = False) -> Dict[str, Any
|
|||
"base_url": "https://api.deepseek.com",
|
||||
},
|
||||
}
|
||||
return configs.get(provider, {})
|
||||
config = configs.get(provider, {})
|
||||
if not config or not config.get("api_key"):
|
||||
raise ValueError(
|
||||
f"Missing required environment variable for provider: {provider}"
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
def create_llm_client(
|
||||
|
|
@ -219,8 +225,41 @@ def initialize_llm(
|
|||
return create_llm_client(provider, model_name, temperature, is_expert=False)
|
||||
|
||||
|
||||
def initialize_expert_llm(
|
||||
provider: str, model_name: str
|
||||
) -> BaseChatModel:
|
||||
def initialize_expert_llm(provider: str, model_name: str) -> BaseChatModel:
|
||||
"""Initialize an expert language model client based on the specified provider and model."""
|
||||
return create_llm_client(provider, model_name, temperature=None, is_expert=True)
|
||||
|
||||
|
||||
def validate_provider_env(provider: str) -> bool:
|
||||
"""Check if the required environment variables for a provider are set."""
|
||||
required_vars = {
|
||||
"openai": "OPENAI_API_KEY",
|
||||
"anthropic": "ANTHROPIC_API_KEY",
|
||||
"openrouter": "OPENROUTER_API_KEY",
|
||||
"openai-compatible": "OPENAI_API_KEY",
|
||||
"gemini": "GEMINI_API_KEY",
|
||||
"deepseek": "DEEPSEEK_API_KEY",
|
||||
}
|
||||
key = required_vars.get(provider.lower())
|
||||
if key:
|
||||
return bool(os.getenv(key))
|
||||
return False
|
||||
|
||||
|
||||
def merge_chat_history(
|
||||
original_history: List[BaseMessage], fallback_history: List[BaseMessage]
|
||||
) -> List[BaseMessage]:
|
||||
"""Merge original and fallback chat histories while preserving order.
|
||||
|
||||
Args:
|
||||
original_history: The original chat message history
|
||||
fallback_history: Additional messages from fallback attempts
|
||||
|
||||
Returns:
|
||||
List[BaseMessage]: Combined message history preserving chronological order
|
||||
|
||||
Note:
|
||||
The function appends fallback messages to maintain context for future
|
||||
interactions while preserving the original conversation flow.
|
||||
"""
|
||||
return original_history + fallback_history
|
||||
|
|
|
|||
|
|
@ -27,7 +27,6 @@ models_params = {
|
|||
"gpt-4o-mini": {"token_limit": 128000, "supports_temperature": True},
|
||||
"o1-preview": {"token_limit": 128000, "supports_temperature": False},
|
||||
"o1-mini": {"token_limit": 128000, "supports_temperature": False},
|
||||
"o1-preview": {"token_limit": 128000, "supports_temperature": False},
|
||||
"o1": {"token_limit": 200000, "supports_temperature": False},
|
||||
"o3-mini": {"token_limit": 200000, "supports_temperature": False},
|
||||
},
|
||||
|
|
|
|||
|
|
@ -1,25 +1,18 @@
|
|||
from ra_aid.tools import (
|
||||
ask_expert,
|
||||
ask_human,
|
||||
delete_key_facts,
|
||||
delete_key_snippets,
|
||||
deregister_related_files,
|
||||
emit_expert_context,
|
||||
emit_key_facts,
|
||||
emit_key_snippets,
|
||||
emit_plan,
|
||||
emit_related_files,
|
||||
emit_research_notes,
|
||||
fuzzy_find_project_files,
|
||||
list_directory_tree,
|
||||
monorepo_detected,
|
||||
plan_implementation_completed,
|
||||
read_file_tool,
|
||||
ripgrep_search,
|
||||
run_programming_task,
|
||||
run_shell_command,
|
||||
task_completed,
|
||||
ui_detected,
|
||||
web_search_tavily,
|
||||
)
|
||||
from ra_aid.tools.agent import (
|
||||
|
|
@ -29,7 +22,6 @@ from ra_aid.tools.agent import (
|
|||
request_task_implementation,
|
||||
request_web_research,
|
||||
)
|
||||
from ra_aid.tools.memory import one_shot_completed
|
||||
from ra_aid.tools.write_file import write_file_tool
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -185,7 +185,11 @@ def ask_expert(question: str) -> str:
|
|||
|
||||
query_parts.extend(["# Question", question])
|
||||
query_parts.extend(
|
||||
["\n # Addidional Requirements", "**DO NOT OVERTHINK**", "**DO NOT OVERCOMPLICATE**"]
|
||||
[
|
||||
"\n # Addidional Requirements",
|
||||
"**DO NOT OVERTHINK**",
|
||||
"**DO NOT OVERCOMPLICATE**",
|
||||
]
|
||||
)
|
||||
|
||||
# Join all parts
|
||||
|
|
|
|||
|
|
@ -1,11 +1,41 @@
|
|||
import unittest
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from ra_aid.agents.ciayn_agent import CiaynAgent, validate_function_call_pattern
|
||||
from ra_aid.exceptions import ToolExecutionError
|
||||
|
||||
|
||||
# Dummy tool function for testing retry and fallback behavior
|
||||
def dummy_tool():
|
||||
dummy_tool.attempt += 1
|
||||
if dummy_tool.attempt < 3:
|
||||
raise Exception("Simulated failure")
|
||||
return "dummy success"
|
||||
|
||||
|
||||
dummy_tool.attempt = 0
|
||||
|
||||
|
||||
class DummyTool:
|
||||
def __init__(self, func):
|
||||
self.func = func
|
||||
|
||||
|
||||
class DummyModel:
|
||||
def invoke(self, messages):
|
||||
# Always return a code snippet that calls dummy_tool()
|
||||
class Response:
|
||||
content = "dummy_tool()"
|
||||
|
||||
return Response()
|
||||
def bind_tools(self, tools, tool_choice):
|
||||
pass
|
||||
|
||||
|
||||
# Fixtures from the source file
|
||||
@pytest.fixture
|
||||
def mock_model():
|
||||
"""Create a mock language model."""
|
||||
|
|
@ -21,6 +51,7 @@ def agent(mock_model):
|
|||
return CiaynAgent(mock_model, tools, max_history_messages=3)
|
||||
|
||||
|
||||
# Trimming test functions
|
||||
def test_trim_chat_history_preserves_initial_messages(agent):
|
||||
"""Test that initial messages are preserved during trimming."""
|
||||
initial_messages = [
|
||||
|
|
@ -33,9 +64,7 @@ def test_trim_chat_history_preserves_initial_messages(agent):
|
|||
HumanMessage(content="Chat 3"),
|
||||
AIMessage(content="Chat 4"),
|
||||
]
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Verify initial messages are preserved
|
||||
assert result[:2] == initial_messages
|
||||
# Verify only last 3 chat messages are kept (due to max_history_messages=3)
|
||||
|
|
@ -47,9 +76,7 @@ def test_trim_chat_history_under_limit(agent):
|
|||
"""Test trimming when chat history is under the maximum limit."""
|
||||
initial_messages = [HumanMessage(content="Initial")]
|
||||
chat_history = [HumanMessage(content="Chat 1"), AIMessage(content="Chat 2")]
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Verify no trimming occurred
|
||||
assert len(result) == 3
|
||||
assert result == initial_messages + chat_history
|
||||
|
|
@ -65,9 +92,7 @@ def test_trim_chat_history_over_limit(agent):
|
|||
AIMessage(content="Chat 4"),
|
||||
HumanMessage(content="Chat 5"),
|
||||
]
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Verify correct trimming
|
||||
assert len(result) == 4 # initial + max_history_messages
|
||||
assert result[0] == initial_messages[0] # Initial message preserved
|
||||
|
|
@ -83,9 +108,7 @@ def test_trim_chat_history_empty_initial(agent):
|
|||
HumanMessage(content="Chat 3"),
|
||||
AIMessage(content="Chat 4"),
|
||||
]
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Verify only last 3 messages are kept
|
||||
assert len(result) == 3
|
||||
assert result == chat_history[-3:]
|
||||
|
|
@ -98,9 +121,7 @@ def test_trim_chat_history_empty_chat(agent):
|
|||
AIMessage(content="Initial 2"),
|
||||
]
|
||||
chat_history = []
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Verify initial messages are preserved and no trimming occurred
|
||||
assert result == initial_messages
|
||||
assert len(result) == 2
|
||||
|
|
@ -109,16 +130,13 @@ def test_trim_chat_history_empty_chat(agent):
|
|||
def test_trim_chat_history_token_limit():
|
||||
"""Test trimming based on token limit."""
|
||||
agent = CiaynAgent(Mock(), [], max_history_messages=10, max_tokens=25)
|
||||
|
||||
initial_messages = [HumanMessage(content="Initial")] # ~2 tokens
|
||||
chat_history = [
|
||||
HumanMessage(content="A" * 40), # ~10 tokens
|
||||
AIMessage(content="B" * 40), # ~10 tokens
|
||||
HumanMessage(content="C" * 40), # ~10 tokens
|
||||
]
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Should keep initial message (~2 tokens) and last message (~10 tokens)
|
||||
assert len(result) == 2
|
||||
assert result[0] == initial_messages[0]
|
||||
|
|
@ -128,16 +146,13 @@ def test_trim_chat_history_token_limit():
|
|||
def test_trim_chat_history_no_token_limit():
|
||||
"""Test trimming with no token limit set."""
|
||||
agent = CiaynAgent(Mock(), [], max_history_messages=2, max_tokens=None)
|
||||
|
||||
initial_messages = [HumanMessage(content="Initial")]
|
||||
chat_history = [
|
||||
HumanMessage(content="A" * 1000),
|
||||
AIMessage(content="B" * 1000),
|
||||
HumanMessage(content="C" * 1000),
|
||||
]
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Should keep initial message and last 2 messages (max_history_messages=2)
|
||||
assert len(result) == 3
|
||||
assert result[0] == initial_messages[0]
|
||||
|
|
@ -147,7 +162,6 @@ def test_trim_chat_history_no_token_limit():
|
|||
def test_trim_chat_history_both_limits():
|
||||
"""Test trimming with both message count and token limits."""
|
||||
agent = CiaynAgent(Mock(), [], max_history_messages=3, max_tokens=35)
|
||||
|
||||
initial_messages = [HumanMessage(content="Init")] # ~1 token
|
||||
chat_history = [
|
||||
HumanMessage(content="A" * 40), # ~10 tokens
|
||||
|
|
@ -155,9 +169,7 @@ def test_trim_chat_history_both_limits():
|
|||
HumanMessage(content="C" * 40), # ~10 tokens
|
||||
AIMessage(content="D" * 40), # ~10 tokens
|
||||
]
|
||||
|
||||
result = agent._trim_chat_history(initial_messages, chat_history)
|
||||
|
||||
# Should first apply message limit (keeping last 3)
|
||||
# Then token limit should further reduce to fit under 15 tokens
|
||||
assert len(result) == 2 # Initial message + 1 message under token limit
|
||||
|
|
@ -165,6 +177,33 @@ def test_trim_chat_history_both_limits():
|
|||
assert result[1] == chat_history[-1]
|
||||
|
||||
|
||||
# Fallback tests
|
||||
class TestCiaynAgentFallback(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Reset dummy_tool attempt counter before each test
|
||||
dummy_tool.attempt = 0
|
||||
self.dummy_tool = DummyTool(dummy_tool)
|
||||
self.model = DummyModel()
|
||||
# Create a CiaynAgent with the dummy tool
|
||||
self.agent = CiaynAgent(self.model, [self.dummy_tool])
|
||||
|
||||
def test_retry_logic_with_failure_recovery(self):
|
||||
# Test that _execute_tool retries and eventually returns success
|
||||
result = self.agent._execute_tool("dummy_tool()")
|
||||
self.assertEqual(result, "dummy success")
|
||||
|
||||
def test_switch_models_on_fallback(self):
|
||||
# Test fallback behavior by making dummy_tool always fail
|
||||
def always_fail():
|
||||
raise Exception("Persistent failure")
|
||||
|
||||
always_fail_tool = DummyTool(always_fail)
|
||||
agent = CiaynAgent(self.model, [always_fail_tool])
|
||||
with self.assertRaises(ToolExecutionError):
|
||||
agent._execute_tool("always_fail()")
|
||||
|
||||
|
||||
# Function call validation tests
|
||||
class TestFunctionCallValidation:
|
||||
@pytest.mark.parametrize(
|
||||
"test_input",
|
||||
|
|
@ -221,3 +260,54 @@ class TestFunctionCallValidation:
|
|||
def test_multiline_responses(self, test_input):
|
||||
"""Test function calls spanning multiple lines."""
|
||||
assert not validate_function_call_pattern(test_input)
|
||||
|
||||
|
||||
class TestCiaynAgentNewMethods(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Create a dummy tool that always fails for testing fallback
|
||||
def always_fail():
|
||||
raise Exception("Failure for fallback test")
|
||||
self.always_fail_tool = DummyTool(always_fail)
|
||||
# Create a dummy model that does minimal work for fallback tests
|
||||
self.dummy_model = DummyModel()
|
||||
# Initialize CiaynAgent with configuration to trigger fallback quickly
|
||||
self.agent = CiaynAgent(
|
||||
self.dummy_model,
|
||||
[self.always_fail_tool],
|
||||
config={"max_tool_failures": 2, "fallback_tool_models": "dummy-fallback-model"}
|
||||
)
|
||||
|
||||
def test_handle_tool_failure_increments_counter(self):
|
||||
initial_failures = self.agent._tool_failure.consecutive_failures
|
||||
self.agent._handle_tool_failure("dummy_call()", Exception("Test error"))
|
||||
self.assertEqual(self.agent._tool_failure.consecutive_failures, initial_failures + 1)
|
||||
|
||||
def test_attempt_fallback_invokes_fallback_logic(self):
|
||||
# Monkey-patch initialize_llm, merge_chat_history, and validate_provider_env
|
||||
# to simulate fallback switching without external dependencies.
|
||||
def dummy_initialize_llm(provider, model_name, temperature=None):
|
||||
return self.dummy_model
|
||||
def dummy_merge_chat_history():
|
||||
return ["merged"]
|
||||
def dummy_validate_provider_env(provider):
|
||||
return True
|
||||
import ra_aid.llm as llm
|
||||
original_initialize = llm.initialize_llm
|
||||
original_merge = llm.merge_chat_history
|
||||
original_validate = llm.validate_provider_env
|
||||
llm.initialize_llm = dummy_initialize_llm
|
||||
llm.merge_chat_history = dummy_merge_chat_history
|
||||
llm.validate_provider_env = dummy_validate_provider_env
|
||||
|
||||
# Set failure counter high enough to trigger fallback in _handle_tool_failure
|
||||
self.agent._tool_failure.consecutive_failures = 2
|
||||
# Call _attempt_fallback; it should reset the failure counter to 0 on success.
|
||||
self.agent._attempt_fallback("always_fail_tool()")
|
||||
self.assertEqual(self.agent._tool_failure.consecutive_failures, 0)
|
||||
# Restore original functions
|
||||
llm.initialize_llm = original_initialize
|
||||
llm.merge_chat_history = original_merge
|
||||
llm.validate_provider_env = original_validate
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -54,7 +54,9 @@ def test_initialize_expert_defaults(clean_env, mock_openai, monkeypatch):
|
|||
monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "test-key")
|
||||
_llm = initialize_expert_llm("openai", "o1")
|
||||
|
||||
mock_openai.assert_called_once_with(api_key="test-key", model="o1", reasoning_effort="high")
|
||||
mock_openai.assert_called_once_with(
|
||||
api_key="test-key", model="o1", reasoning_effort="high"
|
||||
)
|
||||
|
||||
|
||||
def test_initialize_expert_openai_custom(clean_env, mock_openai, monkeypatch):
|
||||
|
|
@ -63,7 +65,10 @@ def test_initialize_expert_openai_custom(clean_env, mock_openai, monkeypatch):
|
|||
_llm = initialize_expert_llm("openai", "gpt-4-preview")
|
||||
|
||||
mock_openai.assert_called_once_with(
|
||||
api_key="test-key", model="gpt-4-preview", temperature=0, reasoning_effort="high"
|
||||
api_key="test-key",
|
||||
model="gpt-4-preview",
|
||||
temperature=0,
|
||||
reasoning_effort="high",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -348,7 +353,9 @@ def test_environment_variable_precedence(clean_env, mock_openai, monkeypatch):
|
|||
|
||||
# Test LLM client creation with expert mode
|
||||
_llm = create_llm_client("openai", "o1", is_expert=True)
|
||||
mock_openai.assert_called_with(api_key="expert-key", model="o1", reasoning_effort="high")
|
||||
mock_openai.assert_called_with(
|
||||
api_key="expert-key", model="o1", reasoning_effort="high"
|
||||
)
|
||||
|
||||
# Test environment validation
|
||||
monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "")
|
||||
|
|
|
|||
Loading…
Reference in New Issue