Add Deepseek Provider Support and Custom Deepseek Reasoner Chat Model (#50)

* chore: Add DeepSeek provider environment variable support in env.py

* feat: Add DeepSeek provider validation strategy in provider_strategy.py

* feat: Add support for DEEPSEEK provider in initialize_llm function

* feat: Create ChatDeepseekReasoner for custom handling of R1 models

* feat: Configure custom OpenAI client for DeepSeek API integration

* chore: Remove unused json import from deepseek_chat.py

* refactor: Simplify invocation_params and update acompletion_with_retry method

* feat: Override _generate to ensure message alternation in DeepseekReasoner

* feat: Add support for ChatDeepseekReasoner in LLM initialization

* feat: Use custom ChatDeepseekReasoner for DeepSeek models in OpenRouter

* fix: Remove redundant condition for DeepSeek model initialization

* feat: Add DeepSeek support for expert model initialization in llm.py

* feat: Add DeepSeek model handling for OpenRouter in expert LLM initialization

* fix: Update model name checks for DeepSeek and OpenRouter providers

* refactor: Extract common logic for LLM initialization into reusable methods

* test: Add unit tests for DeepSeek and OpenRouter functionality

* test: Refactor tests to match updated LLM initialization and helpers

* fix: Import missing helper functions to resolve NameError in tests

* fix: Resolve NameError and improve environment variable fallback logic

* feat(readme): add DeepSeek API key requirements to documentation for better clarity on environment variables
feat(main.py): include DeepSeek as a supported provider in argument parsing for enhanced functionality
feat(deepseek_chat.py): implement ChatDeepseekReasoner class for handling DeepSeek reasoning models
feat(llm.py): add DeepSeek client creation logic to support DeepSeek models in the application
feat(models_tokens.py): define token limits for DeepSeek models to manage resource allocation
fix(provider_strategy.py): correct validation logic for DeepSeek environment variables to ensure proper configuration
chore(memory.py): refactor global memory structure for better readability and maintainability in the codebase
This commit is contained in:
Ariel Frischer 2025-01-22 04:21:10 -08:00 committed by GitHub
parent 7a68de2d06
commit 686ab42f88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 683 additions and 312 deletions

View File

@ -293,6 +293,7 @@ RA.Aid supports multiple providers through environment variables:
- `ANTHROPIC_API_KEY`: Required for the default Anthropic provider - `ANTHROPIC_API_KEY`: Required for the default Anthropic provider
- `OPENAI_API_KEY`: Required for OpenAI provider - `OPENAI_API_KEY`: Required for OpenAI provider
- `OPENROUTER_API_KEY`: Required for OpenRouter provider - `OPENROUTER_API_KEY`: Required for OpenRouter provider
- `DEEPSEEK_API_KEY`: Required for DeepSeek provider
- `OPENAI_API_BASE`: Required for OpenAI-compatible providers along with `OPENAI_API_KEY` - `OPENAI_API_BASE`: Required for OpenAI-compatible providers along with `OPENAI_API_KEY`
- `GEMINI_API_KEY`: Required for Gemini provider - `GEMINI_API_KEY`: Required for Gemini provider
@ -302,6 +303,7 @@ Expert Tool Environment Variables:
- `EXPERT_OPENROUTER_API_KEY`: API key for expert tool using OpenRouter provider - `EXPERT_OPENROUTER_API_KEY`: API key for expert tool using OpenRouter provider
- `EXPERT_OPENAI_API_BASE`: Base URL for expert tool using OpenAI-compatible provider - `EXPERT_OPENAI_API_BASE`: Base URL for expert tool using OpenAI-compatible provider
- `EXPERT_GEMINI_API_KEY`: API key for expert tool using Gemini provider - `EXPERT_GEMINI_API_KEY`: API key for expert tool using Gemini provider
- `EXPERT_DEEPSEEK_API_KEY`: API key for expert tool using DeepSeek provider
You can set these permanently in your shell's configuration file (e.g., `~/.bashrc` or `~/.zshrc`): You can set these permanently in your shell's configuration file (e.g., `~/.bashrc` or `~/.zshrc`):
@ -343,6 +345,15 @@ export GEMINI_API_KEY=your_api_key_here
ra-aid -m "Your task" --provider openrouter --model mistralai/mistral-large-2411 ra-aid -m "Your task" --provider openrouter --model mistralai/mistral-large-2411
``` ```
4. **Using DeepSeek**
```bash
# Direct DeepSeek provider (requires DEEPSEEK_API_KEY)
ra-aid -m "Your task" --provider deepseek --model deepseek-reasoner
# DeepSeek via OpenRouter
ra-aid -m "Your task" --provider openrouter --model deepseek/deepseek-r1
```
4. **Configuring Expert Provider** 4. **Configuring Expert Provider**
The expert tool is used by the agent for complex logic and debugging tasks. It can be configured to use different providers (OpenAI, Anthropic, OpenRouter, Gemini, openai-compatible) using the --expert-provider flag along with the corresponding EXPERT_*API_KEY environment variables. The expert tool is used by the agent for complex logic and debugging tasks. It can be configured to use different providers (OpenAI, Anthropic, OpenRouter, Gemini, openai-compatible) using the --expert-provider flag along with the corresponding EXPERT_*API_KEY environment variables.
@ -356,6 +367,10 @@ export GEMINI_API_KEY=your_api_key_here
export OPENROUTER_API_KEY=your_openrouter_api_key export OPENROUTER_API_KEY=your_openrouter_api_key
ra-aid -m "Your task" --expert-provider openrouter --expert-model mistralai/mistral-large-2411 ra-aid -m "Your task" --expert-provider openrouter --expert-model mistralai/mistral-large-2411
# Use DeepSeek for expert tool
export DEEPSEEK_API_KEY=your_deepseek_api_key
ra-aid -m "Your task" --expert-provider deepseek --expert-model deepseek-reasoner
# Use default OpenAI for expert tool # Use default OpenAI for expert tool
export EXPERT_OPENAI_API_KEY=your_openai_api_key export EXPERT_OPENAI_API_KEY=your_openai_api_key
ra-aid -m "Your task" --expert-provider openai --expert-model o1 ra-aid -m "Your task" --expert-provider openai --expert-model o1

View File

@ -39,6 +39,7 @@ def parse_arguments(args=None):
"openai", "openai",
"openrouter", "openrouter",
"openai-compatible", "openai-compatible",
"deepseek",
"gemini", "gemini",
] ]
ANTHROPIC_DEFAULT_MODEL = "claude-3-5-sonnet-20241022" ANTHROPIC_DEFAULT_MODEL = "claude-3-5-sonnet-20241022"

View File

@ -0,0 +1,52 @@
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult
from langchain_openai import ChatOpenAI
from typing import Any, List, Optional, Dict
# Docs: https://api-docs.deepseek.com/guides/reasoning_model
class ChatDeepseekReasoner(ChatOpenAI):
"""ChatDeepseekReasoner with custom overrides for R1/reasoner models."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def invocation_params(
self, options: Optional[Dict[str, Any]] = None, **kwargs: Any
) -> Dict[str, Any]:
params = super().invocation_params(options, **kwargs)
# Remove unsupported params for R1 models
params.pop("temperature", None)
params.pop("top_p", None)
params.pop("presence_penalty", None)
params.pop("frequency_penalty", None)
return params
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Override _generate to ensure message alternation in accordance to Deepseek API."""
processed = []
prev_role = None
for msg in messages:
current_role = "user" if msg.type == "human" else "assistant"
if prev_role == current_role:
if processed:
processed[-1].content += f"\n\n{msg.content}"
else:
processed.append(msg)
prev_role = current_role
return super()._generate(
processed, stop=stop, run_manager=run_manager, **kwargs
)

