diff --git a/README.md b/README.md index 6aa9851..d298451 100644 --- a/README.md +++ b/README.md @@ -293,6 +293,7 @@ RA.Aid supports multiple providers through environment variables: - `ANTHROPIC_API_KEY`: Required for the default Anthropic provider - `OPENAI_API_KEY`: Required for OpenAI provider - `OPENROUTER_API_KEY`: Required for OpenRouter provider +- `DEEPSEEK_API_KEY`: Required for DeepSeek provider - `OPENAI_API_BASE`: Required for OpenAI-compatible providers along with `OPENAI_API_KEY` - `GEMINI_API_KEY`: Required for Gemini provider @@ -302,6 +303,7 @@ Expert Tool Environment Variables: - `EXPERT_OPENROUTER_API_KEY`: API key for expert tool using OpenRouter provider - `EXPERT_OPENAI_API_BASE`: Base URL for expert tool using OpenAI-compatible provider - `EXPERT_GEMINI_API_KEY`: API key for expert tool using Gemini provider +- `EXPERT_DEEPSEEK_API_KEY`: API key for expert tool using DeepSeek provider You can set these permanently in your shell's configuration file (e.g., `~/.bashrc` or `~/.zshrc`): @@ -343,6 +345,15 @@ export GEMINI_API_KEY=your_api_key_here ra-aid -m "Your task" --provider openrouter --model mistralai/mistral-large-2411 ``` +4. **Using DeepSeek** + ```bash + # Direct DeepSeek provider (requires DEEPSEEK_API_KEY) + ra-aid -m "Your task" --provider deepseek --model deepseek-reasoner + + # DeepSeek via OpenRouter + ra-aid -m "Your task" --provider openrouter --model deepseek/deepseek-r1 + ``` + 4. **Configuring Expert Provider** The expert tool is used by the agent for complex logic and debugging tasks. It can be configured to use different providers (OpenAI, Anthropic, OpenRouter, Gemini, openai-compatible) using the --expert-provider flag along with the corresponding EXPERT_*API_KEY environment variables. @@ -356,6 +367,10 @@ export GEMINI_API_KEY=your_api_key_here export OPENROUTER_API_KEY=your_openrouter_api_key ra-aid -m "Your task" --expert-provider openrouter --expert-model mistralai/mistral-large-2411 + # Use DeepSeek for expert tool + export DEEPSEEK_API_KEY=your_deepseek_api_key + ra-aid -m "Your task" --expert-provider deepseek --expert-model deepseek-reasoner + # Use default OpenAI for expert tool export EXPERT_OPENAI_API_KEY=your_openai_api_key ra-aid -m "Your task" --expert-provider openai --expert-model o1 diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index 6c2489f..4bf5992 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -39,6 +39,7 @@ def parse_arguments(args=None): "openai", "openrouter", "openai-compatible", + "deepseek", "gemini", ] ANTHROPIC_DEFAULT_MODEL = "claude-3-5-sonnet-20241022" diff --git a/ra_aid/chat_models/deepseek_chat.py b/ra_aid/chat_models/deepseek_chat.py new file mode 100644 index 0000000..aa22be3 --- /dev/null +++ b/ra_aid/chat_models/deepseek_chat.py @@ -0,0 +1,52 @@ +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.messages import BaseMessage +from langchain_core.outputs import ChatResult +from langchain_openai import ChatOpenAI +from typing import Any, List, Optional, Dict + + +# Docs: https://api-docs.deepseek.com/guides/reasoning_model +class ChatDeepseekReasoner(ChatOpenAI): + """ChatDeepseekReasoner with custom overrides for R1/reasoner models.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def invocation_params( + self, options: Optional[Dict[str, Any]] = None, **kwargs: Any + ) -> Dict[str, Any]: + params = super().invocation_params(options, **kwargs) + + # Remove unsupported params for R1 models + params.pop("temperature", None) + params.pop("top_p", None) + params.pop("presence_penalty", None) + params.pop("frequency_penalty", None) + + return params + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Override _generate to ensure message alternation in accordance to Deepseek API.""" + + processed = [] + prev_role = None + + for msg in messages: + current_role = "user" if msg.type == "human" else "assistant" + + if prev_role == current_role: + if processed: + processed[-1].content += f"\n\n{msg.content}" + else: + processed.append(msg) + prev_role = current_role + + return super()._generate( + processed, stop=stop, run_manager=run_manager, **kwargs + ) diff --git a/ra_aid/env.py b/ra_aid/env.py index d41ca70..cd8b317 100644 --- a/ra_aid/env.py +++ b/ra_aid/env.py @@ -50,6 +50,9 @@ def copy_base_to_expert_vars(base_provider: str, expert_provider: str) -> None: 'gemini': { 'GEMINI_API_KEY': 'EXPERT_GEMINI_API_KEY', 'GEMINI_MODEL': 'EXPERT_GEMINI_MODEL' + }, + 'deepseek': { + 'DEEPSEEK_API_KEY': 'EXPERT_DEEPSEEK_API_KEY' } } diff --git a/ra_aid/llm.py b/ra_aid/llm.py index 02b3c57..2821578 100644 --- a/ra_aid/llm.py +++ b/ra_aid/llm.py @@ -1,107 +1,184 @@ import os +from typing import Optional, Dict, Any from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic from langchain_core.language_models import BaseChatModel from langchain_google_genai import ChatGoogleGenerativeAI +from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner +def get_env_var(name: str, expert: bool = False) -> Optional[str]: + """Get environment variable with optional expert prefix and fallback.""" + prefix = "EXPERT_" if expert else "" + value = os.getenv(f"{prefix}{name}") -def initialize_llm(provider: str, model_name: str, temperature: float | None = None) -> BaseChatModel: - """Initialize a language model client based on the specified provider and model. + # If expert mode and no expert value, fall back to base value + if expert and not value: + value = os.getenv(name) - Note: Environment variables must be validated before calling this function. - Use validate_environment() to ensure all required variables are set. + return value + + +def create_deepseek_client( + model_name: str, + api_key: str, + base_url: str, + temperature: Optional[float] = None, + is_expert: bool = False, +) -> BaseChatModel: + """Create DeepSeek client with appropriate configuration.""" + if model_name.lower() == "deepseek-reasoner": + return ChatDeepseekReasoner( + api_key=api_key, + base_url=base_url, + temperature=0 + if is_expert + else (temperature if temperature is not None else 1), + model=model_name, + ) + + return ChatOpenAI( + api_key=api_key, + base_url=base_url, + temperature=0 if is_expert else (temperature if temperature is not None else 1), + model=model_name, + ) + + +def create_openrouter_client( + model_name: str, + api_key: str, + temperature: Optional[float] = None, + is_expert: bool = False, +) -> BaseChatModel: + """Create OpenRouter client with appropriate configuration.""" + if model_name.startswith("deepseek/") and "deepseek-r1" in model_name.lower(): + return ChatDeepseekReasoner( + api_key=api_key, + base_url="https://openrouter.ai/api/v1", + temperature=0 + if is_expert + else (temperature if temperature is not None else 1), + model=model_name, + ) + + return ChatOpenAI( + api_key=api_key, + base_url="https://openrouter.ai/api/v1", + model=model_name, + **({"temperature": temperature} if temperature is not None else {}), + ) + + +def get_provider_config(provider: str, is_expert: bool = False) -> Dict[str, Any]: + """Get provider-specific configuration.""" + configs = { + "openai": { + "api_key": get_env_var("OPENAI_API_KEY", is_expert), + "base_url": None, + }, + "anthropic": { + "api_key": get_env_var("ANTHROPIC_API_KEY", is_expert), + "base_url": None, + }, + "openrouter": { + "api_key": get_env_var("OPENROUTER_API_KEY", is_expert), + "base_url": "https://openrouter.ai/api/v1", + }, + "openai-compatible": { + "api_key": get_env_var("OPENAI_API_KEY", is_expert), + "base_url": get_env_var("OPENAI_API_BASE", is_expert), + }, + "gemini": { + "api_key": get_env_var("GEMINI_API_KEY", is_expert), + "base_url": None, + }, + "deepseek": { + "api_key": get_env_var("DEEPSEEK_API_KEY", is_expert), + "base_url": "https://api.deepseek.com", + }, + } + return configs.get(provider, {}) + + +def create_llm_client( + provider: str, + model_name: str, + temperature: Optional[float] = None, + is_expert: bool = False, +) -> BaseChatModel: + """Create a language model client with appropriate configuration. Args: - provider: The LLM provider to use ('openai', 'anthropic', 'openrouter', 'openai-compatible', 'gemini') + provider: The LLM provider to use model_name: Name of the model to use - temperature: Optional temperature setting for controlling randomness (0.0-2.0). - If not specified, provider-specific defaults are used. + temperature: Optional temperature setting (0.0-2.0) + is_expert: Whether this is an expert model (uses deterministic output) Returns: - BaseChatModel: Configured language model client - - Raises: - ValueError: If the provider is not supported + Configured language model client """ - if provider == "openai": + config = get_provider_config(provider, is_expert) + if not config: + raise ValueError(f"Unsupported provider: {provider}") + + # Handle temperature for expert mode + if is_expert: + temperature = 0 + + if provider == "deepseek": + return create_deepseek_client( + model_name=model_name, + api_key=config["api_key"], + base_url=config["base_url"], + temperature=temperature, + is_expert=is_expert, + ) + elif provider == "openrouter": + return create_openrouter_client( + model_name=model_name, + api_key=config["api_key"], + temperature=temperature, + is_expert=is_expert, + ) + elif provider == "openai": return ChatOpenAI( - api_key=os.getenv("OPENAI_API_KEY"), + api_key=config["api_key"], model=model_name, - **({"temperature": temperature} if temperature is not None else {}) + **({"temperature": temperature} if temperature is not None else {}), ) elif provider == "anthropic": return ChatAnthropic( - api_key=os.getenv("ANTHROPIC_API_KEY"), + api_key=config["api_key"], model_name=model_name, - **({"temperature": temperature} if temperature is not None else {}) - ) - elif provider == "openrouter": - return ChatOpenAI( - api_key=os.getenv("OPENROUTER_API_KEY"), - base_url="https://openrouter.ai/api/v1", - model=model_name, - **({"temperature": temperature} if temperature is not None else {}) + **({"temperature": temperature} if temperature is not None else {}), ) elif provider == "openai-compatible": return ChatOpenAI( - api_key=os.getenv("OPENAI_API_KEY"), - base_url=os.getenv("OPENAI_API_BASE"), + api_key=config["api_key"], + base_url=config["base_url"], temperature=temperature if temperature is not None else 0.3, model=model_name, ) elif provider == "gemini": return ChatGoogleGenerativeAI( - api_key=os.getenv("GEMINI_API_KEY"), + api_key=config["api_key"], model=model_name, - **({"temperature": temperature} if temperature is not None else {}) + **({"temperature": temperature} if temperature is not None else {}), ) else: raise ValueError(f"Unsupported provider: {provider}") -def initialize_expert_llm(provider: str = "openai", model_name: str = "o1") -> BaseChatModel: - """Initialize an expert language model client based on the specified provider and model. - Note: Environment variables must be validated before calling this function. - Use validate_environment() to ensure all required variables are set. +def initialize_llm( + provider: str, model_name: str, temperature: float | None = None +) -> BaseChatModel: + """Initialize a language model client based on the specified provider and model.""" + return create_llm_client(provider, model_name, temperature, is_expert=False) - Args: - provider: The LLM provider to use ('openai', 'anthropic', 'openrouter', 'openai-compatible', 'gemini'). - Defaults to 'openai'. - model_name: Name of the model to use. Defaults to 'o1'. - Returns: - BaseChatModel: Configured expert language model client - - Raises: - ValueError: If the provider is not supported - """ - if provider == "openai": - return ChatOpenAI( - api_key=os.getenv("EXPERT_OPENAI_API_KEY"), - model=model_name, - ) - elif provider == "anthropic": - return ChatAnthropic( - api_key=os.getenv("EXPERT_ANTHROPIC_API_KEY"), - model_name=model_name, - ) - elif provider == "openrouter": - return ChatOpenAI( - api_key=os.getenv("EXPERT_OPENROUTER_API_KEY"), - base_url="https://openrouter.ai/api/v1", - model=model_name, - ) - elif provider == "openai-compatible": - return ChatOpenAI( - api_key=os.getenv("EXPERT_OPENAI_API_KEY"), - base_url=os.getenv("EXPERT_OPENAI_API_BASE"), - model=model_name, - ) - elif provider == "gemini": - return ChatGoogleGenerativeAI( - api_key=os.getenv("EXPERT_GEMINI_API_KEY"), - model=model_name, - ) - else: - raise ValueError(f"Unsupported provider: {provider}") +def initialize_expert_llm( + provider: str = "openai", model_name: str = "o1" +) -> BaseChatModel: + """Initialize an expert language model client based on the specified provider and model.""" + return create_llm_client(provider, model_name, temperature=None, is_expert=True) diff --git a/ra_aid/models_tokens.py b/ra_aid/models_tokens.py index 9f32761..9307634 100644 --- a/ra_aid/models_tokens.py +++ b/ra_aid/models_tokens.py @@ -241,6 +241,10 @@ models_tokens = { "deepseek": { "deepseek-chat": 28672, "deepseek-coder": 16384, + "deepseek-reasoner": 65536, + }, + "openrouter": { + "deepseek/deepseek-r1": 65536, }, "ernie": { "ernie-bot-turbo": 4096, diff --git a/ra_aid/provider_strategy.py b/ra_aid/provider_strategy.py index 321fad2..8ac82aa 100644 --- a/ra_aid/provider_strategy.py +++ b/ra_aid/provider_strategy.py @@ -239,6 +239,31 @@ class GeminiStrategy(ProviderStrategy): return ValidationResult(valid=len(missing) == 0, missing_vars=missing) +class DeepSeekStrategy(ProviderStrategy): + """DeepSeek provider validation strategy.""" + + def validate(self, args: Optional[Any] = None) -> ValidationResult: + """Validate DeepSeek environment variables.""" + missing = [] + + if args and hasattr(args, 'expert_provider') and args.expert_provider == 'deepseek': + key = os.environ.get('EXPERT_DEEPSEEK_API_KEY') + if not key or key == '': + # Try to copy from base if not set + base_key = os.environ.get('DEEPSEEK_API_KEY') + if base_key: + os.environ['EXPERT_DEEPSEEK_API_KEY'] = base_key + key = base_key + if not key: + missing.append('EXPERT_DEEPSEEK_API_KEY environment variable is not set') + else: + key = os.environ.get('DEEPSEEK_API_KEY') + if not key: + missing.append('DEEPSEEK_API_KEY environment variable is not set') + + return ValidationResult(valid=len(missing) == 0, missing_vars=missing) + + class OllamaStrategy(ProviderStrategy): """Ollama provider validation strategy.""" @@ -272,7 +297,8 @@ class ProviderFactory: 'anthropic': AnthropicStrategy(), 'openrouter': OpenRouterStrategy(), 'gemini': GeminiStrategy(), - 'ollama': OllamaStrategy() + 'ollama': OllamaStrategy(), + 'deepseek': DeepSeekStrategy() } strategy = strategies.get(provider) return strategy diff --git a/ra_aid/tools/memory.py b/ra_aid/tools/memory.py index b3243fa..dc1ea2f 100644 --- a/ra_aid/tools/memory.py +++ b/ra_aid/tools/memory.py @@ -1,123 +1,147 @@ import os -import os from typing import Dict, List, Any, Union, Optional, Set from typing_extensions import TypedDict - -class WorkLogEntry(TypedDict): - timestamp: str - event: str from rich.console import Console from rich.markdown import Markdown from rich.panel import Panel from langchain_core.tools import tool + +class WorkLogEntry(TypedDict): + timestamp: str + event: str + + class SnippetInfo(TypedDict): """Type definition for source code snippet information""" + filepath: str line_number: int snippet: str description: Optional[str] + console = Console() # Global memory store -_global_memory: Dict[str, Union[List[Any], Dict[int, str], Dict[int, SnippetInfo], int, Set[str], bool, str, int, List[WorkLogEntry]]] = { - 'research_notes': [], - 'plans': [], - 'tasks': {}, # Dict[int, str] - ID to task mapping - 'task_completed': False, # Flag indicating if task is complete - 'completion_message': '', # Message explaining completion - 'task_id_counter': 1, # Counter for generating unique task IDs - 'key_facts': {}, # Dict[int, str] - ID to fact mapping - 'key_fact_id_counter': 1, # Counter for generating unique fact IDs - 'key_snippets': {}, # Dict[int, SnippetInfo] - ID to snippet mapping - 'key_snippet_id_counter': 1, # Counter for generating unique snippet IDs - 'implementation_requested': False, - 'related_files': {}, # Dict[int, str] - ID to filepath mapping - 'related_file_id_counter': 1, # Counter for generating unique file IDs - 'plan_completed': False, - 'agent_depth': 0, - 'work_log': [] # List[WorkLogEntry] - Timestamped work events +_global_memory: Dict[ + str, + Union[ + List[Any], + Dict[int, str], + Dict[int, SnippetInfo], + int, + Set[str], + bool, + str, + int, + List[WorkLogEntry], + ], +] = { + "research_notes": [], + "plans": [], + "tasks": {}, # Dict[int, str] - ID to task mapping + "task_completed": False, # Flag indicating if task is complete + "completion_message": "", # Message explaining completion + "task_id_counter": 1, # Counter for generating unique task IDs + "key_facts": {}, # Dict[int, str] - ID to fact mapping + "key_fact_id_counter": 1, # Counter for generating unique fact IDs + "key_snippets": {}, # Dict[int, SnippetInfo] - ID to snippet mapping + "key_snippet_id_counter": 1, # Counter for generating unique snippet IDs + "implementation_requested": False, + "related_files": {}, # Dict[int, str] - ID to filepath mapping + "related_file_id_counter": 1, # Counter for generating unique file IDs + "plan_completed": False, + "agent_depth": 0, + "work_log": [], # List[WorkLogEntry] - Timestamped work events } + @tool("emit_research_notes") def emit_research_notes(notes: str) -> str: """Store research notes in global memory. - + Args: notes: REQUIRED The research notes to store - + Returns: The stored notes """ - _global_memory['research_notes'].append(notes) + _global_memory["research_notes"].append(notes) console.print(Panel(Markdown(notes), title="🔍 Research Notes")) return notes + @tool("emit_plan") def emit_plan(plan: str) -> str: """Store a plan step in global memory. - + Args: plan: The plan step to store (markdown format; be clear, complete, use newlines, and use as many tokens as you need) - + Returns: The stored plan """ - _global_memory['plans'].append(plan) + _global_memory["plans"].append(plan) console.print(Panel(Markdown(plan), title="📋 Plan")) log_work_event(f"Added plan step:\n\n{plan}") return plan + @tool("emit_task") def emit_task(task: str) -> str: """Store a task in global memory. - + Args: task: The task to store - + Returns: String confirming task storage with ID number """ # Get and increment task ID - task_id = _global_memory['task_id_counter'] - _global_memory['task_id_counter'] += 1 - + task_id = _global_memory["task_id_counter"] + _global_memory["task_id_counter"] += 1 + # Store task with ID - _global_memory['tasks'][task_id] = task - + _global_memory["tasks"][task_id] = task + console.print(Panel(Markdown(task), title=f"✅ Task #{task_id}")) log_work_event(f"Task #{task_id} added:\n\n{task}") return f"Task #{task_id} stored." - @tool("emit_key_facts") def emit_key_facts(facts: List[str]) -> str: """Store multiple key facts about the project or current task in global memory. - + Args: facts: List of key facts to store - + Returns: List of stored fact confirmation messages """ results = [] for fact in facts: # Get and increment fact ID - fact_id = _global_memory['key_fact_id_counter'] - _global_memory['key_fact_id_counter'] += 1 - + fact_id = _global_memory["key_fact_id_counter"] + _global_memory["key_fact_id_counter"] += 1 + # Store fact with ID - _global_memory['key_facts'][fact_id] = fact - + _global_memory["key_facts"][fact_id] = fact + # Display panel with ID - console.print(Panel(Markdown(fact), title=f"💡 Key Fact #{fact_id}", border_style="bright_cyan")) - + console.print( + Panel( + Markdown(fact), + title=f"💡 Key Fact #{fact_id}", + border_style="bright_cyan", + ) + ) + # Add result message results.append(f"Stored fact #{fact_id}: {fact}") - - log_work_event(f"Stored {len(facts)} key facts.") + + log_work_event(f"Stored {len(facts)} key facts.") return "Facts stored." @@ -125,50 +149,54 @@ def emit_key_facts(facts: List[str]) -> str: def delete_key_facts(fact_ids: List[int]) -> str: """Delete multiple key facts from global memory by their IDs. Silently skips any IDs that don't exist. - + Args: fact_ids: List of fact IDs to delete - + Returns: List of success messages for deleted facts """ results = [] for fact_id in fact_ids: - if fact_id in _global_memory['key_facts']: + if fact_id in _global_memory["key_facts"]: # Delete the fact - deleted_fact = _global_memory['key_facts'].pop(fact_id) + deleted_fact = _global_memory["key_facts"].pop(fact_id) success_msg = f"Successfully deleted fact #{fact_id}: {deleted_fact}" - console.print(Panel(Markdown(success_msg), title="Fact Deleted", border_style="green")) + console.print( + Panel(Markdown(success_msg), title="Fact Deleted", border_style="green") + ) results.append(success_msg) - - log_work_event(f"Deleted facts {fact_ids}.") + + log_work_event(f"Deleted facts {fact_ids}.") return "Facts deleted." + @tool("delete_tasks") def delete_tasks(task_ids: List[int]) -> str: """Delete multiple tasks from global memory by their IDs. Silently skips any IDs that don't exist. - + Args: task_ids: List of task IDs to delete - + Returns: Confirmation message """ results = [] for task_id in task_ids: - if task_id in _global_memory['tasks']: + if task_id in _global_memory["tasks"]: # Delete the task - deleted_task = _global_memory['tasks'].pop(task_id) + deleted_task = _global_memory["tasks"].pop(task_id) success_msg = f"Successfully deleted task #{task_id}: {deleted_task}" - console.print(Panel(Markdown(success_msg), - title="Task Deleted", - border_style="green")) + console.print( + Panel(Markdown(success_msg), title="Task Deleted", border_style="green") + ) results.append(success_msg) - - log_work_event(f"Deleted tasks {task_ids}.") + + log_work_event(f"Deleted tasks {task_ids}.") return "Tasks deleted." + @tool("request_implementation") def request_implementation() -> str: """Request that implementation proceed after research/planning. @@ -178,46 +206,47 @@ def request_implementation() -> str: Do you need to request research subtasks first? Have you run relevant unit tests, if they exist, to get a baseline (this can be a subtask)? Do you need to crawl deeper to find all related files and symbols? - + Returns: Empty string """ - _global_memory['implementation_requested'] = True + _global_memory["implementation_requested"] = True console.print(Panel("🚀 Implementation Requested", style="yellow", padding=0)) log_work_event("Implementation requested.") return "" - @tool("emit_key_snippets") def emit_key_snippets(snippets: List[SnippetInfo]) -> str: """Store multiple key source code snippets in global memory. Automatically adds the filepaths of the snippets to related files. - + This is for **existing**, or **just-written** files, not for things to be created in the future. - + Args: snippets: REQUIRED List of snippet information dictionaries containing: - filepath: Path to the source file - - line_number: Line number where the snippet starts + - line_number: Line number where the snippet starts - snippet: The source code snippet text - description: Optional description of the significance - + Returns: List of stored snippet confirmation messages """ # First collect unique filepaths to add as related files - emit_related_files.invoke({"files": [snippet_info['filepath'] for snippet_info in snippets]}) + emit_related_files.invoke( + {"files": [snippet_info["filepath"] for snippet_info in snippets]} + ) results = [] for snippet_info in snippets: - # Get and increment snippet ID - snippet_id = _global_memory['key_snippet_id_counter'] - _global_memory['key_snippet_id_counter'] += 1 - + # Get and increment snippet ID + snippet_id = _global_memory["key_snippet_id_counter"] + _global_memory["key_snippet_id_counter"] += 1 + # Store snippet info - _global_memory['key_snippets'][snippet_id] = snippet_info - + _global_memory["key_snippets"][snippet_id] = snippet_info + # Format display text as markdown display_text = [ f"**Source Location**:", @@ -226,153 +255,170 @@ def emit_key_snippets(snippets: List[SnippetInfo]) -> str: "", # Empty line before code block "**Code**:", "```python", - snippet_info['snippet'].rstrip(), # Remove trailing whitespace - "```" + snippet_info["snippet"].rstrip(), # Remove trailing whitespace + "```", ] - if snippet_info['description']: - display_text.extend(["", "**Description**:", snippet_info['description']]) - + if snippet_info["description"]: + display_text.extend(["", "**Description**:", snippet_info["description"]]) + # Display panel - console.print(Panel(Markdown("\n".join(display_text)), - title=f"📝 Key Snippet #{snippet_id}", - border_style="bright_cyan")) - + console.print( + Panel( + Markdown("\n".join(display_text)), + title=f"📝 Key Snippet #{snippet_id}", + border_style="bright_cyan", + ) + ) + results.append(f"Stored snippet #{snippet_id}") - - log_work_event(f"Stored {len(snippets)} code snippets.") + + log_work_event(f"Stored {len(snippets)} code snippets.") return "Snippets stored." -@tool("delete_key_snippets") + +@tool("delete_key_snippets") def delete_key_snippets(snippet_ids: List[int]) -> str: """Delete multiple key snippets from global memory by their IDs. Silently skips any IDs that don't exist. - + Args: snippet_ids: List of snippet IDs to delete - + Returns: List of success messages for deleted snippets """ results = [] for snippet_id in snippet_ids: - if snippet_id in _global_memory['key_snippets']: + if snippet_id in _global_memory["key_snippets"]: # Delete the snippet - deleted_snippet = _global_memory['key_snippets'].pop(snippet_id) + deleted_snippet = _global_memory["key_snippets"].pop(snippet_id) success_msg = f"Successfully deleted snippet #{snippet_id} from {deleted_snippet['filepath']}" - console.print(Panel(Markdown(success_msg), - title="Snippet Deleted", - border_style="green")) + console.print( + Panel( + Markdown(success_msg), title="Snippet Deleted", border_style="green" + ) + ) results.append(success_msg) - - log_work_event(f"Deleted snippets {snippet_ids}.") + + log_work_event(f"Deleted snippets {snippet_ids}.") return "Snippets deleted." + @tool("swap_task_order") def swap_task_order(id1: int, id2: int) -> str: """Swap the order of two tasks in global memory by their IDs. - + Args: id1: First task ID id2: Second task ID - + Returns: Success or error message depending on outcome """ # Validate IDs are different if id1 == id2: return "Cannot swap task with itself" - + # Validate both IDs exist - if id1 not in _global_memory['tasks'] or id2 not in _global_memory['tasks']: + if id1 not in _global_memory["tasks"] or id2 not in _global_memory["tasks"]: return "Invalid task ID(s)" - + # Swap the tasks - _global_memory['tasks'][id1], _global_memory['tasks'][id2] = \ - _global_memory['tasks'][id2], _global_memory['tasks'][id1] - + _global_memory["tasks"][id1], _global_memory["tasks"][id2] = ( + _global_memory["tasks"][id2], + _global_memory["tasks"][id1], + ) + # Display what was swapped - console.print(Panel( - Markdown(f"Swapped:\n- Task #{id1} ↔️ Task #{id2}"), - title="🔄 Tasks Reordered", - border_style="green" - )) - + console.print( + Panel( + Markdown(f"Swapped:\n- Task #{id1} ↔️ Task #{id2}"), + title="🔄 Tasks Reordered", + border_style="green", + ) + ) + return "Tasks swapped." + @tool("one_shot_completed") def one_shot_completed(message: str) -> str: """Signal that a one-shot task has been completed and execution should stop. Only call this if you have already **fully** completed the original request. - + Args: message: Completion message to display """ - if _global_memory.get('implementation_requested', False): + if _global_memory.get("implementation_requested", False): return "Cannot complete in one shot - implementation was requested" - - _global_memory['task_completed'] = True - _global_memory['completion_message'] = message + + _global_memory["task_completed"] = True + _global_memory["completion_message"] = message console.print(Panel(Markdown(message), title="✅ Task Completed")) log_work_event(f"Task completed\n\n{message}") return "Completion noted." + @tool("task_completed") def task_completed(message: str) -> str: """Mark the current task as completed with a completion message. - + Args: message: Message explaining how/why the task is complete - + Returns: The completion message """ - _global_memory['task_completed'] = True - _global_memory['completion_message'] = message + _global_memory["task_completed"] = True + _global_memory["completion_message"] = message console.print(Panel(Markdown(message), title="✅ Task Completed")) return "Completion noted." + @tool("plan_implementation_completed") def plan_implementation_completed(message: str) -> str: """Mark the entire implementation plan as completed. - + Args: message: Message explaining how the implementation plan was completed - + Returns: Confirmation message """ - _global_memory['plan_completed'] = True - _global_memory['completion_message'] = message - _global_memory['tasks'].clear() # Clear task list when plan is completed - _global_memory['task_id_counter'] = 1 + _global_memory["plan_completed"] = True + _global_memory["completion_message"] = message + _global_memory["tasks"].clear() # Clear task list when plan is completed + _global_memory["task_id_counter"] = 1 console.print(Panel(Markdown(message), title="✅ Plan Executed")) log_work_event(f"Plan execution completed:\n\n{message}") return "Plan completion noted and task list cleared." + def get_related_files() -> List[str]: """Get the current list of related files. - + Returns: List of formatted strings in the format 'ID#X path/to/file.py' """ - files = _global_memory['related_files'] + files = _global_memory["related_files"] return [f"ID#{file_id} {filepath}" for file_id, filepath in sorted(files.items())] + @tool("emit_related_files") def emit_related_files(files: List[str]) -> str: """Store multiple related files that tools should work with. - + Args: files: List of file paths to add - + Returns: Formatted string containing file IDs and paths for all processed files """ results = [] added_files = [] invalid_paths = [] - + # Process files for file in files: # First check if path exists @@ -380,111 +426,115 @@ def emit_related_files(files: List[str]) -> str: invalid_paths.append(file) results.append(f"Error: Path '{file}' does not exist") continue - + # Then check if it's a directory if os.path.isdir(file): invalid_paths.append(file) results.append(f"Error: Path '{file}' is a directory, not a file") continue - + # Finally validate it's a regular file if not os.path.isfile(file): invalid_paths.append(file) results.append(f"Error: Path '{file}' exists but is not a regular file") continue - + # Check if file path already exists in values existing_id = None - for fid, fpath in _global_memory['related_files'].items(): + for fid, fpath in _global_memory["related_files"].items(): if fpath == file: existing_id = fid break - + if existing_id is not None: # File exists, use existing ID results.append(f"File ID #{existing_id}: {file}") else: # New file, assign new ID - file_id = _global_memory['related_file_id_counter'] - _global_memory['related_file_id_counter'] += 1 - + file_id = _global_memory["related_file_id_counter"] + _global_memory["related_file_id_counter"] += 1 + # Store file with ID - _global_memory['related_files'][file_id] = file + _global_memory["related_files"][file_id] = file added_files.append((file_id, file)) results.append(f"File ID #{file_id}: {file}") - + # Rich output - single consolidated panel if added_files: - files_added_md = '\n'.join(f"- `{file}`" for id, file in added_files) + files_added_md = "\n".join(f"- `{file}`" for id, file in added_files) md_content = f"**Files Noted:**\n{files_added_md}" - console.print(Panel(Markdown(md_content), - title="📁 Related Files Noted", - border_style="green")) - - return '\n'.join(results) + console.print( + Panel( + Markdown(md_content), + title="📁 Related Files Noted", + border_style="green", + ) + ) + + return "\n".join(results) def log_work_event(event: str) -> str: """Add timestamped entry to work log. - + Internal function used to track major events during agent execution. Each entry is stored with an ISO format timestamp. - + Args: event: Description of the event to log - + Returns: Confirmation message - + Note: Entries can be retrieved with get_work_log() as markdown formatted text. """ from datetime import datetime - entry = WorkLogEntry( - timestamp=datetime.now().isoformat(), - event=event - ) - _global_memory['work_log'].append(entry) + + entry = WorkLogEntry(timestamp=datetime.now().isoformat(), event=event) + _global_memory["work_log"].append(entry) return f"Event logged: {event}" def get_work_log() -> str: """Return formatted markdown of work log entries. - + Returns: Markdown formatted text with timestamps as headings and events as content, or 'No work log entries' if the log is empty. - + Example: ## 2024-12-23T11:39:10 Task #1 added: Create login form """ - if not _global_memory['work_log']: + if not _global_memory["work_log"]: return "No work log entries" - + entries = [] - for entry in _global_memory['work_log']: - entries.extend([ - f"## {entry['timestamp']}", - "", - entry['event'], - "" # Blank line between entries - ]) - + for entry in _global_memory["work_log"]: + entries.extend( + [ + f"## {entry['timestamp']}", + "", + entry["event"], + "", # Blank line between entries + ] + ) + return "\n".join(entries).rstrip() # Remove trailing newline def reset_work_log() -> str: """Clear the work log. - - Returns: + + Returns: Confirmation message - + Note: This permanently removes all work log entries. The operation cannot be undone. """ - _global_memory['work_log'].clear() + _global_memory["work_log"].clear() return "Work log cleared" @@ -492,37 +542,44 @@ def reset_work_log() -> str: def deregister_related_files(file_ids: List[int]) -> str: """Delete multiple related files from global memory by their IDs. Silently skips any IDs that don't exist. - + Args: file_ids: List of file IDs to delete - + Returns: Success message string """ results = [] for file_id in file_ids: - if file_id in _global_memory['related_files']: + if file_id in _global_memory["related_files"]: # Delete the file reference - deleted_file = _global_memory['related_files'].pop(file_id) - success_msg = f"Successfully removed related file #{file_id}: {deleted_file}" - console.print(Panel(Markdown(success_msg), - title="File Reference Removed", - border_style="green")) + deleted_file = _global_memory["related_files"].pop(file_id) + success_msg = ( + f"Successfully removed related file #{file_id}: {deleted_file}" + ) + console.print( + Panel( + Markdown(success_msg), + title="File Reference Removed", + border_style="green", + ) + ) results.append(success_msg) - + return "File references removed." + def get_memory_value(key: str) -> str: """Get a value from global memory. - + Different memory types return different formats: - key_facts: Returns numbered list of facts in format '#ID: fact' - key_snippets: Returns formatted snippets with file path, line number and content - All other types: Returns newline-separated list of values - + Args: key: The key to get from memory - + Returns: String representation of the memory values: - For key_facts: '#ID: fact' format, one per line @@ -530,23 +587,25 @@ def get_memory_value(key: str) -> str: - For other types: One value per line """ values = _global_memory.get(key, []) - - if key == 'key_facts': + + if key == "key_facts": # For empty dict, return empty string if not values: return "" # Sort by ID for consistent output and format as markdown sections facts = [] for k, v in sorted(values.items()): - facts.extend([ - f"## 🔑 Key Fact #{k}", - "", # Empty line for better markdown spacing - v, - "" # Empty line between facts - ]) + facts.extend( + [ + f"## 🔑 Key Fact #{k}", + "", # Empty line for better markdown spacing + v, + "", # Empty line between facts + ] + ) return "\n".join(facts).rstrip() # Remove trailing newline - - if key == 'key_snippets': + + if key == "key_snippets": if not values: return "" # Format each snippet with file info and content using markdown @@ -555,26 +614,25 @@ def get_memory_value(key: str) -> str: snippet_text = [ f"## 📝 Code Snippet #{k}", "", # Empty line for better markdown spacing - f"**Source Location**:", + "**Source Location**:", f"- File: `{v['filepath']}`", f"- Line: `{v['line_number']}`", "", # Empty line before code block "**Code**:", "```python", - v['snippet'].rstrip(), # Remove trailing whitespace - "```" + v["snippet"].rstrip(), # Remove trailing whitespace + "```", ] - if v['description']: + if v["description"]: # Add empty line and description - snippet_text.extend(["", "**Description**:", v['description']]) + snippet_text.extend(["", "**Description**:", v["description"]]) snippets.append("\n".join(snippet_text)) return "\n\n".join(snippets) - - if key == 'work_log': + + if key == "work_log": if not values: return "" - entries = [f"## {entry['timestamp']}\n{entry['event']}" - for entry in values] + entries = [f"## {entry['timestamp']}\n{entry['event']}" for entry in values] return "\n\n".join(entries) # For other types (lists), join with newlines diff --git a/tests/ra_aid/test_llm.py b/tests/ra_aid/test_llm.py index 4a5efad..403b2d1 100644 --- a/tests/ra_aid/test_llm.py +++ b/tests/ra_aid/test_llm.py @@ -9,7 +9,13 @@ from dataclasses import dataclass from ra_aid.agents.ciayn_agent import CiaynAgent from ra_aid.env import validate_environment -from ra_aid.llm import initialize_llm, initialize_expert_llm +from ra_aid.llm import ( + initialize_llm, + initialize_expert_llm, + get_env_var, + get_provider_config, + create_llm_client +) @pytest.fixture def clean_env(monkeypatch): @@ -39,7 +45,8 @@ def test_initialize_expert_defaults(clean_env, mock_openai, monkeypatch): mock_openai.assert_called_once_with( api_key="test-key", - model="o1" + model="o1", + temperature=0 ) def test_initialize_expert_openai_custom(clean_env, mock_openai, monkeypatch): @@ -49,7 +56,8 @@ def test_initialize_expert_openai_custom(clean_env, mock_openai, monkeypatch): mock_openai.assert_called_once_with( api_key="test-key", - model="gpt-4-preview" + model="gpt-4-preview", + temperature=0 ) def test_initialize_expert_gemini(clean_env, mock_gemini, monkeypatch): @@ -59,7 +67,8 @@ def test_initialize_expert_gemini(clean_env, mock_gemini, monkeypatch): mock_gemini.assert_called_once_with( api_key="test-key", - model="gemini-2.0-flash-thinking-exp-1219" + model="gemini-2.0-flash-thinking-exp-1219", + temperature=0 ) def test_initialize_expert_anthropic(clean_env, mock_anthropic, monkeypatch): @@ -69,7 +78,8 @@ def test_initialize_expert_anthropic(clean_env, mock_anthropic, monkeypatch): mock_anthropic.assert_called_once_with( api_key="test-key", - model_name="claude-3" + model_name="claude-3", + temperature=0 ) def test_initialize_expert_openrouter(clean_env, mock_openai, monkeypatch): @@ -80,7 +90,8 @@ def test_initialize_expert_openrouter(clean_env, mock_openai, monkeypatch): mock_openai.assert_called_once_with( api_key="test-key", base_url="https://openrouter.ai/api/v1", - model="models/mistral-large" + model="models/mistral-large", + temperature=0 ) def test_initialize_expert_openai_compatible(clean_env, mock_openai, monkeypatch): @@ -92,7 +103,8 @@ def test_initialize_expert_openai_compatible(clean_env, mock_openai, monkeypatch mock_openai.assert_called_once_with( api_key="test-key", base_url="http://test-url", - model="local-model" + model="local-model", + temperature=0 ) def test_initialize_expert_unsupported_provider(clean_env): @@ -311,45 +323,49 @@ def test_initialize_llm_cross_provider(clean_env, mock_openai, mock_anthropic, m model="gemini-2.0-flash-thinking-exp-1219" ) +@dataclass +class Args: + """Test arguments class.""" + provider: str + expert_provider: str + model: str = None + expert_model: str = None + def test_environment_variable_precedence(clean_env, mock_openai, monkeypatch): """Test environment variable precedence and fallback.""" - from ra_aid.env import validate_environment - from dataclasses import dataclass + # Test get_env_var helper with fallback + monkeypatch.setenv("TEST_KEY", "base-value") + monkeypatch.setenv("EXPERT_TEST_KEY", "expert-value") - @dataclass - class Args: - provider: str - expert_provider: str - model: str = None - expert_model: str = None - - # Test expert mode with explicit key - # Set up base environment first - monkeypatch.setenv("OPENAI_API_KEY", "base-key") + assert get_env_var("TEST_KEY") == "base-value" + assert get_env_var("TEST_KEY", expert=True) == "expert-value" + + # Test fallback when expert value not set + monkeypatch.delenv("EXPERT_TEST_KEY", raising=False) + assert get_env_var("TEST_KEY", expert=True) == "base-value" + + # Test provider config monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "expert-key") - monkeypatch.setenv("TAVILY_API_KEY", "tavily-key") - monkeypatch.setenv("GEMINI_API_KEY", "gemini-key") - args = Args(provider="openai", expert_provider="openai") - expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args) - assert expert_enabled - assert not expert_missing - assert web_enabled - assert not web_missing - - llm = initialize_expert_llm() + config = get_provider_config("openai", is_expert=True) + assert config["api_key"] == "expert-key" + + # Test LLM client creation with expert mode + llm = create_llm_client("openai", "o1", is_expert=True) mock_openai.assert_called_with( api_key="expert-key", - model="o1" + model="o1", + temperature=0 ) - # Test empty key validation + # Test environment validation monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "") - monkeypatch.delenv("OPENAI_API_KEY", raising=False) # Remove fallback - monkeypatch.delenv("TAVILY_API_KEY", raising=False) # Remove web research - monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key") # Add for provider validation - monkeypatch.setenv("GEMINI_API_KEY", "gemini-key") # Add for provider validation - monkeypatch.setenv("ANTHROPIC_MODEL", "claude-3-haiku-20240307") # Add model for provider validation - args = Args(provider="anthropic", expert_provider="openai") # Change base provider to avoid validation error + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.delenv("TAVILY_API_KEY", raising=False) + monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key") + monkeypatch.setenv("GEMINI_API_KEY", "gemini-key") + monkeypatch.setenv("ANTHROPIC_MODEL", "claude-3-haiku-20240307") + + args = Args(provider="anthropic", expert_provider="openai") expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args) assert not expert_enabled assert expert_missing @@ -372,3 +388,122 @@ def mock_gemini(): with patch('ra_aid.llm.ChatGoogleGenerativeAI') as mock: mock.return_value = Mock(spec=ChatGoogleGenerativeAI) yield mock + +@pytest.fixture +def mock_deepseek_reasoner(): + """Mock ChatDeepseekReasoner for testing DeepSeek provider initialization.""" + with patch('ra_aid.llm.ChatDeepseekReasoner') as mock: + mock.return_value = Mock() + yield mock + +def test_initialize_deepseek(clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch): + """Test DeepSeek provider initialization with different models.""" + monkeypatch.setenv("DEEPSEEK_API_KEY", "test-key") + + # Test with reasoner model + model = initialize_llm("deepseek", "deepseek-reasoner") + mock_deepseek_reasoner.assert_called_with( + api_key="test-key", + base_url="https://api.deepseek.com", + temperature=1, + model="deepseek-reasoner" + ) + + # Test with non-reasoner model + model = initialize_llm("deepseek", "deepseek-chat") + mock_openai.assert_called_with( + api_key="test-key", + base_url="https://api.deepseek.com", + temperature=1, + model="deepseek-chat" + ) + +def test_initialize_expert_deepseek(clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch): + """Test expert DeepSeek provider initialization.""" + monkeypatch.setenv("EXPERT_DEEPSEEK_API_KEY", "test-key") + + # Test with reasoner model + model = initialize_expert_llm("deepseek", "deepseek-reasoner") + mock_deepseek_reasoner.assert_called_with( + api_key="test-key", + base_url="https://api.deepseek.com", + temperature=0, + model="deepseek-reasoner" + ) + + # Test with non-reasoner model + model = initialize_expert_llm("deepseek", "deepseek-chat") + mock_openai.assert_called_with( + api_key="test-key", + base_url="https://api.deepseek.com", + temperature=0, + model="deepseek-chat" + ) + +def test_initialize_openrouter_deepseek(clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch): + """Test OpenRouter DeepSeek model initialization.""" + monkeypatch.setenv("OPENROUTER_API_KEY", "test-key") + + # Test with DeepSeek R1 model + model = initialize_llm("openrouter", "deepseek/deepseek-r1") + mock_deepseek_reasoner.assert_called_with( + api_key="test-key", + base_url="https://openrouter.ai/api/v1", + temperature=1, + model="deepseek/deepseek-r1" + ) + + # Test with non-DeepSeek model + model = initialize_llm("openrouter", "mistral/mistral-large") + mock_openai.assert_called_with( + api_key="test-key", + base_url="https://openrouter.ai/api/v1", + model="mistral/mistral-large" + ) + +def test_initialize_expert_openrouter_deepseek(clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch): + """Test expert OpenRouter DeepSeek model initialization.""" + monkeypatch.setenv("EXPERT_OPENROUTER_API_KEY", "test-key") + + # Test with DeepSeek R1 model via create_llm_client + model = create_llm_client("openrouter", "deepseek/deepseek-r1", is_expert=True) + mock_deepseek_reasoner.assert_called_with( + api_key="test-key", + base_url="https://openrouter.ai/api/v1", + temperature=0, + model="deepseek/deepseek-r1" + ) + + # Test with non-DeepSeek model + model = create_llm_client("openrouter", "mistral/mistral-large", is_expert=True) + mock_openai.assert_called_with( + api_key="test-key", + base_url="https://openrouter.ai/api/v1", + model="mistral/mistral-large", + temperature=0 + ) + +def test_deepseek_environment_fallback(clean_env, mock_deepseek_reasoner, monkeypatch): + """Test DeepSeek environment variable fallback behavior.""" + # Test environment variable helper with fallback + monkeypatch.setenv("DEEPSEEK_API_KEY", "base-key") + assert get_env_var("DEEPSEEK_API_KEY", expert=True) == "base-key" + + # Test provider config with fallback + config = get_provider_config("deepseek", is_expert=True) + assert config["api_key"] == "base-key" + assert config["base_url"] == "https://api.deepseek.com" + + # Test with expert key + monkeypatch.setenv("EXPERT_DEEPSEEK_API_KEY", "expert-key") + config = get_provider_config("deepseek", is_expert=True) + assert config["api_key"] == "expert-key" + + # Test client creation with expert key + model = create_llm_client("deepseek", "deepseek-reasoner", is_expert=True) + mock_deepseek_reasoner.assert_called_with( + api_key="expert-key", + base_url="https://api.deepseek.com", + temperature=0, + model="deepseek-reasoner" + )