Add Deepseek Provider Support and Custom Deepseek Reasoner Chat Model (#50)

* chore: Add DeepSeek provider environment variable support in env.py

* feat: Add DeepSeek provider validation strategy in provider_strategy.py

* feat: Add support for DEEPSEEK provider in initialize_llm function

* feat: Create ChatDeepseekReasoner for custom handling of R1 models

* feat: Configure custom OpenAI client for DeepSeek API integration

* chore: Remove unused json import from deepseek_chat.py

* refactor: Simplify invocation_params and update acompletion_with_retry method

* feat: Override _generate to ensure message alternation in DeepseekReasoner

* feat: Add support for ChatDeepseekReasoner in LLM initialization

* feat: Use custom ChatDeepseekReasoner for DeepSeek models in OpenRouter

* fix: Remove redundant condition for DeepSeek model initialization

* feat: Add DeepSeek support for expert model initialization in llm.py

* feat: Add DeepSeek model handling for OpenRouter in expert LLM initialization

* fix: Update model name checks for DeepSeek and OpenRouter providers

* refactor: Extract common logic for LLM initialization into reusable methods

* test: Add unit tests for DeepSeek and OpenRouter functionality

* test: Refactor tests to match updated LLM initialization and helpers

* fix: Import missing helper functions to resolve NameError in tests

* fix: Resolve NameError and improve environment variable fallback logic

* feat(readme): add DeepSeek API key requirements to documentation for better clarity on environment variables
feat(main.py): include DeepSeek as a supported provider in argument parsing for enhanced functionality
feat(deepseek_chat.py): implement ChatDeepseekReasoner class for handling DeepSeek reasoning models
feat(llm.py): add DeepSeek client creation logic to support DeepSeek models in the application
feat(models_tokens.py): define token limits for DeepSeek models to manage resource allocation
fix(provider_strategy.py): correct validation logic for DeepSeek environment variables to ensure proper configuration
chore(memory.py): refactor global memory structure for better readability and maintainability in the codebase
This commit is contained in:
Ariel Frischer 2025-01-22 04:21:10 -08:00 committed by GitHub
parent 7a68de2d06
commit 686ab42f88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 683 additions and 312 deletions

View File

@ -293,6 +293,7 @@ RA.Aid supports multiple providers through environment variables:
- `ANTHROPIC_API_KEY`: Required for the default Anthropic provider
- `OPENAI_API_KEY`: Required for OpenAI provider
- `OPENROUTER_API_KEY`: Required for OpenRouter provider
- `DEEPSEEK_API_KEY`: Required for DeepSeek provider
- `OPENAI_API_BASE`: Required for OpenAI-compatible providers along with `OPENAI_API_KEY`
- `GEMINI_API_KEY`: Required for Gemini provider
@ -302,6 +303,7 @@ Expert Tool Environment Variables:
- `EXPERT_OPENROUTER_API_KEY`: API key for expert tool using OpenRouter provider
- `EXPERT_OPENAI_API_BASE`: Base URL for expert tool using OpenAI-compatible provider
- `EXPERT_GEMINI_API_KEY`: API key for expert tool using Gemini provider
- `EXPERT_DEEPSEEK_API_KEY`: API key for expert tool using DeepSeek provider
You can set these permanently in your shell's configuration file (e.g., `~/.bashrc` or `~/.zshrc`):
@ -343,6 +345,15 @@ export GEMINI_API_KEY=your_api_key_here
ra-aid -m "Your task" --provider openrouter --model mistralai/mistral-large-2411
```
4. **Using DeepSeek**
```bash
# Direct DeepSeek provider (requires DEEPSEEK_API_KEY)
ra-aid -m "Your task" --provider deepseek --model deepseek-reasoner
# DeepSeek via OpenRouter
ra-aid -m "Your task" --provider openrouter --model deepseek/deepseek-r1
```
4. **Configuring Expert Provider**
The expert tool is used by the agent for complex logic and debugging tasks. It can be configured to use different providers (OpenAI, Anthropic, OpenRouter, Gemini, openai-compatible) using the --expert-provider flag along with the corresponding EXPERT_*API_KEY environment variables.
@ -356,6 +367,10 @@ export GEMINI_API_KEY=your_api_key_here
export OPENROUTER_API_KEY=your_openrouter_api_key
ra-aid -m "Your task" --expert-provider openrouter --expert-model mistralai/mistral-large-2411
# Use DeepSeek for expert tool
export DEEPSEEK_API_KEY=your_deepseek_api_key
ra-aid -m "Your task" --expert-provider deepseek --expert-model deepseek-reasoner
# Use default OpenAI for expert tool
export EXPERT_OPENAI_API_KEY=your_openai_api_key
ra-aid -m "Your task" --expert-provider openai --expert-model o1

View File

@ -39,6 +39,7 @@ def parse_arguments(args=None):
"openai",
"openrouter",
"openai-compatible",
"deepseek",
"gemini",
]
ANTHROPIC_DEFAULT_MODEL = "claude-3-5-sonnet-20241022"

View File

@ -0,0 +1,52 @@
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult
from langchain_openai import ChatOpenAI
from typing import Any, List, Optional, Dict
# Docs: https://api-docs.deepseek.com/guides/reasoning_model
class ChatDeepseekReasoner(ChatOpenAI):
"""ChatDeepseekReasoner with custom overrides for R1/reasoner models."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def invocation_params(
self, options: Optional[Dict[str, Any]] = None, **kwargs: Any
) -> Dict[str, Any]:
params = super().invocation_params(options, **kwargs)
# Remove unsupported params for R1 models
params.pop("temperature", None)
params.pop("top_p", None)
params.pop("presence_penalty", None)
params.pop("frequency_penalty", None)
return params
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Override _generate to ensure message alternation in accordance to Deepseek API."""
processed = []
prev_role = None
for msg in messages:
current_role = "user" if msg.type == "human" else "assistant"
if prev_role == current_role:
if processed:
processed[-1].content += f"\n\n{msg.content}"
else:
processed.append(msg)
prev_role = current_role
return super()._generate(
processed, stop=stop, run_manager=run_manager, **kwargs
)

View File

@ -50,6 +50,9 @@ def copy_base_to_expert_vars(base_provider: str, expert_provider: str) -> None:
'gemini': {
'GEMINI_API_KEY': 'EXPERT_GEMINI_API_KEY',
'GEMINI_MODEL': 'EXPERT_GEMINI_MODEL'
},
'deepseek': {
'DEEPSEEK_API_KEY': 'EXPERT_DEEPSEEK_API_KEY'
}
}

View File

@ -1,107 +1,184 @@
import os
from typing import Optional, Dict, Any
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_core.language_models import BaseChatModel
from langchain_google_genai import ChatGoogleGenerativeAI
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
def get_env_var(name: str, expert: bool = False) -> Optional[str]:
"""Get environment variable with optional expert prefix and fallback."""
prefix = "EXPERT_" if expert else ""
value = os.getenv(f"{prefix}{name}")
def initialize_llm(provider: str, model_name: str, temperature: float | None = None) -> BaseChatModel:
"""Initialize a language model client based on the specified provider and model.
# If expert mode and no expert value, fall back to base value
if expert and not value:
value = os.getenv(name)
Note: Environment variables must be validated before calling this function.
Use validate_environment() to ensure all required variables are set.
return value
def create_deepseek_client(
model_name: str,
api_key: str,
base_url: str,
temperature: Optional[float] = None,
is_expert: bool = False,
) -> BaseChatModel:
"""Create DeepSeek client with appropriate configuration."""
if model_name.lower() == "deepseek-reasoner":
return ChatDeepseekReasoner(
api_key=api_key,
base_url=base_url,
temperature=0
if is_expert
else (temperature if temperature is not None else 1),
model=model_name,
)
return ChatOpenAI(
api_key=api_key,
base_url=base_url,
temperature=0 if is_expert else (temperature if temperature is not None else 1),
model=model_name,
)
def create_openrouter_client(
model_name: str,
api_key: str,
temperature: Optional[float] = None,
is_expert: bool = False,
) -> BaseChatModel:
"""Create OpenRouter client with appropriate configuration."""
if model_name.startswith("deepseek/") and "deepseek-r1" in model_name.lower():
return ChatDeepseekReasoner(
api_key=api_key,
base_url="https://openrouter.ai/api/v1",
temperature=0
if is_expert
else (temperature if temperature is not None else 1),
model=model_name,
)
return ChatOpenAI(
api_key=api_key,
base_url="https://openrouter.ai/api/v1",
model=model_name,
**({"temperature": temperature} if temperature is not None else {}),
)
def get_provider_config(provider: str, is_expert: bool = False) -> Dict[str, Any]:
"""Get provider-specific configuration."""
configs = {
"openai": {
"api_key": get_env_var("OPENAI_API_KEY", is_expert),
"base_url": None,
},
"anthropic": {
"api_key": get_env_var("ANTHROPIC_API_KEY", is_expert),
"base_url": None,
},
"openrouter": {
"api_key": get_env_var("OPENROUTER_API_KEY", is_expert),
"base_url": "https://openrouter.ai/api/v1",
},
"openai-compatible": {
"api_key": get_env_var("OPENAI_API_KEY", is_expert),
"base_url": get_env_var("OPENAI_API_BASE", is_expert),
},
"gemini": {
"api_key": get_env_var("GEMINI_API_KEY", is_expert),
"base_url": None,
},
"deepseek": {
"api_key": get_env_var("DEEPSEEK_API_KEY", is_expert),
"base_url": "https://api.deepseek.com",
},
}
return configs.get(provider, {})
def create_llm_client(
provider: str,
model_name: str,
temperature: Optional[float] = None,
is_expert: bool = False,
) -> BaseChatModel:
"""Create a language model client with appropriate configuration.
Args:
provider: The LLM provider to use ('openai', 'anthropic', 'openrouter', 'openai-compatible', 'gemini')
provider: The LLM provider to use
model_name: Name of the model to use
temperature: Optional temperature setting for controlling randomness (0.0-2.0).
If not specified, provider-specific defaults are used.
temperature: Optional temperature setting (0.0-2.0)
is_expert: Whether this is an expert model (uses deterministic output)
Returns:
BaseChatModel: Configured language model client
Raises:
ValueError: If the provider is not supported
Configured language model client
"""
if provider == "openai":
config = get_provider_config(provider, is_expert)
if not config:
raise ValueError(f"Unsupported provider: {provider}")
# Handle temperature for expert mode
if is_expert:
temperature = 0
if provider == "deepseek":
return create_deepseek_client(
model_name=model_name,
api_key=config["api_key"],
base_url=config["base_url"],
temperature=temperature,
is_expert=is_expert,
)
elif provider == "openrouter":
return create_openrouter_client(
model_name=model_name,
api_key=config["api_key"],
temperature=temperature,
is_expert=is_expert,
)
elif provider == "openai":
return ChatOpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
api_key=config["api_key"],
model=model_name,
**({"temperature": temperature} if temperature is not None else {})
**({"temperature": temperature} if temperature is not None else {}),
)
elif provider == "anthropic":
return ChatAnthropic(
api_key=os.getenv("ANTHROPIC_API_KEY"),
api_key=config["api_key"],
model_name=model_name,
**({"temperature": temperature} if temperature is not None else {})
)
elif provider == "openrouter":
return ChatOpenAI(
api_key=os.getenv("OPENROUTER_API_KEY"),
base_url="https://openrouter.ai/api/v1",
model=model_name,
**({"temperature": temperature} if temperature is not None else {})
**({"temperature": temperature} if temperature is not None else {}),
)
elif provider == "openai-compatible":
return ChatOpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
base_url=os.getenv("OPENAI_API_BASE"),
api_key=config["api_key"],
base_url=config["base_url"],
temperature=temperature if temperature is not None else 0.3,
model=model_name,
)
elif provider == "gemini":
return ChatGoogleGenerativeAI(
api_key=os.getenv("GEMINI_API_KEY"),
api_key=config["api_key"],
model=model_name,
**({"temperature": temperature} if temperature is not None else {})
**({"temperature": temperature} if temperature is not None else {}),
)
else:
raise ValueError(f"Unsupported provider: {provider}")
def initialize_expert_llm(provider: str = "openai", model_name: str = "o1") -> BaseChatModel:
"""Initialize an expert language model client based on the specified provider and model.
Note: Environment variables must be validated before calling this function.
Use validate_environment() to ensure all required variables are set.
def initialize_llm(
provider: str, model_name: str, temperature: float | None = None
) -> BaseChatModel:
"""Initialize a language model client based on the specified provider and model."""
return create_llm_client(provider, model_name, temperature, is_expert=False)
Args:
provider: The LLM provider to use ('openai', 'anthropic', 'openrouter', 'openai-compatible', 'gemini').
Defaults to 'openai'.
model_name: Name of the model to use. Defaults to 'o1'.
Returns:
BaseChatModel: Configured expert language model client
Raises:
ValueError: If the provider is not supported
"""
if provider == "openai":
return ChatOpenAI(
api_key=os.getenv("EXPERT_OPENAI_API_KEY"),
model=model_name,
)
elif provider == "anthropic":
return ChatAnthropic(
api_key=os.getenv("EXPERT_ANTHROPIC_API_KEY"),
model_name=model_name,
)
elif provider == "openrouter":
return ChatOpenAI(
api_key=os.getenv("EXPERT_OPENROUTER_API_KEY"),
base_url="https://openrouter.ai/api/v1",
model=model_name,
)
elif provider == "openai-compatible":
return ChatOpenAI(
api_key=os.getenv("EXPERT_OPENAI_API_KEY"),
base_url=os.getenv("EXPERT_OPENAI_API_BASE"),
model=model_name,
)
elif provider == "gemini":
return ChatGoogleGenerativeAI(
api_key=os.getenv("EXPERT_GEMINI_API_KEY"),
model=model_name,
)
else:
raise ValueError(f"Unsupported provider: {provider}")
def initialize_expert_llm(
provider: str = "openai", model_name: str = "o1"
) -> BaseChatModel:
"""Initialize an expert language model client based on the specified provider and model."""
return create_llm_client(provider, model_name, temperature=None, is_expert=True)

View File

@ -241,6 +241,10 @@ models_tokens = {
"deepseek": {
"deepseek-chat": 28672,
"deepseek-coder": 16384,
"deepseek-reasoner": 65536,
},
"openrouter": {
"deepseek/deepseek-r1": 65536,
},
"ernie": {
"ernie-bot-turbo": 4096,

View File

@ -239,6 +239,31 @@ class GeminiStrategy(ProviderStrategy):
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
class DeepSeekStrategy(ProviderStrategy):
"""DeepSeek provider validation strategy."""
def validate(self, args: Optional[Any] = None) -> ValidationResult:
"""Validate DeepSeek environment variables."""
missing = []
if args and hasattr(args, 'expert_provider') and args.expert_provider == 'deepseek':
key = os.environ.get('EXPERT_DEEPSEEK_API_KEY')
if not key or key == '':
# Try to copy from base if not set
base_key = os.environ.get('DEEPSEEK_API_KEY')
if base_key:
os.environ['EXPERT_DEEPSEEK_API_KEY'] = base_key
key = base_key
if not key:
missing.append('EXPERT_DEEPSEEK_API_KEY environment variable is not set')
else:
key = os.environ.get('DEEPSEEK_API_KEY')
if not key:
missing.append('DEEPSEEK_API_KEY environment variable is not set')
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
class OllamaStrategy(ProviderStrategy):
"""Ollama provider validation strategy."""
@ -272,7 +297,8 @@ class ProviderFactory:
'anthropic': AnthropicStrategy(),
'openrouter': OpenRouterStrategy(),
'gemini': GeminiStrategy(),
'ollama': OllamaStrategy()
'ollama': OllamaStrategy(),
'deepseek': DeepSeekStrategy()
}
strategy = strategies.get(provider)
return strategy

View File

@ -1,123 +1,147 @@
import os
import os
from typing import Dict, List, Any, Union, Optional, Set
from typing_extensions import TypedDict
class WorkLogEntry(TypedDict):
timestamp: str
event: str
from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from langchain_core.tools import tool
class WorkLogEntry(TypedDict):
timestamp: str
event: str
class SnippetInfo(TypedDict):
"""Type definition for source code snippet information"""
filepath: str
line_number: int
snippet: str
description: Optional[str]
console = Console()
# Global memory store
_global_memory: Dict[str, Union[List[Any], Dict[int, str], Dict[int, SnippetInfo], int, Set[str], bool, str, int, List[WorkLogEntry]]] = {
'research_notes': [],
'plans': [],
'tasks': {}, # Dict[int, str] - ID to task mapping
'task_completed': False, # Flag indicating if task is complete
'completion_message': '', # Message explaining completion
'task_id_counter': 1, # Counter for generating unique task IDs
'key_facts': {}, # Dict[int, str] - ID to fact mapping
'key_fact_id_counter': 1, # Counter for generating unique fact IDs
'key_snippets': {}, # Dict[int, SnippetInfo] - ID to snippet mapping
'key_snippet_id_counter': 1, # Counter for generating unique snippet IDs
'implementation_requested': False,
'related_files': {}, # Dict[int, str] - ID to filepath mapping
'related_file_id_counter': 1, # Counter for generating unique file IDs
'plan_completed': False,
'agent_depth': 0,
'work_log': [] # List[WorkLogEntry] - Timestamped work events
_global_memory: Dict[
str,
Union[
List[Any],
Dict[int, str],
Dict[int, SnippetInfo],
int,
Set[str],
bool,
str,
int,
List[WorkLogEntry],
],
] = {
"research_notes": [],
"plans": [],
"tasks": {}, # Dict[int, str] - ID to task mapping
"task_completed": False, # Flag indicating if task is complete
"completion_message": "", # Message explaining completion
"task_id_counter": 1, # Counter for generating unique task IDs
"key_facts": {}, # Dict[int, str] - ID to fact mapping
"key_fact_id_counter": 1, # Counter for generating unique fact IDs
"key_snippets": {}, # Dict[int, SnippetInfo] - ID to snippet mapping
"key_snippet_id_counter": 1, # Counter for generating unique snippet IDs
"implementation_requested": False,
"related_files": {}, # Dict[int, str] - ID to filepath mapping
"related_file_id_counter": 1, # Counter for generating unique file IDs
"plan_completed": False,
"agent_depth": 0,
"work_log": [], # List[WorkLogEntry] - Timestamped work events
}
@tool("emit_research_notes")
def emit_research_notes(notes: str) -> str:
"""Store research notes in global memory.
Args:
notes: REQUIRED The research notes to store
Returns:
The stored notes
"""
_global_memory['research_notes'].append(notes)
_global_memory["research_notes"].append(notes)
console.print(Panel(Markdown(notes), title="🔍 Research Notes"))
return notes
@tool("emit_plan")
def emit_plan(plan: str) -> str:
"""Store a plan step in global memory.
Args:
plan: The plan step to store (markdown format; be clear, complete, use newlines, and use as many tokens as you need)
Returns:
The stored plan
"""
_global_memory['plans'].append(plan)
_global_memory["plans"].append(plan)
console.print(Panel(Markdown(plan), title="📋 Plan"))
log_work_event(f"Added plan step:\n\n{plan}")
return plan
@tool("emit_task")
def emit_task(task: str) -> str:
"""Store a task in global memory.
Args:
task: The task to store
Returns:
String confirming task storage with ID number
"""
# Get and increment task ID
task_id = _global_memory['task_id_counter']
_global_memory['task_id_counter'] += 1
task_id = _global_memory["task_id_counter"]
_global_memory["task_id_counter"] += 1
# Store task with ID
_global_memory['tasks'][task_id] = task
_global_memory["tasks"][task_id] = task
console.print(Panel(Markdown(task), title=f"✅ Task #{task_id}"))
log_work_event(f"Task #{task_id} added:\n\n{task}")
return f"Task #{task_id} stored."
@tool("emit_key_facts")
def emit_key_facts(facts: List[str]) -> str:
"""Store multiple key facts about the project or current task in global memory.
Args:
facts: List of key facts to store
Returns:
List of stored fact confirmation messages
"""
results = []
for fact in facts:
# Get and increment fact ID
fact_id = _global_memory['key_fact_id_counter']
_global_memory['key_fact_id_counter'] += 1
fact_id = _global_memory["key_fact_id_counter"]
_global_memory["key_fact_id_counter"] += 1
# Store fact with ID
_global_memory['key_facts'][fact_id] = fact
_global_memory["key_facts"][fact_id] = fact
# Display panel with ID
console.print(Panel(Markdown(fact), title=f"💡 Key Fact #{fact_id}", border_style="bright_cyan"))
console.print(
Panel(
Markdown(fact),
title=f"💡 Key Fact #{fact_id}",
border_style="bright_cyan",
)
)
# Add result message
results.append(f"Stored fact #{fact_id}: {fact}")
log_work_event(f"Stored {len(facts)} key facts.")
log_work_event(f"Stored {len(facts)} key facts.")
return "Facts stored."
@ -125,50 +149,54 @@ def emit_key_facts(facts: List[str]) -> str:
def delete_key_facts(fact_ids: List[int]) -> str:
"""Delete multiple key facts from global memory by their IDs.
Silently skips any IDs that don't exist.
Args:
fact_ids: List of fact IDs to delete
Returns:
List of success messages for deleted facts
"""
results = []
for fact_id in fact_ids:
if fact_id in _global_memory['key_facts']:
if fact_id in _global_memory["key_facts"]:
# Delete the fact
deleted_fact = _global_memory['key_facts'].pop(fact_id)
deleted_fact = _global_memory["key_facts"].pop(fact_id)
success_msg = f"Successfully deleted fact #{fact_id}: {deleted_fact}"
console.print(Panel(Markdown(success_msg), title="Fact Deleted", border_style="green"))
console.print(
Panel(Markdown(success_msg), title="Fact Deleted", border_style="green")
)
results.append(success_msg)
log_work_event(f"Deleted facts {fact_ids}.")
log_work_event(f"Deleted facts {fact_ids}.")
return "Facts deleted."
@tool("delete_tasks")
def delete_tasks(task_ids: List[int]) -> str:
"""Delete multiple tasks from global memory by their IDs.
Silently skips any IDs that don't exist.
Args:
task_ids: List of task IDs to delete
Returns:
Confirmation message
"""
results = []
for task_id in task_ids:
if task_id in _global_memory['tasks']:
if task_id in _global_memory["tasks"]:
# Delete the task
deleted_task = _global_memory['tasks'].pop(task_id)
deleted_task = _global_memory["tasks"].pop(task_id)
success_msg = f"Successfully deleted task #{task_id}: {deleted_task}"
console.print(Panel(Markdown(success_msg),
title="Task Deleted",
border_style="green"))
console.print(
Panel(Markdown(success_msg), title="Task Deleted", border_style="green")
)
results.append(success_msg)
log_work_event(f"Deleted tasks {task_ids}.")
log_work_event(f"Deleted tasks {task_ids}.")
return "Tasks deleted."
@tool("request_implementation")
def request_implementation() -> str:
"""Request that implementation proceed after research/planning.
@ -178,46 +206,47 @@ def request_implementation() -> str:
Do you need to request research subtasks first?
Have you run relevant unit tests, if they exist, to get a baseline (this can be a subtask)?
Do you need to crawl deeper to find all related files and symbols?
Returns:
Empty string
"""
_global_memory['implementation_requested'] = True
_global_memory["implementation_requested"] = True
console.print(Panel("🚀 Implementation Requested", style="yellow", padding=0))
log_work_event("Implementation requested.")
return ""
@tool("emit_key_snippets")
def emit_key_snippets(snippets: List[SnippetInfo]) -> str:
"""Store multiple key source code snippets in global memory.
Automatically adds the filepaths of the snippets to related files.
This is for **existing**, or **just-written** files, not for things to be created in the future.
Args:
snippets: REQUIRED List of snippet information dictionaries containing:
- filepath: Path to the source file
- line_number: Line number where the snippet starts
- line_number: Line number where the snippet starts
- snippet: The source code snippet text
- description: Optional description of the significance
Returns:
List of stored snippet confirmation messages
"""
# First collect unique filepaths to add as related files
emit_related_files.invoke({"files": [snippet_info['filepath'] for snippet_info in snippets]})
emit_related_files.invoke(
{"files": [snippet_info["filepath"] for snippet_info in snippets]}
)
results = []
for snippet_info in snippets:
# Get and increment snippet ID
snippet_id = _global_memory['key_snippet_id_counter']
_global_memory['key_snippet_id_counter'] += 1
# Get and increment snippet ID
snippet_id = _global_memory["key_snippet_id_counter"]
_global_memory["key_snippet_id_counter"] += 1
# Store snippet info
_global_memory['key_snippets'][snippet_id] = snippet_info
_global_memory["key_snippets"][snippet_id] = snippet_info
# Format display text as markdown
display_text = [
f"**Source Location**:",
@ -226,153 +255,170 @@ def emit_key_snippets(snippets: List[SnippetInfo]) -> str:
"", # Empty line before code block
"**Code**:",
"```python",
snippet_info['snippet'].rstrip(), # Remove trailing whitespace
"```"
snippet_info["snippet"].rstrip(), # Remove trailing whitespace
"```",
]
if snippet_info['description']:
display_text.extend(["", "**Description**:", snippet_info['description']])
if snippet_info["description"]:
display_text.extend(["", "**Description**:", snippet_info["description"]])
# Display panel
console.print(Panel(Markdown("\n".join(display_text)),
title=f"📝 Key Snippet #{snippet_id}",
border_style="bright_cyan"))
console.print(
Panel(
Markdown("\n".join(display_text)),
title=f"📝 Key Snippet #{snippet_id}",
border_style="bright_cyan",
)
)
results.append(f"Stored snippet #{snippet_id}")
log_work_event(f"Stored {len(snippets)} code snippets.")
log_work_event(f"Stored {len(snippets)} code snippets.")
return "Snippets stored."
@tool("delete_key_snippets")
@tool("delete_key_snippets")
def delete_key_snippets(snippet_ids: List[int]) -> str:
"""Delete multiple key snippets from global memory by their IDs.
Silently skips any IDs that don't exist.
Args:
snippet_ids: List of snippet IDs to delete
Returns:
List of success messages for deleted snippets
"""
results = []
for snippet_id in snippet_ids:
if snippet_id in _global_memory['key_snippets']:
if snippet_id in _global_memory["key_snippets"]:
# Delete the snippet
deleted_snippet = _global_memory['key_snippets'].pop(snippet_id)
deleted_snippet = _global_memory["key_snippets"].pop(snippet_id)
success_msg = f"Successfully deleted snippet #{snippet_id} from {deleted_snippet['filepath']}"
console.print(Panel(Markdown(success_msg),
title="Snippet Deleted",
border_style="green"))
console.print(
Panel(
Markdown(success_msg), title="Snippet Deleted", border_style="green"
)
)
results.append(success_msg)
log_work_event(f"Deleted snippets {snippet_ids}.")
log_work_event(f"Deleted snippets {snippet_ids}.")
return "Snippets deleted."
@tool("swap_task_order")
def swap_task_order(id1: int, id2: int) -> str:
"""Swap the order of two tasks in global memory by their IDs.
Args:
id1: First task ID
id2: Second task ID
Returns:
Success or error message depending on outcome
"""
# Validate IDs are different
if id1 == id2:
return "Cannot swap task with itself"
# Validate both IDs exist
if id1 not in _global_memory['tasks'] or id2 not in _global_memory['tasks']:
if id1 not in _global_memory["tasks"] or id2 not in _global_memory["tasks"]:
return "Invalid task ID(s)"
# Swap the tasks
_global_memory['tasks'][id1], _global_memory['tasks'][id2] = \
_global_memory['tasks'][id2], _global_memory['tasks'][id1]
_global_memory["tasks"][id1], _global_memory["tasks"][id2] = (
_global_memory["tasks"][id2],
_global_memory["tasks"][id1],
)
# Display what was swapped
console.print(Panel(
Markdown(f"Swapped:\n- Task #{id1} ↔️ Task #{id2}"),
title="🔄 Tasks Reordered",
border_style="green"
))
console.print(
Panel(
Markdown(f"Swapped:\n- Task #{id1} ↔️ Task #{id2}"),
title="🔄 Tasks Reordered",
border_style="green",
)
)
return "Tasks swapped."
@tool("one_shot_completed")
def one_shot_completed(message: str) -> str:
"""Signal that a one-shot task has been completed and execution should stop.
Only call this if you have already **fully** completed the original request.
Args:
message: Completion message to display
"""
if _global_memory.get('implementation_requested', False):
if _global_memory.get("implementation_requested", False):
return "Cannot complete in one shot - implementation was requested"
_global_memory['task_completed'] = True
_global_memory['completion_message'] = message
_global_memory["task_completed"] = True
_global_memory["completion_message"] = message
console.print(Panel(Markdown(message), title="✅ Task Completed"))
log_work_event(f"Task completed\n\n{message}")
return "Completion noted."
@tool("task_completed")
def task_completed(message: str) -> str:
"""Mark the current task as completed with a completion message.
Args:
message: Message explaining how/why the task is complete
Returns:
The completion message
"""
_global_memory['task_completed'] = True
_global_memory['completion_message'] = message
_global_memory["task_completed"] = True
_global_memory["completion_message"] = message
console.print(Panel(Markdown(message), title="✅ Task Completed"))
return "Completion noted."
@tool("plan_implementation_completed")
def plan_implementation_completed(message: str) -> str:
"""Mark the entire implementation plan as completed.
Args:
message: Message explaining how the implementation plan was completed
Returns:
Confirmation message
"""
_global_memory['plan_completed'] = True
_global_memory['completion_message'] = message
_global_memory['tasks'].clear() # Clear task list when plan is completed
_global_memory['task_id_counter'] = 1
_global_memory["plan_completed"] = True
_global_memory["completion_message"] = message
_global_memory["tasks"].clear() # Clear task list when plan is completed
_global_memory["task_id_counter"] = 1
console.print(Panel(Markdown(message), title="✅ Plan Executed"))
log_work_event(f"Plan execution completed:\n\n{message}")
return "Plan completion noted and task list cleared."
def get_related_files() -> List[str]:
"""Get the current list of related files.
Returns:
List of formatted strings in the format 'ID#X path/to/file.py'
"""
files = _global_memory['related_files']
files = _global_memory["related_files"]
return [f"ID#{file_id} {filepath}" for file_id, filepath in sorted(files.items())]
@tool("emit_related_files")
def emit_related_files(files: List[str]) -> str:
"""Store multiple related files that tools should work with.
Args:
files: List of file paths to add
Returns:
Formatted string containing file IDs and paths for all processed files
"""
results = []
added_files = []
invalid_paths = []
# Process files
for file in files:
# First check if path exists
@ -380,111 +426,115 @@ def emit_related_files(files: List[str]) -> str:
invalid_paths.append(file)
results.append(f"Error: Path '{file}' does not exist")
continue
# Then check if it's a directory
if os.path.isdir(file):
invalid_paths.append(file)
results.append(f"Error: Path '{file}' is a directory, not a file")
continue
# Finally validate it's a regular file
if not os.path.isfile(file):
invalid_paths.append(file)
results.append(f"Error: Path '{file}' exists but is not a regular file")
continue
# Check if file path already exists in values
existing_id = None
for fid, fpath in _global_memory['related_files'].items():
for fid, fpath in _global_memory["related_files"].items():
if fpath == file:
existing_id = fid
break
if existing_id is not None:
# File exists, use existing ID
results.append(f"File ID #{existing_id}: {file}")
else:
# New file, assign new ID
file_id = _global_memory['related_file_id_counter']
_global_memory['related_file_id_counter'] += 1
file_id = _global_memory["related_file_id_counter"]
_global_memory["related_file_id_counter"] += 1
# Store file with ID
_global_memory['related_files'][file_id] = file
_global_memory["related_files"][file_id] = file
added_files.append((file_id, file))
results.append(f"File ID #{file_id}: {file}")
# Rich output - single consolidated panel
if added_files:
files_added_md = '\n'.join(f"- `{file}`" for id, file in added_files)
files_added_md = "\n".join(f"- `{file}`" for id, file in added_files)
md_content = f"**Files Noted:**\n{files_added_md}"
console.print(Panel(Markdown(md_content),
title="📁 Related Files Noted",
border_style="green"))
return '\n'.join(results)
console.print(
Panel(
Markdown(md_content),
title="📁 Related Files Noted",
border_style="green",
)
)
return "\n".join(results)
def log_work_event(event: str) -> str:
"""Add timestamped entry to work log.
Internal function used to track major events during agent execution.
Each entry is stored with an ISO format timestamp.
Args:
event: Description of the event to log
Returns:
Confirmation message
Note:
Entries can be retrieved with get_work_log() as markdown formatted text.
"""
from datetime import datetime
entry = WorkLogEntry(
timestamp=datetime.now().isoformat(),
event=event
)
_global_memory['work_log'].append(entry)
entry = WorkLogEntry(timestamp=datetime.now().isoformat(), event=event)
_global_memory["work_log"].append(entry)
return f"Event logged: {event}"
def get_work_log() -> str:
"""Return formatted markdown of work log entries.
Returns:
Markdown formatted text with timestamps as headings and events as content,
or 'No work log entries' if the log is empty.
Example:
## 2024-12-23T11:39:10
Task #1 added: Create login form
"""
if not _global_memory['work_log']:
if not _global_memory["work_log"]:
return "No work log entries"
entries = []
for entry in _global_memory['work_log']:
entries.extend([
f"## {entry['timestamp']}",
"",
entry['event'],
"" # Blank line between entries
])
for entry in _global_memory["work_log"]:
entries.extend(
[
f"## {entry['timestamp']}",
"",
entry["event"],
"", # Blank line between entries
]
)
return "\n".join(entries).rstrip() # Remove trailing newline
def reset_work_log() -> str:
"""Clear the work log.
Returns:
Returns:
Confirmation message
Note:
This permanently removes all work log entries. The operation cannot be undone.
"""
_global_memory['work_log'].clear()
_global_memory["work_log"].clear()
return "Work log cleared"
@ -492,37 +542,44 @@ def reset_work_log() -> str:
def deregister_related_files(file_ids: List[int]) -> str:
"""Delete multiple related files from global memory by their IDs.
Silently skips any IDs that don't exist.
Args:
file_ids: List of file IDs to delete
Returns:
Success message string
"""
results = []
for file_id in file_ids:
if file_id in _global_memory['related_files']:
if file_id in _global_memory["related_files"]:
# Delete the file reference
deleted_file = _global_memory['related_files'].pop(file_id)
success_msg = f"Successfully removed related file #{file_id}: {deleted_file}"
console.print(Panel(Markdown(success_msg),
title="File Reference Removed",
border_style="green"))
deleted_file = _global_memory["related_files"].pop(file_id)
success_msg = (
f"Successfully removed related file #{file_id}: {deleted_file}"
)
console.print(
Panel(
Markdown(success_msg),
title="File Reference Removed",
border_style="green",
)
)
results.append(success_msg)
return "File references removed."
def get_memory_value(key: str) -> str:
"""Get a value from global memory.
Different memory types return different formats:
- key_facts: Returns numbered list of facts in format '#ID: fact'
- key_snippets: Returns formatted snippets with file path, line number and content
- All other types: Returns newline-separated list of values
Args:
key: The key to get from memory
Returns:
String representation of the memory values:
- For key_facts: '#ID: fact' format, one per line
@ -530,23 +587,25 @@ def get_memory_value(key: str) -> str:
- For other types: One value per line
"""
values = _global_memory.get(key, [])
if key == 'key_facts':
if key == "key_facts":
# For empty dict, return empty string
if not values:
return ""
# Sort by ID for consistent output and format as markdown sections
facts = []
for k, v in sorted(values.items()):
facts.extend([
f"## 🔑 Key Fact #{k}",
"", # Empty line for better markdown spacing
v,
"" # Empty line between facts
])
facts.extend(
[
f"## 🔑 Key Fact #{k}",
"", # Empty line for better markdown spacing
v,
"", # Empty line between facts
]
)
return "\n".join(facts).rstrip() # Remove trailing newline
if key == 'key_snippets':
if key == "key_snippets":
if not values:
return ""
# Format each snippet with file info and content using markdown
@ -555,26 +614,25 @@ def get_memory_value(key: str) -> str:
snippet_text = [
f"## 📝 Code Snippet #{k}",
"", # Empty line for better markdown spacing
f"**Source Location**:",
"**Source Location**:",
f"- File: `{v['filepath']}`",
f"- Line: `{v['line_number']}`",
"", # Empty line before code block
"**Code**:",
"```python",
v['snippet'].rstrip(), # Remove trailing whitespace
"```"
v["snippet"].rstrip(), # Remove trailing whitespace
"```",
]
if v['description']:
if v["description"]:
# Add empty line and description
snippet_text.extend(["", "**Description**:", v['description']])
snippet_text.extend(["", "**Description**:", v["description"]])
snippets.append("\n".join(snippet_text))
return "\n\n".join(snippets)
if key == 'work_log':
if key == "work_log":
if not values:
return ""
entries = [f"## {entry['timestamp']}\n{entry['event']}"
for entry in values]
entries = [f"## {entry['timestamp']}\n{entry['event']}" for entry in values]
return "\n\n".join(entries)
# For other types (lists), join with newlines