View File

@ -50,6 +50,9 @@ def copy_base_to_expert_vars(base_provider: str, expert_provider: str) -> None:
'gemini': { 'gemini': {
'GEMINI_API_KEY': 'EXPERT_GEMINI_API_KEY', 'GEMINI_API_KEY': 'EXPERT_GEMINI_API_KEY',
'GEMINI_MODEL': 'EXPERT_GEMINI_MODEL' 'GEMINI_MODEL': 'EXPERT_GEMINI_MODEL'
},
'deepseek': {
'DEEPSEEK_API_KEY': 'EXPERT_DEEPSEEK_API_KEY'
} }
} }

View File

@ -1,107 +1,184 @@
import os import os
from typing import Optional, Dict, Any
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic from langchain_anthropic import ChatAnthropic
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_google_genai import ChatGoogleGenerativeAI from langchain_google_genai import ChatGoogleGenerativeAI
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
def get_env_var(name: str, expert: bool = False) -> Optional[str]:
"""Get environment variable with optional expert prefix and fallback."""
prefix = "EXPERT_" if expert else ""
value = os.getenv(f"{prefix}{name}")
def initialize_llm(provider: str, model_name: str, temperature: float | None = None) -> BaseChatModel: # If expert mode and no expert value, fall back to base value
"""Initialize a language model client based on the specified provider and model. if expert and not value:
value = os.getenv(name)
Note: Environment variables must be validated before calling this function. return value
Use validate_environment() to ensure all required variables are set.
def create_deepseek_client(
model_name: str,
api_key: str,
base_url: str,
temperature: Optional[float] = None,
is_expert: bool = False,
) -> BaseChatModel:
"""Create DeepSeek client with appropriate configuration."""
if model_name.lower() == "deepseek-reasoner":
return ChatDeepseekReasoner(
api_key=api_key,
base_url=base_url,
temperature=0
if is_expert
else (temperature if temperature is not None else 1),
model=model_name,
)
return ChatOpenAI(
api_key=api_key,
base_url=base_url,
temperature=0 if is_expert else (temperature if temperature is not None else 1),
model=model_name,
)
def create_openrouter_client(
model_name: str,
api_key: str,
temperature: Optional[float] = None,
is_expert: bool = False,
) -> BaseChatModel:
"""Create OpenRouter client with appropriate configuration."""
if model_name.startswith("deepseek/") and "deepseek-r1" in model_name.lower():
return ChatDeepseekReasoner(
api_key=api_key,
base_url="https://openrouter.ai/api/v1",
temperature=0
if is_expert
else (temperature if temperature is not None else 1),
model=model_name,
)
return ChatOpenAI(
api_key=api_key,
base_url="https://openrouter.ai/api/v1",
model=model_name,
**({"temperature": temperature} if temperature is not None else {}),
)
def get_provider_config(provider: str, is_expert: bool = False) -> Dict[str, Any]:
"""Get provider-specific configuration."""
configs = {
"openai": {
"api_key": get_env_var("OPENAI_API_KEY", is_expert),
"base_url": None,
},
"anthropic": {
"api_key": get_env_var("ANTHROPIC_API_KEY", is_expert),
"base_url": None,
},
"openrouter": {
"api_key": get_env_var("OPENROUTER_API_KEY", is_expert),
"base_url": "https://openrouter.ai/api/v1",
},
"openai-compatible": {
"api_key": get_env_var("OPENAI_API_KEY", is_expert),
"base_url": get_env_var("OPENAI_API_BASE", is_expert),
},
"gemini": {
"api_key": get_env_var("GEMINI_API_KEY", is_expert),
"base_url": None,
},
"deepseek": {
"api_key": get_env_var("DEEPSEEK_API_KEY", is_expert),
"base_url": "https://api.deepseek.com",
},
}
return configs.get(provider, {})
def create_llm_client(
provider: str,
model_name: str,
temperature: Optional[float] = None,
is_expert: bool = False,
) -> BaseChatModel:
"""Create a language model client with appropriate configuration.
Args: Args:
provider: The LLM provider to use ('openai', 'anthropic', 'openrouter', 'openai-compatible', 'gemini') provider: The LLM provider to use
model_name: Name of the model to use model_name: Name of the model to use
temperature: Optional temperature setting for controlling randomness (0.0-2.0). temperature: Optional temperature setting (0.0-2.0)
If not specified, provider-specific defaults are used. is_expert: Whether this is an expert model (uses deterministic output)
Returns: Returns:
BaseChatModel: Configured language model client Configured language model client
Raises:
ValueError: If the provider is not supported
""" """
if provider == "openai": config = get_provider_config(provider, is_expert)
if not config:
raise ValueError(f"Unsupported provider: {provider}")
# Handle temperature for expert mode
if is_expert:
temperature = 0
if provider == "deepseek":
return create_deepseek_client(
model_name=model_name,
api_key=config["api_key"],
base_url=config["base_url"],
temperature=temperature,
is_expert=is_expert,
)
elif provider == "openrouter":
return create_openrouter_client(
model_name=model_name,
api_key=config["api_key"],
temperature=temperature,
is_expert=is_expert,
)
elif provider == "openai":
return ChatOpenAI( return ChatOpenAI(
api_key=os.getenv("OPENAI_API_KEY"), api_key=config["api_key"],
model=model_name, model=model_name,
**({"temperature": temperature} if temperature is not None else {}) **({"temperature": temperature} if temperature is not None else {}),
) )
elif provider == "anthropic": elif provider == "anthropic":
return ChatAnthropic( return ChatAnthropic(
api_key=os.getenv("ANTHROPIC_API_KEY"), api_key=config["api_key"],
model_name=model_name, model_name=model_name,
**({"temperature": temperature} if temperature is not None else {}) **({"temperature": temperature} if temperature is not None else {}),
)
elif provider == "openrouter":
return ChatOpenAI(
api_key=os.getenv("OPENROUTER_API_KEY"),
base_url="https://openrouter.ai/api/v1",
model=model_name,
**({"temperature": temperature} if temperature is not None else {})
) )
elif provider == "openai-compatible": elif provider == "openai-compatible":
return ChatOpenAI( return ChatOpenAI(
api_key=os.getenv("OPENAI_API_KEY"), api_key=config["api_key"],
base_url=os.getenv("OPENAI_API_BASE"), base_url=config["base_url"],
temperature=temperature if temperature is not None else 0.3, temperature=temperature if temperature is not None else 0.3,
model=model_name, model=model_name,
) )
elif provider == "gemini": elif provider == "gemini":
return ChatGoogleGenerativeAI( return ChatGoogleGenerativeAI(
api_key=os.getenv("GEMINI_API_KEY"), api_key=config["api_key"],
model=model_name, model=model_name,
**({"temperature": temperature} if temperature is not None else {}) **({"temperature": temperature} if temperature is not None else {}),
) )
else: else:
raise ValueError(f"Unsupported provider: {provider}") raise ValueError(f"Unsupported provider: {provider}")
def initialize_expert_llm(provider: str = "openai", model_name: str = "o1") -> BaseChatModel:
"""Initialize an expert language model client based on the specified provider and model.
Note: Environment variables must be validated before calling this function. def initialize_llm(
Use validate_environment() to ensure all required variables are set. provider: str, model_name: str, temperature: float | None = None
) -> BaseChatModel:
"""Initialize a language model client based on the specified provider and model."""
return create_llm_client(provider, model_name, temperature, is_expert=False)
Args:
provider: The LLM provider to use ('openai', 'anthropic', 'openrouter', 'openai-compatible', 'gemini').
Defaults to 'openai'.
model_name: Name of the model to use. Defaults to 'o1'.
Returns: def initialize_expert_llm(
BaseChatModel: Configured expert language model client provider: str = "openai", model_name: str = "o1"
) -> BaseChatModel:
Raises: """Initialize an expert language model client based on the specified provider and model."""
ValueError: If the provider is not supported return create_llm_client(provider, model_name, temperature=None, is_expert=True)
"""
if provider == "openai":
return ChatOpenAI(
api_key=os.getenv("EXPERT_OPENAI_API_KEY"),
model=model_name,
)
elif provider == "anthropic":
return ChatAnthropic(
api_key=os.getenv("EXPERT_ANTHROPIC_API_KEY"),
model_name=model_name,
)
elif provider == "openrouter":
return ChatOpenAI(
api_key=os.getenv("EXPERT_OPENROUTER_API_KEY"),
base_url="https://openrouter.ai/api/v1",
model=model_name,
)
elif provider == "openai-compatible":
return ChatOpenAI(
api_key=os.getenv("EXPERT_OPENAI_API_KEY"),
base_url=os.getenv("EXPERT_OPENAI_API_BASE"),
model=model_name,
)
elif provider == "gemini":
return ChatGoogleGenerativeAI(
api_key=os.getenv("EXPERT_GEMINI_API_KEY"),
model=model_name,
)
else:
raise ValueError(f"Unsupported provider: {provider}")

View File

@ -241,6 +241,10 @@ models_tokens = {
"deepseek": { "deepseek": {
"deepseek-chat": 28672, "deepseek-chat": 28672,
"deepseek-coder": 16384, "deepseek-coder": 16384,
"deepseek-reasoner": 65536,
},
"openrouter": {
"deepseek/deepseek-r1": 65536,
}, },
"ernie": { "ernie": {
"ernie-bot-turbo": 4096, "ernie-bot-turbo": 4096,

View File

@ -239,6 +239,31 @@ class GeminiStrategy(ProviderStrategy):
return ValidationResult(valid=len(missing) == 0, missing_vars=missing) return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
class DeepSeekStrategy(ProviderStrategy):
"""DeepSeek provider validation strategy."""
def validate(self, args: Optional[Any] = None) -> ValidationResult:
"""Validate DeepSeek environment variables."""
missing = []
if args and hasattr(args, 'expert_provider') and args.expert_provider == 'deepseek':
key = os.environ.get('EXPERT_DEEPSEEK_API_KEY')
if not key or key == '':
# Try to copy from base if not set
base_key = os.environ.get('DEEPSEEK_API_KEY')
if base_key:
os.environ['EXPERT_DEEPSEEK_API_KEY'] = base_key
key = base_key
if not key:
missing.append('EXPERT_DEEPSEEK_API_KEY environment variable is not set')
else:
key = os.environ.get('DEEPSEEK_API_KEY')
if not key:
missing.append('DEEPSEEK_API_KEY environment variable is not set')
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
class OllamaStrategy(ProviderStrategy): class OllamaStrategy(ProviderStrategy):
"""Ollama provider validation strategy.""" """Ollama provider validation strategy."""
@ -272,7 +297,8 @@ class ProviderFactory:
'anthropic': AnthropicStrategy(), 'anthropic': AnthropicStrategy(),
'openrouter': OpenRouterStrategy(), 'openrouter': OpenRouterStrategy(),
'gemini': GeminiStrategy(), 'gemini': GeminiStrategy(),
'ollama': OllamaStrategy() 'ollama': OllamaStrategy(),
'deepseek': DeepSeekStrategy()
} }
strategy = strategies.get(provider) strategy = strategies.get(provider)
return strategy return strategy

View File

@ -1,123 +1,147 @@
import os import os
import os
from typing import Dict, List, Any, Union, Optional, Set from typing import Dict, List, Any, Union, Optional, Set
from typing_extensions import TypedDict from typing_extensions import TypedDict
class WorkLogEntry(TypedDict):
timestamp: str
event: str
from rich.console import Console from rich.console import Console
from rich.markdown import Markdown from rich.markdown import Markdown
from rich.panel import Panel from rich.panel import Panel
from langchain_core.tools import tool from langchain_core.tools import tool
class WorkLogEntry(TypedDict):
timestamp: str
event: str
class SnippetInfo(TypedDict): class SnippetInfo(TypedDict):
"""Type definition for source code snippet information""" """Type definition for source code snippet information"""
filepath: str filepath: str
line_number: int line_number: int
snippet: str snippet: str
description: Optional[str] description: Optional[str]
console = Console() console = Console()
# Global memory store # Global memory store
_global_memory: Dict[str, Union[List[Any], Dict[int, str], Dict[int, SnippetInfo], int, Set[str], bool, str, int, List[WorkLogEntry]]] = { _global_memory: Dict[
'research_notes': [], str,
'plans': [], Union[
'tasks': {}, # Dict[int, str] - ID to task mapping List[Any],
'task_completed': False, # Flag indicating if task is complete Dict[int, str],
'completion_message': '', # Message explaining completion Dict[int, SnippetInfo],
'task_id_counter': 1, # Counter for generating unique task IDs int,
'key_facts': {}, # Dict[int, str] - ID to fact mapping Set[str],
'key_fact_id_counter': 1, # Counter for generating unique fact IDs bool,
'key_snippets': {}, # Dict[int, SnippetInfo] - ID to snippet mapping str,
'key_snippet_id_counter': 1, # Counter for generating unique snippet IDs int,
'implementation_requested': False, List[WorkLogEntry],
'related_files': {}, # Dict[int, str] - ID to filepath mapping ],
'related_file_id_counter': 1, # Counter for generating unique file IDs ] = {
'plan_completed': False, "research_notes": [],
'agent_depth': 0, "plans": [],
'work_log': [] # List[WorkLogEntry] - Timestamped work events "tasks": {}, # Dict[int, str] - ID to task mapping
"task_completed": False, # Flag indicating if task is complete
"completion_message": "", # Message explaining completion
"task_id_counter": 1, # Counter for generating unique task IDs
"key_facts": {}, # Dict[int, str] - ID to fact mapping
"key_fact_id_counter": 1, # Counter for generating unique fact IDs
"key_snippets": {}, # Dict[int, SnippetInfo] - ID to snippet mapping
"key_snippet_id_counter": 1, # Counter for generating unique snippet IDs
"implementation_requested": False,
"related_files": {}, # Dict[int, str] - ID to filepath mapping
"related_file_id_counter": 1, # Counter for generating unique file IDs
"plan_completed": False,
"agent_depth": 0,
"work_log": [], # List[WorkLogEntry] - Timestamped work events
} }
@tool("emit_research_notes") @tool("emit_research_notes")
def emit_research_notes(notes: str) -> str: def emit_research_notes(notes: str) -> str:
"""Store research notes in global memory. """Store research notes in global memory.
Args: Args:
notes: REQUIRED The research notes to store notes: REQUIRED The research notes to store
Returns: Returns:
The stored notes The stored notes
""" """
_global_memory['research_notes'].append(notes) _global_memory["research_notes"].append(notes)
console.print(Panel(Markdown(notes), title="🔍 Research Notes")) console.print(Panel(Markdown(notes), title="🔍 Research Notes"))
return notes return notes
@tool("emit_plan") @tool("emit_plan")
def emit_plan(plan: str) -> str: def emit_plan(plan: str) -> str:
"""Store a plan step in global memory. """Store a plan step in global memory.
Args: Args:
plan: The plan step to store (markdown format; be clear, complete, use newlines, and use as many tokens as you need) plan: The plan step to store (markdown format; be clear, complete, use newlines, and use as many tokens as you need)
Returns: Returns:
The stored plan The stored plan
""" """
_global_memory['plans'].append(plan) _global_memory["plans"].append(plan)
console.print(Panel(Markdown(plan), title="📋 Plan")) console.print(Panel(Markdown(plan), title="📋 Plan"))
log_work_event(f"Added plan step:\n\n{plan}") log_work_event(f"Added plan step:\n\n{plan}")
return plan return plan
@tool("emit_task") @tool("emit_task")
def emit_task(task: str) -> str: def emit_task(task: str) -> str:
"""Store a task in global memory. """Store a task in global memory.
Args: Args:
task: The task to store task: The task to store
Returns: Returns:
String confirming task storage with ID number String confirming task storage with ID number
""" """
# Get and increment task ID # Get and increment task ID
task_id = _global_memory['task_id_counter'] task_id = _global_memory["task_id_counter"]
_global_memory['task_id_counter'] += 1 _global_memory["task_id_counter"] += 1
# Store task with ID # Store task with ID
_global_memory['tasks'][task_id] = task _global_memory["tasks"][task_id] = task
console.print(Panel(Markdown(task), title=f"✅ Task #{task_id}")) console.print(Panel(Markdown(task), title=f"✅ Task #{task_id}"))
log_work_event(f"Task #{task_id} added:\n\n{task}") log_work_event(f"Task #{task_id} added:\n\n{task}")
return f"Task #{task_id} stored." return f"Task #{task_id} stored."
@tool("emit_key_facts") @tool("emit_key_facts")
def emit_key_facts(facts: List[str]) -> str: def emit_key_facts(facts: List[str]) -> str:
"""Store multiple key facts about the project or current task in global memory. """Store multiple key facts about the project or current task in global memory.
Args: Args:
facts: List of key facts to store facts: List of key facts to store
Returns: Returns:
List of stored fact confirmation messages List of stored fact confirmation messages
""" """
results = [] results = []
for fact in facts: for fact in facts:
# Get and increment fact ID # Get and increment fact ID
fact_id = _global_memory['key_fact_id_counter'] fact_id = _global_memory["key_fact_id_counter"]
_global_memory['key_fact_id_counter'] += 1 _global_memory["key_fact_id_counter"] += 1
# Store fact with ID # Store fact with ID
_global_memory['key_facts'][fact_id] = fact _global_memory["key_facts"][fact_id] = fact
# Display panel with ID # Display panel with ID
console.print(Panel(Markdown(fact), title=f"💡 Key Fact #{fact_id}", border_style="bright_cyan")) console.print(
Panel(
Markdown(fact),
title=f"💡 Key Fact #{fact_id}",
border_style="bright_cyan",
)
)
# Add result message # Add result message
results.append(f"Stored fact #{fact_id}: {fact}") results.append(f"Stored fact #{fact_id}: {fact}")
log_work_event(f"Stored {len(facts)} key facts.") log_work_event(f"Stored {len(facts)} key facts.")
return "Facts stored." return "Facts stored."
@ -125,50 +149,54 @@ def emit_key_facts(facts: List[str]) -> str:
def delete_key_facts(fact_ids: List[int]) -> str: def delete_key_facts(fact_ids: List[int]) -> str:
"""Delete multiple key facts from global memory by their IDs. """Delete multiple key facts from global memory by their IDs.
Silently skips any IDs that don't exist. Silently skips any IDs that don't exist.
Args: Args:
fact_ids: List of fact IDs to delete fact_ids: List of fact IDs to delete
Returns: Returns:
List of success messages for deleted facts List of success messages for deleted facts
""" """
results = [] results = []
for fact_id in fact_ids: for fact_id in fact_ids:
if fact_id in _global_memory['key_facts']: if fact_id in _global_memory["key_facts"]:
# Delete the fact # Delete the fact
deleted_fact = _global_memory['key_facts'].pop(fact_id) deleted_fact = _global_memory["key_facts"].pop(fact_id)
success_msg = f"Successfully deleted fact #{fact_id}: {deleted_fact}" success_msg = f"Successfully deleted fact #{fact_id}: {deleted_fact}"
console.print(Panel(Markdown(success_msg), title="Fact Deleted", border_style="green")) console.print(
Panel(Markdown(success_msg), title="Fact Deleted", border_style="green")
)
results.append(success_msg) results.append(success_msg)
log_work_event(f"Deleted facts {fact_ids}.") log_work_event(f"Deleted facts {fact_ids}.")
return "Facts deleted." return "Facts deleted."
@tool("delete_tasks") @tool("delete_tasks")
def delete_tasks(task_ids: List[int]) -> str: def delete_tasks(task_ids: List[int]) -> str:
"""Delete multiple tasks from global memory by their IDs. """Delete multiple tasks from global memory by their IDs.
Silently skips any IDs that don't exist. Silently skips any IDs that don't exist.
Args: Args:
task_ids: List of task IDs to delete task_ids: List of task IDs to delete
Returns: Returns:
Confirmation message Confirmation message
""" """
results = [] results = []
for task_id in task_ids: for task_id in task_ids:
if task_id in _global_memory['tasks']: if task_id in _global_memory["tasks"]:
# Delete the task # Delete the task
deleted_task = _global_memory['tasks'].pop(task_id) deleted_task = _global_memory["tasks"].pop(task_id)
success_msg = f"Successfully deleted task #{task_id}: {deleted_task}" success_msg = f"Successfully deleted task #{task_id}: {deleted_task}"
console.print(Panel(Markdown(success_msg), console.print(
title="Task Deleted", Panel(Markdown(success_msg), title="Task Deleted", border_style="green")
border_style="green")) )
results.append(success_msg) results.append(success_msg)
log_work_event(f"Deleted tasks {task_ids}.") log_work_event(f"Deleted tasks {task_ids}.")
return "Tasks deleted." return "Tasks deleted."
@tool("request_implementation") @tool("request_implementation")
def request_implementation() -> str: def request_implementation() -> str:
"""Request that implementation proceed after research/planning. """Request that implementation proceed after research/planning.
@ -178,46 +206,47 @@ def request_implementation() -> str:
Do you need to request research subtasks first? Do you need to request research subtasks first?
Have you run relevant unit tests, if they exist, to get a baseline (this can be a subtask)? Have you run relevant unit tests, if they exist, to get a baseline (this can be a subtask)?
Do you need to crawl deeper to find all related files and symbols? Do you need to crawl deeper to find all related files and symbols?
Returns: Returns:
Empty string Empty string
""" """
_global_memory['implementation_requested'] = True _global_memory["implementation_requested"] = True
console.print(Panel("🚀 Implementation Requested", style="yellow", padding=0)) console.print(Panel("🚀 Implementation Requested", style="yellow", padding=0))
log_work_event("Implementation requested.") log_work_event("Implementation requested.")
return "" return ""
@tool("emit_key_snippets") @tool("emit_key_snippets")
def emit_key_snippets(snippets: List[SnippetInfo]) -> str: def emit_key_snippets(snippets: List[SnippetInfo]) -> str:
"""Store multiple key source code snippets in global memory. """Store multiple key source code snippets in global memory.
Automatically adds the filepaths of the snippets to related files. Automatically adds the filepaths of the snippets to related files.
This is for **existing**, or **just-written** files, not for things to be created in the future. This is for **existing**, or **just-written** files, not for things to be created in the future.
Args: Args:
snippets: REQUIRED List of snippet information dictionaries containing: snippets: REQUIRED List of snippet information dictionaries containing:
- filepath: Path to the source file - filepath: Path to the source file
- line_number: Line number where the snippet starts - line_number: Line number where the snippet starts
- snippet: The source code snippet text - snippet: The source code snippet text
- description: Optional description of the significance - description: Optional description of the significance
Returns: Returns:
List of stored snippet confirmation messages List of stored snippet confirmation messages
""" """
# First collect unique filepaths to add as related files # First collect unique filepaths to add as related files
emit_related_files.invoke({"files": [snippet_info['filepath'] for snippet_info in snippets]}) emit_related_files.invoke(
{"files": [snippet_info["filepath"] for snippet_info in snippets]}
)
results = [] results = []
for snippet_info in snippets: for snippet_info in snippets:
# Get and increment snippet ID # Get and increment snippet ID
snippet_id = _global_memory['key_snippet_id_counter'] snippet_id = _global_memory["key_snippet_id_counter"]
_global_memory['key_snippet_id_counter'] += 1 _global_memory["key_snippet_id_counter"] += 1
# Store snippet info # Store snippet info
_global_memory['key_snippets'][snippet_id] = snippet_info _global_memory["key_snippets"][snippet_id] = snippet_info
# Format display text as markdown # Format display text as markdown
display_text = [ display_text = [
f"**Source Location**:", f"**Source Location**:",
@ -226,153 +255,170 @@ def emit_key_snippets(snippets: List[SnippetInfo]) -> str:
"", # Empty line before code block "", # Empty line before code block
"**Code**:", "**Code**:",
"```python", "```python",
snippet_info['snippet'].rstrip(), # Remove trailing whitespace snippet_info["snippet"].rstrip(), # Remove trailing whitespace
"```" "```",
] ]
if snippet_info['description']: if snippet_info["description"]:
display_text.extend(["", "**Description**:", snippet_info['description']]) display_text.extend(["", "**Description**:", snippet_info["description"]])
# Display panel # Display panel
console.print(Panel(Markdown("\n".join(display_text)), console.print(
title=f"📝 Key Snippet #{snippet_id}", Panel(
border_style="bright_cyan")) Markdown("\n".join(display_text)),
title=f"📝 Key Snippet #{snippet_id}",
border_style="bright_cyan",
)
)
results.append(f"Stored snippet #{snippet_id}") results.append(f"Stored snippet #{snippet_id}")
log_work_event(f"Stored {len(snippets)} code snippets.") log_work_event(f"Stored {len(snippets)} code snippets.")
return "Snippets stored." return "Snippets stored."
@tool("delete_key_snippets")
@tool("delete_key_snippets")
def delete_key_snippets(snippet_ids: List[int]) -> str: def delete_key_snippets(snippet_ids: List[int]) -> str:
"""Delete multiple key snippets from global memory by their IDs. """Delete multiple key snippets from global memory by their IDs.
Silently skips any IDs that don't exist. Silently skips any IDs that don't exist.
Args: Args:
snippet_ids: List of snippet IDs to delete snippet_ids: List of snippet IDs to delete
Returns: Returns:
List of success messages for deleted snippets List of success messages for deleted snippets
""" """
results = [] results = []
for snippet_id in snippet_ids: for snippet_id in snippet_ids:
if snippet_id in _global_memory['key_snippets']: if snippet_id in _global_memory["key_snippets"]:
# Delete the snippet # Delete the snippet
deleted_snippet = _global_memory['key_snippets'].pop(snippet_id) deleted_snippet = _global_memory["key_snippets"].pop(snippet_id)
success_msg = f"Successfully deleted snippet #{snippet_id} from {deleted_snippet['filepath']}" success_msg = f"Successfully deleted snippet #{snippet_id} from {deleted_snippet['filepath']}"
console.print(Panel(Markdown(success_msg), console.print(
title="Snippet Deleted", Panel(
border_style="green")) Markdown(success_msg), title="Snippet Deleted", border_style="green"
)
)
results.append(success_msg) results.append(success_msg)
log_work_event(f"Deleted snippets {snippet_ids}.") log_work_event(f"Deleted snippets {snippet_ids}.")
return "Snippets deleted." return "Snippets deleted."
@tool("swap_task_order") @tool("swap_task_order")
def swap_task_order(id1: int, id2: int) -> str: def swap_task_order(id1: int, id2: int) -> str:
"""Swap the order of two tasks in global memory by their IDs. """Swap the order of two tasks in global memory by their IDs.
Args: Args:
id1: First task ID id1: First task ID
id2: Second task ID id2: Second task ID
Returns: Returns:
Success or error message depending on outcome Success or error message depending on outcome
""" """
# Validate IDs are different # Validate IDs are different
if id1 == id2: if id1 == id2:
return "Cannot swap task with itself" return "Cannot swap task with itself"
# Validate both IDs exist # Validate both IDs exist
if id1 not in _global_memory['tasks'] or id2 not in _global_memory['tasks']: if id1 not in _global_memory["tasks"] or id2 not in _global_memory["tasks"]:
return "Invalid task ID(s)" return "Invalid task ID(s)"
# Swap the tasks # Swap the tasks
_global_memory['tasks'][id1], _global_memory['tasks'][id2] = \ _global_memory["tasks"][id1], _global_memory["tasks"][id2] = (
_global_memory['tasks'][id2], _global_memory['tasks'][id1] _global_memory["tasks"][id2],
_global_memory["tasks"][id1],
)
# Display what was swapped # Display what was swapped
console.print(Panel( console.print(
Markdown(f"Swapped:\n- Task #{id1} ↔️ Task #{id2}"), Panel(
title="🔄 Tasks Reordered", Markdown(f"Swapped:\n- Task #{id1} ↔️ Task #{id2}"),
border_style="green" title="🔄 Tasks Reordered",
)) border_style="green",
)
)
return "Tasks swapped." return "Tasks swapped."
@tool("one_shot_completed") @tool("one_shot_completed")
def one_shot_completed(message: str) -> str: def one_shot_completed(message: str) -> str:
"""Signal that a one-shot task has been completed and execution should stop. """Signal that a one-shot task has been completed and execution should stop.
Only call this if you have already **fully** completed the original request. Only call this if you have already **fully** completed the original request.
Args: Args:
message: Completion message to display message: Completion message to display
""" """
if _global_memory.get('implementation_requested', False): if _global_memory.get("implementation_requested", False):
return "Cannot complete in one shot - implementation was requested" return "Cannot complete in one shot - implementation was requested"
_global_memory['task_completed'] = True _global_memory["task_completed"] = True
_global_memory['completion_message'] = message _global_memory["completion_message"] = message
console.print(Panel(Markdown(message), title="✅ Task Completed")) console.print(Panel(Markdown(message), title="✅ Task Completed"))
log_work_event(f"Task completed\n\n{message}") log_work_event(f"Task completed\n\n{message}")
return "Completion noted." return "Completion noted."
@tool("task_completed") @tool("task_completed")
def task_completed(message: str) -> str: def task_completed(message: str) -> str:
"""Mark the current task as completed with a completion message. """Mark the current task as completed with a completion message.
Args: Args:
message: Message explaining how/why the task is complete message: Message explaining how/why the task is complete
Returns: Returns:
The completion message The completion message
""" """
_global_memory['task_completed'] = True _global_memory["task_completed"] = True
_global_memory['completion_message'] = message _global_memory["completion_message"] = message
console.print(Panel(Markdown(message), title="✅ Task Completed")) console.print(Panel(Markdown(message), title="✅ Task Completed"))
return "Completion noted." return "Completion noted."
@tool("plan_implementation_completed") @tool("plan_implementation_completed")
def plan_implementation_completed(message: str) -> str: def plan_implementation_completed(message: str) -> str:
"""Mark the entire implementation plan as completed. """Mark the entire implementation plan as completed.
Args: Args:
message: Message explaining how the implementation plan was completed message: Message explaining how the implementation plan was completed
Returns: Returns:
Confirmation message Confirmation message
""" """
_global_memory['plan_completed'] = True _global_memory["plan_completed"] = True
_global_memory['completion_message'] = message _global_memory["completion_message"] = message
_global_memory['tasks'].clear() # Clear task list when plan is completed _global_memory["tasks"].clear() # Clear task list when plan is completed
_global_memory['task_id_counter'] = 1 _global_memory["task_id_counter"] = 1
console.print(Panel(Markdown(message), title="✅ Plan Executed")) console.print(Panel(Markdown(message), title="✅ Plan Executed"))
log_work_event(f"Plan execution completed:\n\n{message}") log_work_event(f"Plan execution completed:\n\n{message}")
return "Plan completion noted and task list cleared." return "Plan completion noted and task list cleared."
def get_related_files() -> List[str]: def get_related_files() -> List[str]:
"""Get the current list of related files. """Get the current list of related files.
Returns: Returns:
List of formatted strings in the format 'ID#X path/to/file.py' List of formatted strings in the format 'ID#X path/to/file.py'
""" """
files = _global_memory['related_files'] files = _global_memory["related_files"]
return [f"ID#{file_id} {filepath}" for file_id, filepath in sorted(files.items())] return [f"ID#{file_id} {filepath}" for file_id, filepath in sorted(files.items())]
@tool("emit_related_files") @tool("emit_related_files")
def emit_related_files(files: List[str]) -> str: def emit_related_files(files: List[str]) -> str:
"""Store multiple related files that tools should work with. """Store multiple related files that tools should work with.
Args: Args:
files: List of file paths to add files: List of file paths to add
Returns: Returns:
Formatted string containing file IDs and paths for all processed files Formatted string containing file IDs and paths for all processed files
""" """
results = [] results = []
added_files = [] added_files = []
invalid_paths = [] invalid_paths = []
# Process files # Process files
for file in files: for file in files:
# First check if path exists # First check if path exists
@ -380,111 +426,115 @@ def emit_related_files(files: List[str]) -> str:
invalid_paths.append(file) invalid_paths.append(file)
results.append(f"Error: Path '{file}' does not exist") results.append(f"Error: Path '{file}' does not exist")
continue continue
# Then check if it's a directory # Then check if it's a directory
if os.path.isdir(file): if os.path.isdir(file):
invalid_paths.append(file) invalid_paths.append(file)
results.append(f"Error: Path '{file}' is a directory, not a file") results.append(f"Error: Path '{file}' is a directory, not a file")
continue continue
# Finally validate it's a regular file # Finally validate it's a regular file
if not os.path.isfile(file): if not os.path.isfile(file):
invalid_paths.append(file) invalid_paths.append(file)
results.append(f"Error: Path '{file}' exists but is not a regular file") results.append(f"Error: Path '{file}' exists but is not a regular file")
continue continue
# Check if file path already exists in values # Check if file path already exists in values
existing_id = None existing_id = None
for fid, fpath in _global_memory['related_files'].items(): for fid, fpath in _global_memory["related_files"].items():
if fpath == file: if fpath == file:
existing_id = fid existing_id = fid
break break
if existing_id is not None: if existing_id is not None:
# File exists, use existing ID # File exists, use existing ID
results.append(f"File ID #{existing_id}: {file}") results.append(f"File ID #{existing_id}: {file}")
else: else:
# New file, assign new ID # New file, assign new ID
file_id = _global_memory['related_file_id_counter'] file_id = _global_memory["related_file_id_counter"]
_global_memory['related_file_id_counter'] += 1 _global_memory["related_file_id_counter"] += 1
# Store file with ID # Store file with ID
_global_memory['related_files'][file_id] = file _global_memory["related_files"][file_id] = file
added_files.append((file_id, file)) added_files.append((file_id, file))
results.append(f"File ID #{file_id}: {file}") results.append(f"File ID #{file_id}: {file}")
# Rich output - single consolidated panel # Rich output - single consolidated panel
if added_files: if added_files:
files_added_md = '\n'.join(f"- `{file}`" for id, file in added_files) files_added_md = "\n".join(f"- `{file}`" for id, file in added_files)
md_content = f"**Files Noted:**\n{files_added_md}" md_content = f"**Files Noted:**\n{files_added_md}"
console.print(Panel(Markdown(md_content), console.print(
title="📁 Related Files Noted", Panel(
border_style="green")) Markdown(md_content),
title="📁 Related Files Noted",
return '\n'.join(results) border_style="green",
)
)
return "\n".join(results)
def log_work_event(event: str) -> str: def log_work_event(event: str) -> str:
"""Add timestamped entry to work log. """Add timestamped entry to work log.
Internal function used to track major events during agent execution. Internal function used to track major events during agent execution.
Each entry is stored with an ISO format timestamp. Each entry is stored with an ISO format timestamp.
Args: Args:
event: Description of the event to log event: Description of the event to log
Returns: Returns:
Confirmation message Confirmation message
Note: Note:
Entries can be retrieved with get_work_log() as markdown formatted text. Entries can be retrieved with get_work_log() as markdown formatted text.
""" """
from datetime import datetime from datetime import datetime
entry = WorkLogEntry(
timestamp=datetime.now().isoformat(), entry = WorkLogEntry(timestamp=datetime.now().isoformat(), event=event)
event=event _global_memory["work_log"].append(entry)
)
_global_memory['work_log'].append(entry)
return f"Event logged: {event}" return f"Event logged: {event}"
def get_work_log() -> str: def get_work_log() -> str:
"""Return formatted markdown of work log entries. """Return formatted markdown of work log entries.
Returns: Returns:
Markdown formatted text with timestamps as headings and events as content, Markdown formatted text with timestamps as headings and events as content,
or 'No work log entries' if the log is empty. or 'No work log entries' if the log is empty.
Example: Example:
## 2024-12-23T11:39:10 ## 2024-12-23T11:39:10
Task #1 added: Create login form Task #1 added: Create login form
""" """
if not _global_memory['work_log']: if not _global_memory["work_log"]:
return "No work log entries" return "No work log entries"
entries = [] entries = []
for entry in _global_memory['work_log']: for entry in _global_memory["work_log"]:
entries.extend([ entries.extend(
f"## {entry['timestamp']}", [
"", f"## {entry['timestamp']}",
entry['event'], "",
"" # Blank line between entries entry["event"],
]) "", # Blank line between entries
]
)
return "\n".join(entries).rstrip() # Remove trailing newline return "\n".join(entries).rstrip() # Remove trailing newline
def reset_work_log() -> str: def reset_work_log() -> str:
"""Clear the work log. """Clear the work log.
Returns: Returns:
Confirmation message Confirmation message
Note: Note:
This permanently removes all work log entries. The operation cannot be undone. This permanently removes all work log entries. The operation cannot be undone.
""" """
_global_memory['work_log'].clear() _global_memory["work_log"].clear()
return "Work log cleared" return "Work log cleared"
@ -492,37 +542,44 @@ def reset_work_log() -> str:
def deregister_related_files(file_ids: List[int]) -> str: def deregister_related_files(file_ids: List[int]) -> str:
"""Delete multiple related files from global memory by their IDs. """Delete multiple related files from global memory by their IDs.
Silently skips any IDs that don't exist. Silently skips any IDs that don't exist.
Args: Args:
file_ids: List of file IDs to delete file_ids: List of file IDs to delete
Returns: Returns:
Success message string Success message string
""" """
results = [] results = []
for file_id in file_ids: for file_id in file_ids:
if file_id in _global_memory['related_files']: if file_id in _global_memory["related_files"]:
# Delete the file reference # Delete the file reference
deleted_file = _global_memory['related_files'].pop(file_id) deleted_file = _global_memory["related_files"].pop(file_id)
success_msg = f"Successfully removed related file #{file_id}: {deleted_file}" success_msg = (
console.print(Panel(Markdown(success_msg), f"Successfully removed related file #{file_id}: {deleted_file}"
title="File Reference Removed", )
border_style="green")) console.print(
Panel(
Markdown(success_msg),
title="File Reference Removed",
border_style="green",
)
)
results.append(success_msg) results.append(success_msg)
return "File references removed." return "File references removed."
def get_memory_value(key: str) -> str: def get_memory_value(key: str) -> str:
"""Get a value from global memory. """Get a value from global memory.
Different memory types return different formats: Different memory types return different formats:
- key_facts: Returns numbered list of facts in format '#ID: fact' - key_facts: Returns numbered list of facts in format '#ID: fact'
- key_snippets: Returns formatted snippets with file path, line number and content - key_snippets: Returns formatted snippets with file path, line number and content
- All other types: Returns newline-separated list of values - All other types: Returns newline-separated list of values
Args: Args:
key: The key to get from memory key: The key to get from memory
Returns: Returns:
String representation of the memory values: String representation of the memory values:
- For key_facts: '#ID: fact' format, one per line - For key_facts: '#ID: fact' format, one per line
@ -530,23 +587,25 @@ def get_memory_value(key: str) -> str:
- For other types: One value per line - For other types: One value per line
""" """
values = _global_memory.get(key, []) values = _global_memory.get(key, [])
if key == 'key_facts': if key == "key_facts":
# For empty dict, return empty string # For empty dict, return empty string
if not values: if not values:
return "" return ""
# Sort by ID for consistent output and format as markdown sections # Sort by ID for consistent output and format as markdown sections
facts = [] facts = []
for k, v in sorted(values.items()): for k, v in sorted(values.items()):
facts.extend([ facts.extend(
f"## 🔑 Key Fact #{k}", [
"", # Empty line for better markdown spacing f"## 🔑 Key Fact #{k}",
v, "", # Empty line for better markdown spacing
"" # Empty line between facts v,
]) "", # Empty line between facts
]
)
return "\n".join(facts).rstrip() # Remove trailing newline return "\n".join(facts).rstrip() # Remove trailing newline
if key == 'key_snippets': if key == "key_snippets":
if not values: if not values:
return "" return ""
# Format each snippet with file info and content using markdown # Format each snippet with file info and content using markdown
@ -555,26 +614,25 @@ def get_memory_value(key: str) -> str:
snippet_text = [ snippet_text = [
f"## 📝 Code Snippet #{k}", f"## 📝 Code Snippet #{k}",
"", # Empty line for better markdown spacing "", # Empty line for better markdown spacing
f"**Source Location**:", "**Source Location**:",
f"- File: `{v['filepath']}`", f"- File: `{v['filepath']}`",
f"- Line: `{v['line_number']}`", f"- Line: `{v['line_number']}`",
"", # Empty line before code block "", # Empty line before code block
"**Code**:", "**Code**:",
"```python", "```python",
v['snippet'].rstrip(), # Remove trailing whitespace v["snippet"].rstrip(), # Remove trailing whitespace
"```" "```",
] ]
if v['description']: if v["description"]:
# Add empty line and description # Add empty line and description
snippet_text.extend(["", "**Description**:", v['description']]) snippet_text.extend(["", "**Description**:", v["description"]])
snippets.append("\n".join(snippet_text)) snippets.append("\n".join(snippet_text))
return "\n\n".join(snippets) return "\n\n".join(snippets)
if key == 'work_log': if key == "work_log":
if not values: if not values:
return "" return ""
entries = [f"## {entry['timestamp']}\n{entry['event']}" entries = [f"## {entry['timestamp']}\n{entry['event']}" for entry in values]
for entry in values]
return "\n\n".join(entries) return "\n\n".join(entries)
# For other types (lists), join with newlines # For other types (lists), join with newlines

View File

@ -9,7 +9,13 @@ from dataclasses import dataclass
from ra_aid.agents.ciayn_agent import CiaynAgent from ra_aid.agents.ciayn_agent import CiaynAgent
from ra_aid.env import validate_environment from ra_aid.env import validate_environment
from ra_aid.llm import initialize_llm, initialize_expert_llm from ra_aid.llm import (
initialize_llm,
initialize_expert_llm,
get_env_var,
get_provider_config,
create_llm_client
)
@pytest.fixture @pytest.fixture
def clean_env(monkeypatch): def clean_env(monkeypatch):
@ -39,7 +45,8 @@ def test_initialize_expert_defaults(clean_env, mock_openai, monkeypatch):
mock_openai.assert_called_once_with( mock_openai.assert_called_once_with(
api_key="test-key", api_key="test-key",
model="o1" model="o1",
temperature=0
) )
def test_initialize_expert_openai_custom(clean_env, mock_openai, monkeypatch): def test_initialize_expert_openai_custom(clean_env, mock_openai, monkeypatch):
@ -49,7 +56,8 @@ def test_initialize_expert_openai_custom(clean_env, mock_openai, monkeypatch):
mock_openai.assert_called_once_with( mock_openai.assert_called_once_with(
api_key="test-key", api_key="test-key",
model="gpt-4-preview" model="gpt-4-preview",
temperature=0
) )
def test_initialize_expert_gemini(clean_env, mock_gemini, monkeypatch): def test_initialize_expert_gemini(clean_env, mock_gemini, monkeypatch):
@ -59,7 +67,8 @@ def test_initialize_expert_gemini(clean_env, mock_gemini, monkeypatch):
mock_gemini.assert_called_once_with( mock_gemini.assert_called_once_with(
api_key="test-key", api_key="test-key",
model="gemini-2.0-flash-thinking-exp-1219" model="gemini-2.0-flash-thinking-exp-1219",
temperature=0
) )
def test_initialize_expert_anthropic(clean_env, mock_anthropic, monkeypatch): def test_initialize_expert_anthropic(clean_env, mock_anthropic, monkeypatch):
@ -69,7 +78,8 @@ def test_initialize_expert_anthropic(clean_env, mock_anthropic, monkeypatch):
mock_anthropic.assert_called_once_with( mock_anthropic.assert_called_once_with(
api_key="test-key", api_key="test-key",
model_name="claude-3" model_name="claude-3",
temperature=0
) )
def test_initialize_expert_openrouter(clean_env, mock_openai, monkeypatch): def test_initialize_expert_openrouter(clean_env, mock_openai, monkeypatch):
@ -80,7 +90,8 @@ def test_initialize_expert_openrouter(clean_env, mock_openai, monkeypatch):
mock_openai.assert_called_once_with( mock_openai.assert_called_once_with(
api_key="test-key", api_key="test-key",
base_url="https://openrouter.ai/api/v1", base_url="https://openrouter.ai/api/v1",
model="models/mistral-large" model="models/mistral-large",
temperature=0
) )
def test_initialize_expert_openai_compatible(clean_env, mock_openai, monkeypatch): def test_initialize_expert_openai_compatible(clean_env, mock_openai, monkeypatch):
@ -92,7 +103,8 @@ def test_initialize_expert_openai_compatible(clean_env, mock_openai, monkeypatch
mock_openai.assert_called_once_with( mock_openai.assert_called_once_with(
api_key="test-key", api_key="test-key",
base_url="http://test-url", base_url="http://test-url",
model="local-model" model="local-model",
temperature=0
) )
def test_initialize_expert_unsupported_provider(clean_env): def test_initialize_expert_unsupported_provider(clean_env):
@ -311,45 +323,49 @@ def test_initialize_llm_cross_provider(clean_env, mock_openai, mock_anthropic, m
model="gemini-2.0-flash-thinking-exp-1219" model="gemini-2.0-flash-thinking-exp-1219"
) )
@dataclass
class Args:
"""Test arguments class."""
provider: str
expert_provider: str
model: str = None
expert_model: str = None
def test_environment_variable_precedence(clean_env, mock_openai, monkeypatch): def test_environment_variable_precedence(clean_env, mock_openai, monkeypatch):
"""Test environment variable precedence and fallback.""" """Test environment variable precedence and fallback."""
from ra_aid.env import validate_environment # Test get_env_var helper with fallback
from dataclasses import dataclass monkeypatch.setenv("TEST_KEY", "base-value")
monkeypatch.setenv("EXPERT_TEST_KEY", "expert-value")
@dataclass assert get_env_var("TEST_KEY") == "base-value"
class Args: assert get_env_var("TEST_KEY", expert=True) == "expert-value"
provider: str
expert_provider: str # Test fallback when expert value not set
model: str = None monkeypatch.delenv("EXPERT_TEST_KEY", raising=False)
expert_model: str = None assert get_env_var("TEST_KEY", expert=True) == "base-value"
# Test expert mode with explicit key # Test provider config
# Set up base environment first
monkeypatch.setenv("OPENAI_API_KEY", "base-key")
monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "expert-key") monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "expert-key")
monkeypatch.setenv("TAVILY_API_KEY", "tavily-key") config = get_provider_config("openai", is_expert=True)
monkeypatch.setenv("GEMINI_API_KEY", "gemini-key") assert config["api_key"] == "expert-key"
args = Args(provider="openai", expert_provider="openai")
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args) # Test LLM client creation with expert mode
assert expert_enabled llm = create_llm_client("openai", "o1", is_expert=True)
assert not expert_missing
assert web_enabled
assert not web_missing
llm = initialize_expert_llm()
mock_openai.assert_called_with( mock_openai.assert_called_with(
api_key="expert-key", api_key="expert-key",
model="o1" model="o1",
temperature=0
) )
# Test empty key validation # Test environment validation
monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "") monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "")
monkeypatch.delenv("OPENAI_API_KEY", raising=False) # Remove fallback monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.delenv("TAVILY_API_KEY", raising=False) # Remove web research monkeypatch.delenv("TAVILY_API_KEY", raising=False)
monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key") # Add for provider validation monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key")
monkeypatch.setenv("GEMINI_API_KEY", "gemini-key") # Add for provider validation monkeypatch.setenv("GEMINI_API_KEY", "gemini-key")
monkeypatch.setenv("ANTHROPIC_MODEL", "claude-3-haiku-20240307") # Add model for provider validation monkeypatch.setenv("ANTHROPIC_MODEL", "claude-3-haiku-20240307")
args = Args(provider="anthropic", expert_provider="openai") # Change base provider to avoid validation error
args = Args(provider="anthropic", expert_provider="openai")
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args) expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
assert not expert_enabled assert not expert_enabled
assert expert_missing assert expert_missing
@ -372,3 +388,122 @@ def mock_gemini():
with patch('ra_aid.llm.ChatGoogleGenerativeAI') as mock: with patch('ra_aid.llm.ChatGoogleGenerativeAI') as mock:
mock.return_value = Mock(spec=ChatGoogleGenerativeAI) mock.return_value = Mock(spec=ChatGoogleGenerativeAI)
yield mock yield mock
@pytest.fixture
def mock_deepseek_reasoner():
"""Mock ChatDeepseekReasoner for testing DeepSeek provider initialization."""
with patch('ra_aid.llm.ChatDeepseekReasoner') as mock:
mock.return_value = Mock()
yield mock
def test_initialize_deepseek(clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch):
"""Test DeepSeek provider initialization with different models."""
monkeypatch.setenv("DEEPSEEK_API_KEY", "test-key")
# Test with reasoner model
model = initialize_llm("deepseek", "deepseek-reasoner")
mock_deepseek_reasoner.assert_called_with(
api_key="test-key",
base_url="https://api.deepseek.com",
temperature=1,
model="deepseek-reasoner"
)
# Test with non-reasoner model
model = initialize_llm("deepseek", "deepseek-chat")
mock_openai.assert_called_with(
api_key="test-key",
base_url="https://api.deepseek.com",
temperature=1,
model="deepseek-chat"
)
def test_initialize_expert_deepseek(clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch):
"""Test expert DeepSeek provider initialization."""
monkeypatch.setenv("EXPERT_DEEPSEEK_API_KEY", "test-key")
# Test with reasoner model
model = initialize_expert_llm("deepseek", "deepseek-reasoner")
mock_deepseek_reasoner.assert_called_with(
api_key="test-key",
base_url="https://api.deepseek.com",
temperature=0,
model="deepseek-reasoner"
)
# Test with non-reasoner model
model = initialize_expert_llm("deepseek", "deepseek-chat")
mock_openai.assert_called_with(
api_key="test-key",
base_url="https://api.deepseek.com",
temperature=0,
model="deepseek-chat"
)
def test_initialize_openrouter_deepseek(clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch):
"""Test OpenRouter DeepSeek model initialization."""
monkeypatch.setenv("OPENROUTER_API_KEY", "test-key")
# Test with DeepSeek R1 model
model = initialize_llm("openrouter", "deepseek/deepseek-r1")
mock_deepseek_reasoner.assert_called_with(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
temperature=1,
model="deepseek/deepseek-r1"
)
# Test with non-DeepSeek model
model = initialize_llm("openrouter", "mistral/mistral-large")
mock_openai.assert_called_with(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="mistral/mistral-large"
)
def test_initialize_expert_openrouter_deepseek(clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch):
"""Test expert OpenRouter DeepSeek model initialization."""
monkeypatch.setenv("EXPERT_OPENROUTER_API_KEY", "test-key")
# Test with DeepSeek R1 model via create_llm_client
model = create_llm_client("openrouter", "deepseek/deepseek-r1", is_expert=True)
mock_deepseek_reasoner.assert_called_with(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
temperature=0,
model="deepseek/deepseek-r1"
)
# Test with non-DeepSeek model
model = create_llm_client("openrouter", "mistral/mistral-large", is_expert=True)
mock_openai.assert_called_with(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="mistral/mistral-large",
temperature=0
)
def test_deepseek_environment_fallback(clean_env, mock_deepseek_reasoner, monkeypatch):
"""Test DeepSeek environment variable fallback behavior."""
# Test environment variable helper with fallback
monkeypatch.setenv("DEEPSEEK_API_KEY", "base-key")
assert get_env_var("DEEPSEEK_API_KEY", expert=True) == "base-key"
# Test provider config with fallback
config = get_provider_config("deepseek", is_expert=True)
assert config["api_key"] == "base-key"
assert config["base_url"] == "https://api.deepseek.com"
# Test with expert key
monkeypatch.setenv("EXPERT_DEEPSEEK_API_KEY", "expert-key")
config = get_provider_config("deepseek", is_expert=True)
assert config["api_key"] == "expert-key"
# Test client creation with expert key
model = create_llm_client("deepseek", "deepseek-reasoner", is_expert=True)
mock_deepseek_reasoner.assert_called_with(
api_key="expert-key",
base_url="https://api.deepseek.com",
temperature=0,
model="deepseek-reasoner"
)