View File

@ -9,7 +9,13 @@ from dataclasses import dataclass
from ra_aid.agents.ciayn_agent import CiaynAgent
from ra_aid.env import validate_environment
from ra_aid.llm import initialize_llm, initialize_expert_llm
from ra_aid.llm import (
initialize_llm,
initialize_expert_llm,
get_env_var,
get_provider_config,
create_llm_client
)
@pytest.fixture
def clean_env(monkeypatch):
@ -39,7 +45,8 @@ def test_initialize_expert_defaults(clean_env, mock_openai, monkeypatch):
mock_openai.assert_called_once_with(
api_key="test-key",
model="o1"
model="o1",
temperature=0
)
def test_initialize_expert_openai_custom(clean_env, mock_openai, monkeypatch):
@ -49,7 +56,8 @@ def test_initialize_expert_openai_custom(clean_env, mock_openai, monkeypatch):
mock_openai.assert_called_once_with(
api_key="test-key",
model="gpt-4-preview"
model="gpt-4-preview",
temperature=0
)
def test_initialize_expert_gemini(clean_env, mock_gemini, monkeypatch):
@ -59,7 +67,8 @@ def test_initialize_expert_gemini(clean_env, mock_gemini, monkeypatch):
mock_gemini.assert_called_once_with(
api_key="test-key",
model="gemini-2.0-flash-thinking-exp-1219"
model="gemini-2.0-flash-thinking-exp-1219",
temperature=0
)
def test_initialize_expert_anthropic(clean_env, mock_anthropic, monkeypatch):
@ -69,7 +78,8 @@ def test_initialize_expert_anthropic(clean_env, mock_anthropic, monkeypatch):
mock_anthropic.assert_called_once_with(
api_key="test-key",
model_name="claude-3"
model_name="claude-3",
temperature=0
)
def test_initialize_expert_openrouter(clean_env, mock_openai, monkeypatch):
@ -80,7 +90,8 @@ def test_initialize_expert_openrouter(clean_env, mock_openai, monkeypatch):
mock_openai.assert_called_once_with(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="models/mistral-large"
model="models/mistral-large",
temperature=0
)
def test_initialize_expert_openai_compatible(clean_env, mock_openai, monkeypatch):
@ -92,7 +103,8 @@ def test_initialize_expert_openai_compatible(clean_env, mock_openai, monkeypatch
mock_openai.assert_called_once_with(
api_key="test-key",
base_url="http://test-url",
model="local-model"
model="local-model",
temperature=0
)
def test_initialize_expert_unsupported_provider(clean_env):
@ -311,45 +323,49 @@ def test_initialize_llm_cross_provider(clean_env, mock_openai, mock_anthropic, m
model="gemini-2.0-flash-thinking-exp-1219"
)
@dataclass
class Args:
"""Test arguments class."""
provider: str
expert_provider: str
model: str = None
expert_model: str = None
def test_environment_variable_precedence(clean_env, mock_openai, monkeypatch):
"""Test environment variable precedence and fallback."""
from ra_aid.env import validate_environment
from dataclasses import dataclass
# Test get_env_var helper with fallback
monkeypatch.setenv("TEST_KEY", "base-value")
monkeypatch.setenv("EXPERT_TEST_KEY", "expert-value")
@dataclass
class Args:
provider: str
expert_provider: str
model: str = None
expert_model: str = None
# Test expert mode with explicit key
# Set up base environment first
monkeypatch.setenv("OPENAI_API_KEY", "base-key")
assert get_env_var("TEST_KEY") == "base-value"
assert get_env_var("TEST_KEY", expert=True) == "expert-value"
# Test fallback when expert value not set
monkeypatch.delenv("EXPERT_TEST_KEY", raising=False)
assert get_env_var("TEST_KEY", expert=True) == "base-value"
# Test provider config
monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "expert-key")
monkeypatch.setenv("TAVILY_API_KEY", "tavily-key")
monkeypatch.setenv("GEMINI_API_KEY", "gemini-key")
args = Args(provider="openai", expert_provider="openai")
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
assert expert_enabled
assert not expert_missing
assert web_enabled
assert not web_missing
llm = initialize_expert_llm()
config = get_provider_config("openai", is_expert=True)
assert config["api_key"] == "expert-key"
# Test LLM client creation with expert mode
llm = create_llm_client("openai", "o1", is_expert=True)
mock_openai.assert_called_with(
api_key="expert-key",
model="o1"
model="o1",
temperature=0
)
# Test empty key validation
# Test environment validation
monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "")
monkeypatch.delenv("OPENAI_API_KEY", raising=False) # Remove fallback
monkeypatch.delenv("TAVILY_API_KEY", raising=False) # Remove web research
monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key") # Add for provider validation
monkeypatch.setenv("GEMINI_API_KEY", "gemini-key") # Add for provider validation
monkeypatch.setenv("ANTHROPIC_MODEL", "claude-3-haiku-20240307") # Add model for provider validation
args = Args(provider="anthropic", expert_provider="openai") # Change base provider to avoid validation error
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.delenv("TAVILY_API_KEY", raising=False)
monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key")
monkeypatch.setenv("GEMINI_API_KEY", "gemini-key")
monkeypatch.setenv("ANTHROPIC_MODEL", "claude-3-haiku-20240307")
args = Args(provider="anthropic", expert_provider="openai")
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
assert not expert_enabled
assert expert_missing
@ -372,3 +388,122 @@ def mock_gemini():
with patch('ra_aid.llm.ChatGoogleGenerativeAI') as mock:
mock.return_value = Mock(spec=ChatGoogleGenerativeAI)
yield mock
@pytest.fixture
def mock_deepseek_reasoner():
"""Mock ChatDeepseekReasoner for testing DeepSeek provider initialization."""
with patch('ra_aid.llm.ChatDeepseekReasoner') as mock:
mock.return_value = Mock()
yield mock
def test_initialize_deepseek(clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch):
"""Test DeepSeek provider initialization with different models."""
monkeypatch.setenv("DEEPSEEK_API_KEY", "test-key")
# Test with reasoner model
model = initialize_llm("deepseek", "deepseek-reasoner")
mock_deepseek_reasoner.assert_called_with(
api_key="test-key",
base_url="https://api.deepseek.com",
temperature=1,
model="deepseek-reasoner"
)
# Test with non-reasoner model
model = initialize_llm("deepseek", "deepseek-chat")
mock_openai.assert_called_with(
api_key="test-key",
base_url="https://api.deepseek.com",
temperature=1,
model="deepseek-chat"
)
def test_initialize_expert_deepseek(clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch):
"""Test expert DeepSeek provider initialization."""
monkeypatch.setenv("EXPERT_DEEPSEEK_API_KEY", "test-key")
# Test with reasoner model
model = initialize_expert_llm("deepseek", "deepseek-reasoner")
mock_deepseek_reasoner.assert_called_with(
api_key="test-key",
base_url="https://api.deepseek.com",
temperature=0,
model="deepseek-reasoner"
)
# Test with non-reasoner model
model = initialize_expert_llm("deepseek", "deepseek-chat")
mock_openai.assert_called_with(
api_key="test-key",
base_url="https://api.deepseek.com",
temperature=0,
model="deepseek-chat"
)
def test_initialize_openrouter_deepseek(clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch):
"""Test OpenRouter DeepSeek model initialization."""
monkeypatch.setenv("OPENROUTER_API_KEY", "test-key")
# Test with DeepSeek R1 model
model = initialize_llm("openrouter", "deepseek/deepseek-r1")
mock_deepseek_reasoner.assert_called_with(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
temperature=1,
model="deepseek/deepseek-r1"
)
# Test with non-DeepSeek model
model = initialize_llm("openrouter", "mistral/mistral-large")
mock_openai.assert_called_with(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="mistral/mistral-large"
)
def test_initialize_expert_openrouter_deepseek(clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch):
"""Test expert OpenRouter DeepSeek model initialization."""
monkeypatch.setenv("EXPERT_OPENROUTER_API_KEY", "test-key")
# Test with DeepSeek R1 model via create_llm_client
model = create_llm_client("openrouter", "deepseek/deepseek-r1", is_expert=True)
mock_deepseek_reasoner.assert_called_with(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
temperature=0,
model="deepseek/deepseek-r1"
)
# Test with non-DeepSeek model
model = create_llm_client("openrouter", "mistral/mistral-large", is_expert=True)
mock_openai.assert_called_with(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="mistral/mistral-large",
temperature=0
)
def test_deepseek_environment_fallback(clean_env, mock_deepseek_reasoner, monkeypatch):
"""Test DeepSeek environment variable fallback behavior."""
# Test environment variable helper with fallback
monkeypatch.setenv("DEEPSEEK_API_KEY", "base-key")
assert get_env_var("DEEPSEEK_API_KEY", expert=True) == "base-key"
# Test provider config with fallback
config = get_provider_config("deepseek", is_expert=True)
assert config["api_key"] == "base-key"
assert config["base_url"] == "https://api.deepseek.com"
# Test with expert key
monkeypatch.setenv("EXPERT_DEEPSEEK_API_KEY", "expert-key")
config = get_provider_config("deepseek", is_expert=True)
assert config["api_key"] == "expert-key"
# Test client creation with expert key
model = create_llm_client("deepseek", "deepseek-reasoner", is_expert=True)
mock_deepseek_reasoner.assert_called_with(
api_key="expert-key",
base_url="https://api.deepseek.com",
temperature=0,
model="deepseek-reasoner"
)