Add Deepseek Provider Support and Custom Deepseek Reasoner Chat Model (#50)
* chore: Add DeepSeek provider environment variable support in env.py * feat: Add DeepSeek provider validation strategy in provider_strategy.py * feat: Add support for DEEPSEEK provider in initialize_llm function * feat: Create ChatDeepseekReasoner for custom handling of R1 models * feat: Configure custom OpenAI client for DeepSeek API integration * chore: Remove unused json import from deepseek_chat.py * refactor: Simplify invocation_params and update acompletion_with_retry method * feat: Override _generate to ensure message alternation in DeepseekReasoner * feat: Add support for ChatDeepseekReasoner in LLM initialization * feat: Use custom ChatDeepseekReasoner for DeepSeek models in OpenRouter * fix: Remove redundant condition for DeepSeek model initialization * feat: Add DeepSeek support for expert model initialization in llm.py * feat: Add DeepSeek model handling for OpenRouter in expert LLM initialization * fix: Update model name checks for DeepSeek and OpenRouter providers * refactor: Extract common logic for LLM initialization into reusable methods * test: Add unit tests for DeepSeek and OpenRouter functionality * test: Refactor tests to match updated LLM initialization and helpers * fix: Import missing helper functions to resolve NameError in tests * fix: Resolve NameError and improve environment variable fallback logic * feat(readme): add DeepSeek API key requirements to documentation for better clarity on environment variables feat(main.py): include DeepSeek as a supported provider in argument parsing for enhanced functionality feat(deepseek_chat.py): implement ChatDeepseekReasoner class for handling DeepSeek reasoning models feat(llm.py): add DeepSeek client creation logic to support DeepSeek models in the application feat(models_tokens.py): define token limits for DeepSeek models to manage resource allocation fix(provider_strategy.py): correct validation logic for DeepSeek environment variables to ensure proper configuration chore(memory.py): refactor global memory structure for better readability and maintainability in the codebase
This commit is contained in:
parent
7a68de2d06
commit
686ab42f88
15
README.md
15
README.md
|
|
@ -293,6 +293,7 @@ RA.Aid supports multiple providers through environment variables:
|
|||
- `ANTHROPIC_API_KEY`: Required for the default Anthropic provider
|
||||
- `OPENAI_API_KEY`: Required for OpenAI provider
|
||||
- `OPENROUTER_API_KEY`: Required for OpenRouter provider
|
||||
- `DEEPSEEK_API_KEY`: Required for DeepSeek provider
|
||||
- `OPENAI_API_BASE`: Required for OpenAI-compatible providers along with `OPENAI_API_KEY`
|
||||
- `GEMINI_API_KEY`: Required for Gemini provider
|
||||
|
||||
|
|
@ -302,6 +303,7 @@ Expert Tool Environment Variables:
|
|||
- `EXPERT_OPENROUTER_API_KEY`: API key for expert tool using OpenRouter provider
|
||||
- `EXPERT_OPENAI_API_BASE`: Base URL for expert tool using OpenAI-compatible provider
|
||||
- `EXPERT_GEMINI_API_KEY`: API key for expert tool using Gemini provider
|
||||
- `EXPERT_DEEPSEEK_API_KEY`: API key for expert tool using DeepSeek provider
|
||||
|
||||
You can set these permanently in your shell's configuration file (e.g., `~/.bashrc` or `~/.zshrc`):
|
||||
|
||||
|
|
@ -343,6 +345,15 @@ export GEMINI_API_KEY=your_api_key_here
|
|||
ra-aid -m "Your task" --provider openrouter --model mistralai/mistral-large-2411
|
||||
```
|
||||
|
||||
4. **Using DeepSeek**
|
||||
```bash
|
||||
# Direct DeepSeek provider (requires DEEPSEEK_API_KEY)
|
||||
ra-aid -m "Your task" --provider deepseek --model deepseek-reasoner
|
||||
|
||||
# DeepSeek via OpenRouter
|
||||
ra-aid -m "Your task" --provider openrouter --model deepseek/deepseek-r1
|
||||
```
|
||||
|
||||
4. **Configuring Expert Provider**
|
||||
|
||||
The expert tool is used by the agent for complex logic and debugging tasks. It can be configured to use different providers (OpenAI, Anthropic, OpenRouter, Gemini, openai-compatible) using the --expert-provider flag along with the corresponding EXPERT_*API_KEY environment variables.
|
||||
|
|
@ -356,6 +367,10 @@ export GEMINI_API_KEY=your_api_key_here
|
|||
export OPENROUTER_API_KEY=your_openrouter_api_key
|
||||
ra-aid -m "Your task" --expert-provider openrouter --expert-model mistralai/mistral-large-2411
|
||||
|
||||
# Use DeepSeek for expert tool
|
||||
export DEEPSEEK_API_KEY=your_deepseek_api_key
|
||||
ra-aid -m "Your task" --expert-provider deepseek --expert-model deepseek-reasoner
|
||||
|
||||
# Use default OpenAI for expert tool
|
||||
export EXPERT_OPENAI_API_KEY=your_openai_api_key
|
||||
ra-aid -m "Your task" --expert-provider openai --expert-model o1
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ def parse_arguments(args=None):
|
|||
"openai",
|
||||
"openrouter",
|
||||
"openai-compatible",
|
||||
"deepseek",
|
||||
"gemini",
|
||||
]
|
||||
ANTHROPIC_DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,52 @@
|
|||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import ChatResult
|
||||
from langchain_openai import ChatOpenAI
|
||||
from typing import Any, List, Optional, Dict
|
||||
|
||||
|
||||
# Docs: https://api-docs.deepseek.com/guides/reasoning_model
|
||||
class ChatDeepseekReasoner(ChatOpenAI):
|
||||
"""ChatDeepseekReasoner with custom overrides for R1/reasoner models."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def invocation_params(
|
||||
self, options: Optional[Dict[str, Any]] = None, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
params = super().invocation_params(options, **kwargs)
|
||||
|
||||
# Remove unsupported params for R1 models
|
||||
params.pop("temperature", None)
|
||||
params.pop("top_p", None)
|
||||
params.pop("presence_penalty", None)
|
||||
params.pop("frequency_penalty", None)
|
||||
|
||||
return params
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Override _generate to ensure message alternation in accordance to Deepseek API."""
|
||||
|
||||
processed = []
|
||||
prev_role = None
|
||||
|
||||
for msg in messages:
|
||||
current_role = "user" if msg.type == "human" else "assistant"
|
||||
|
||||
if prev_role == current_role:
|
||||
if processed:
|
||||
processed[-1].content += f"\n\n{msg.content}"
|
||||
else:
|
||||
processed.append(msg)
|
||||
prev_role = current_role
|
||||
|
||||
return super()._generate(
|
||||
processed, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
|
|
@ -50,6 +50,9 @@ def copy_base_to_expert_vars(base_provider: str, expert_provider: str) -> None:
|
|||
'gemini': {
|
||||
'GEMINI_API_KEY': 'EXPERT_GEMINI_API_KEY',
|
||||
'GEMINI_MODEL': 'EXPERT_GEMINI_MODEL'
|
||||
},
|
||||
'deepseek': {
|
||||
'DEEPSEEK_API_KEY': 'EXPERT_DEEPSEEK_API_KEY'
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
217
ra_aid/llm.py
217
ra_aid/llm.py
|
|
@ -1,107 +1,184 @@
|
|||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
|
||||
|
||||
|
||||
def get_env_var(name: str, expert: bool = False) -> Optional[str]:
|
||||
"""Get environment variable with optional expert prefix and fallback."""
|
||||
prefix = "EXPERT_" if expert else ""
|
||||
value = os.getenv(f"{prefix}{name}")
|
||||
|
||||
def initialize_llm(provider: str, model_name: str, temperature: float | None = None) -> BaseChatModel:
|
||||
"""Initialize a language model client based on the specified provider and model.
|
||||
# If expert mode and no expert value, fall back to base value
|
||||
if expert and not value:
|
||||
value = os.getenv(name)
|
||||
|
||||
Note: Environment variables must be validated before calling this function.
|
||||
Use validate_environment() to ensure all required variables are set.
|
||||
return value
|
||||
|
||||
|
||||
def create_deepseek_client(
|
||||
model_name: str,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
temperature: Optional[float] = None,
|
||||
is_expert: bool = False,
|
||||
) -> BaseChatModel:
|
||||
"""Create DeepSeek client with appropriate configuration."""
|
||||
if model_name.lower() == "deepseek-reasoner":
|
||||
return ChatDeepseekReasoner(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
temperature=0
|
||||
if is_expert
|
||||
else (temperature if temperature is not None else 1),
|
||||
model=model_name,
|
||||
)
|
||||
|
||||
return ChatOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
temperature=0 if is_expert else (temperature if temperature is not None else 1),
|
||||
model=model_name,
|
||||
)
|
||||
|
||||
|
||||
def create_openrouter_client(
|
||||
model_name: str,
|
||||
api_key: str,
|
||||
temperature: Optional[float] = None,
|
||||
is_expert: bool = False,
|
||||
) -> BaseChatModel:
|
||||
"""Create OpenRouter client with appropriate configuration."""
|
||||
if model_name.startswith("deepseek/") and "deepseek-r1" in model_name.lower():
|
||||
return ChatDeepseekReasoner(
|
||||
api_key=api_key,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
temperature=0
|
||||
if is_expert
|
||||
else (temperature if temperature is not None else 1),
|
||||
model=model_name,
|
||||
)
|
||||
|
||||
return ChatOpenAI(
|
||||
api_key=api_key,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model=model_name,
|
||||
**({"temperature": temperature} if temperature is not None else {}),
|
||||
)
|
||||
|
||||
|
||||
def get_provider_config(provider: str, is_expert: bool = False) -> Dict[str, Any]:
|
||||
"""Get provider-specific configuration."""
|
||||
configs = {
|
||||
"openai": {
|
||||
"api_key": get_env_var("OPENAI_API_KEY", is_expert),
|
||||
"base_url": None,
|
||||
},
|
||||
"anthropic": {
|
||||
"api_key": get_env_var("ANTHROPIC_API_KEY", is_expert),
|
||||
"base_url": None,
|
||||
},
|
||||
"openrouter": {
|
||||
"api_key": get_env_var("OPENROUTER_API_KEY", is_expert),
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
},
|
||||
"openai-compatible": {
|
||||
"api_key": get_env_var("OPENAI_API_KEY", is_expert),
|
||||
"base_url": get_env_var("OPENAI_API_BASE", is_expert),
|
||||
},
|
||||
"gemini": {
|
||||
"api_key": get_env_var("GEMINI_API_KEY", is_expert),
|
||||
"base_url": None,
|
||||
},
|
||||
"deepseek": {
|
||||
"api_key": get_env_var("DEEPSEEK_API_KEY", is_expert),
|
||||
"base_url": "https://api.deepseek.com",
|
||||
},
|
||||
}
|
||||
return configs.get(provider, {})
|
||||
|
||||
|
||||
def create_llm_client(
|
||||
provider: str,
|
||||
model_name: str,
|
||||
temperature: Optional[float] = None,
|
||||
is_expert: bool = False,
|
||||
) -> BaseChatModel:
|
||||
"""Create a language model client with appropriate configuration.
|
||||
|
||||
Args:
|
||||
provider: The LLM provider to use ('openai', 'anthropic', 'openrouter', 'openai-compatible', 'gemini')
|
||||
provider: The LLM provider to use
|
||||
model_name: Name of the model to use
|
||||
temperature: Optional temperature setting for controlling randomness (0.0-2.0).
|
||||
If not specified, provider-specific defaults are used.
|
||||
temperature: Optional temperature setting (0.0-2.0)
|
||||
is_expert: Whether this is an expert model (uses deterministic output)
|
||||
|
||||
Returns:
|
||||
BaseChatModel: Configured language model client
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider is not supported
|
||||
Configured language model client
|
||||
"""
|
||||
if provider == "openai":
|
||||
config = get_provider_config(provider, is_expert)
|
||||
if not config:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
||||
# Handle temperature for expert mode
|
||||
if is_expert:
|
||||
temperature = 0
|
||||
|
||||
if provider == "deepseek":
|
||||
return create_deepseek_client(
|
||||
model_name=model_name,
|
||||
api_key=config["api_key"],
|
||||
base_url=config["base_url"],
|
||||
temperature=temperature,
|
||||
is_expert=is_expert,
|
||||
)
|
||||
elif provider == "openrouter":
|
||||
return create_openrouter_client(
|
||||
model_name=model_name,
|
||||
api_key=config["api_key"],
|
||||
temperature=temperature,
|
||||
is_expert=is_expert,
|
||||
)
|
||||
elif provider == "openai":
|
||||
return ChatOpenAI(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
api_key=config["api_key"],
|
||||
model=model_name,
|
||||
**({"temperature": temperature} if temperature is not None else {})
|
||||
**({"temperature": temperature} if temperature is not None else {}),
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
return ChatAnthropic(
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
api_key=config["api_key"],
|
||||
model_name=model_name,
|
||||
**({"temperature": temperature} if temperature is not None else {})
|
||||
)
|
||||
elif provider == "openrouter":
|
||||
return ChatOpenAI(
|
||||
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model=model_name,
|
||||
**({"temperature": temperature} if temperature is not None else {})
|
||||
**({"temperature": temperature} if temperature is not None else {}),
|
||||
)
|
||||
elif provider == "openai-compatible":
|
||||
return ChatOpenAI(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
base_url=os.getenv("OPENAI_API_BASE"),
|
||||
api_key=config["api_key"],
|
||||
base_url=config["base_url"],
|
||||
temperature=temperature if temperature is not None else 0.3,
|
||||
model=model_name,
|
||||
)
|
||||
elif provider == "gemini":
|
||||
return ChatGoogleGenerativeAI(
|
||||
api_key=os.getenv("GEMINI_API_KEY"),
|
||||
api_key=config["api_key"],
|
||||
model=model_name,
|
||||
**({"temperature": temperature} if temperature is not None else {})
|
||||
**({"temperature": temperature} if temperature is not None else {}),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
||||
def initialize_expert_llm(provider: str = "openai", model_name: str = "o1") -> BaseChatModel:
|
||||
"""Initialize an expert language model client based on the specified provider and model.
|
||||
|
||||
Note: Environment variables must be validated before calling this function.
|
||||
Use validate_environment() to ensure all required variables are set.
|
||||
def initialize_llm(
|
||||
provider: str, model_name: str, temperature: float | None = None
|
||||
) -> BaseChatModel:
|
||||
"""Initialize a language model client based on the specified provider and model."""
|
||||
return create_llm_client(provider, model_name, temperature, is_expert=False)
|
||||
|
||||
Args:
|
||||
provider: The LLM provider to use ('openai', 'anthropic', 'openrouter', 'openai-compatible', 'gemini').
|
||||
Defaults to 'openai'.
|
||||
model_name: Name of the model to use. Defaults to 'o1'.
|
||||
|
||||
Returns:
|
||||
BaseChatModel: Configured expert language model client
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider is not supported
|
||||
"""
|
||||
if provider == "openai":
|
||||
return ChatOpenAI(
|
||||
api_key=os.getenv("EXPERT_OPENAI_API_KEY"),
|
||||
model=model_name,
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
return ChatAnthropic(
|
||||
api_key=os.getenv("EXPERT_ANTHROPIC_API_KEY"),
|
||||
model_name=model_name,
|
||||
)
|
||||
elif provider == "openrouter":
|
||||
return ChatOpenAI(
|
||||
api_key=os.getenv("EXPERT_OPENROUTER_API_KEY"),
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model=model_name,
|
||||
)
|
||||
elif provider == "openai-compatible":
|
||||
return ChatOpenAI(
|
||||
api_key=os.getenv("EXPERT_OPENAI_API_KEY"),
|
||||
base_url=os.getenv("EXPERT_OPENAI_API_BASE"),
|
||||
model=model_name,
|
||||
)
|
||||
elif provider == "gemini":
|
||||
return ChatGoogleGenerativeAI(
|
||||
api_key=os.getenv("EXPERT_GEMINI_API_KEY"),
|
||||
model=model_name,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
def initialize_expert_llm(
|
||||
provider: str = "openai", model_name: str = "o1"
|
||||
) -> BaseChatModel:
|
||||
"""Initialize an expert language model client based on the specified provider and model."""
|
||||
return create_llm_client(provider, model_name, temperature=None, is_expert=True)
|
||||
|
|
|
|||
|
|
@ -241,6 +241,10 @@ models_tokens = {
|
|||
"deepseek": {
|
||||
"deepseek-chat": 28672,
|
||||
"deepseek-coder": 16384,
|
||||
"deepseek-reasoner": 65536,
|
||||
},
|
||||
"openrouter": {
|
||||
"deepseek/deepseek-r1": 65536,
|
||||
},
|
||||
"ernie": {
|
||||
"ernie-bot-turbo": 4096,
|
||||
|
|
|
|||
|
|
@ -239,6 +239,31 @@ class GeminiStrategy(ProviderStrategy):
|
|||
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
||||
|
||||
|
||||
class DeepSeekStrategy(ProviderStrategy):
|
||||
"""DeepSeek provider validation strategy."""
|
||||
|
||||
def validate(self, args: Optional[Any] = None) -> ValidationResult:
|
||||
"""Validate DeepSeek environment variables."""
|
||||
missing = []
|
||||
|
||||
if args and hasattr(args, 'expert_provider') and args.expert_provider == 'deepseek':
|
||||
key = os.environ.get('EXPERT_DEEPSEEK_API_KEY')
|
||||
if not key or key == '':
|
||||
# Try to copy from base if not set
|
||||
base_key = os.environ.get('DEEPSEEK_API_KEY')
|
||||
if base_key:
|
||||
os.environ['EXPERT_DEEPSEEK_API_KEY'] = base_key
|
||||
key = base_key
|
||||
if not key:
|
||||
missing.append('EXPERT_DEEPSEEK_API_KEY environment variable is not set')
|
||||
else:
|
||||
key = os.environ.get('DEEPSEEK_API_KEY')
|
||||
if not key:
|
||||
missing.append('DEEPSEEK_API_KEY environment variable is not set')
|
||||
|
||||
return ValidationResult(valid=len(missing) == 0, missing_vars=missing)
|
||||
|
||||
|
||||
class OllamaStrategy(ProviderStrategy):
|
||||
"""Ollama provider validation strategy."""
|
||||
|
||||
|
|
@ -272,7 +297,8 @@ class ProviderFactory:
|
|||
'anthropic': AnthropicStrategy(),
|
||||
'openrouter': OpenRouterStrategy(),
|
||||
'gemini': GeminiStrategy(),
|
||||
'ollama': OllamaStrategy()
|
||||
'ollama': OllamaStrategy(),
|
||||
'deepseek': DeepSeekStrategy()
|
||||
}
|
||||
strategy = strategies.get(provider)
|
||||
return strategy
|
||||
|
|
|
|||
|
|
@ -1,123 +1,147 @@
|
|||
import os
|
||||
import os
|
||||
from typing import Dict, List, Any, Union, Optional, Set
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
class WorkLogEntry(TypedDict):
|
||||
timestamp: str
|
||||
event: str
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
class WorkLogEntry(TypedDict):
|
||||
timestamp: str
|
||||
event: str
|
||||
|
||||
|
||||
class SnippetInfo(TypedDict):
|
||||
"""Type definition for source code snippet information"""
|
||||
|
||||
filepath: str
|
||||
line_number: int
|
||||
snippet: str
|
||||
description: Optional[str]
|
||||
|
||||
|
||||
console = Console()
|
||||
|
||||
# Global memory store
|
||||
_global_memory: Dict[str, Union[List[Any], Dict[int, str], Dict[int, SnippetInfo], int, Set[str], bool, str, int, List[WorkLogEntry]]] = {
|
||||
'research_notes': [],
|
||||
'plans': [],
|
||||
'tasks': {}, # Dict[int, str] - ID to task mapping
|
||||
'task_completed': False, # Flag indicating if task is complete
|
||||
'completion_message': '', # Message explaining completion
|
||||
'task_id_counter': 1, # Counter for generating unique task IDs
|
||||
'key_facts': {}, # Dict[int, str] - ID to fact mapping
|
||||
'key_fact_id_counter': 1, # Counter for generating unique fact IDs
|
||||
'key_snippets': {}, # Dict[int, SnippetInfo] - ID to snippet mapping
|
||||
'key_snippet_id_counter': 1, # Counter for generating unique snippet IDs
|
||||
'implementation_requested': False,
|
||||
'related_files': {}, # Dict[int, str] - ID to filepath mapping
|
||||
'related_file_id_counter': 1, # Counter for generating unique file IDs
|
||||
'plan_completed': False,
|
||||
'agent_depth': 0,
|
||||
'work_log': [] # List[WorkLogEntry] - Timestamped work events
|
||||
_global_memory: Dict[
|
||||
str,
|
||||
Union[
|
||||
List[Any],
|
||||
Dict[int, str],
|
||||
Dict[int, SnippetInfo],
|
||||
int,
|
||||
Set[str],
|
||||
bool,
|
||||
str,
|
||||
int,
|
||||
List[WorkLogEntry],
|
||||
],
|
||||
] = {
|
||||
"research_notes": [],
|
||||
"plans": [],
|
||||
"tasks": {}, # Dict[int, str] - ID to task mapping
|
||||
"task_completed": False, # Flag indicating if task is complete
|
||||
"completion_message": "", # Message explaining completion
|
||||
"task_id_counter": 1, # Counter for generating unique task IDs
|
||||
"key_facts": {}, # Dict[int, str] - ID to fact mapping
|
||||
"key_fact_id_counter": 1, # Counter for generating unique fact IDs
|
||||
"key_snippets": {}, # Dict[int, SnippetInfo] - ID to snippet mapping
|
||||
"key_snippet_id_counter": 1, # Counter for generating unique snippet IDs
|
||||
"implementation_requested": False,
|
||||
"related_files": {}, # Dict[int, str] - ID to filepath mapping
|
||||
"related_file_id_counter": 1, # Counter for generating unique file IDs
|
||||
"plan_completed": False,
|
||||
"agent_depth": 0,
|
||||
"work_log": [], # List[WorkLogEntry] - Timestamped work events
|
||||
}
|
||||
|
||||
|
||||
@tool("emit_research_notes")
|
||||
def emit_research_notes(notes: str) -> str:
|
||||
"""Store research notes in global memory.
|
||||
|
||||
|
||||
Args:
|
||||
notes: REQUIRED The research notes to store
|
||||
|
||||
|
||||
Returns:
|
||||
The stored notes
|
||||
"""
|
||||
_global_memory['research_notes'].append(notes)
|
||||
_global_memory["research_notes"].append(notes)
|
||||
console.print(Panel(Markdown(notes), title="🔍 Research Notes"))
|
||||
return notes
|
||||
|
||||
|
||||
@tool("emit_plan")
|
||||
def emit_plan(plan: str) -> str:
|
||||
"""Store a plan step in global memory.
|
||||
|
||||
|
||||
Args:
|
||||
plan: The plan step to store (markdown format; be clear, complete, use newlines, and use as many tokens as you need)
|
||||
|
||||
|
||||
Returns:
|
||||
The stored plan
|
||||
"""
|
||||
_global_memory['plans'].append(plan)
|
||||
_global_memory["plans"].append(plan)
|
||||
console.print(Panel(Markdown(plan), title="📋 Plan"))
|
||||
log_work_event(f"Added plan step:\n\n{plan}")
|
||||
return plan
|
||||
|
||||
|
||||
@tool("emit_task")
|
||||
def emit_task(task: str) -> str:
|
||||
"""Store a task in global memory.
|
||||
|
||||
|
||||
Args:
|
||||
task: The task to store
|
||||
|
||||
|
||||
Returns:
|
||||
String confirming task storage with ID number
|
||||
"""
|
||||
# Get and increment task ID
|
||||
task_id = _global_memory['task_id_counter']
|
||||
_global_memory['task_id_counter'] += 1
|
||||
|
||||
task_id = _global_memory["task_id_counter"]
|
||||
_global_memory["task_id_counter"] += 1
|
||||
|
||||
# Store task with ID
|
||||
_global_memory['tasks'][task_id] = task
|
||||
|
||||
_global_memory["tasks"][task_id] = task
|
||||
|
||||
console.print(Panel(Markdown(task), title=f"✅ Task #{task_id}"))
|
||||
log_work_event(f"Task #{task_id} added:\n\n{task}")
|
||||
return f"Task #{task_id} stored."
|
||||
|
||||
|
||||
|
||||
@tool("emit_key_facts")
|
||||
def emit_key_facts(facts: List[str]) -> str:
|
||||
"""Store multiple key facts about the project or current task in global memory.
|
||||
|
||||
|
||||
Args:
|
||||
facts: List of key facts to store
|
||||
|
||||
|
||||
Returns:
|
||||
List of stored fact confirmation messages
|
||||
"""
|
||||
results = []
|
||||
for fact in facts:
|
||||
# Get and increment fact ID
|
||||
fact_id = _global_memory['key_fact_id_counter']
|
||||
_global_memory['key_fact_id_counter'] += 1
|
||||
|
||||
fact_id = _global_memory["key_fact_id_counter"]
|
||||
_global_memory["key_fact_id_counter"] += 1
|
||||
|
||||
# Store fact with ID
|
||||
_global_memory['key_facts'][fact_id] = fact
|
||||
|
||||
_global_memory["key_facts"][fact_id] = fact
|
||||
|
||||
# Display panel with ID
|
||||
console.print(Panel(Markdown(fact), title=f"💡 Key Fact #{fact_id}", border_style="bright_cyan"))
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(fact),
|
||||
title=f"💡 Key Fact #{fact_id}",
|
||||
border_style="bright_cyan",
|
||||
)
|
||||
)
|
||||
|
||||
# Add result message
|
||||
results.append(f"Stored fact #{fact_id}: {fact}")
|
||||
|
||||
log_work_event(f"Stored {len(facts)} key facts.")
|
||||
|
||||
log_work_event(f"Stored {len(facts)} key facts.")
|
||||
return "Facts stored."
|
||||
|
||||
|
||||
|
|
@ -125,50 +149,54 @@ def emit_key_facts(facts: List[str]) -> str:
|
|||
def delete_key_facts(fact_ids: List[int]) -> str:
|
||||
"""Delete multiple key facts from global memory by their IDs.
|
||||
Silently skips any IDs that don't exist.
|
||||
|
||||
|
||||
Args:
|
||||
fact_ids: List of fact IDs to delete
|
||||
|
||||
|
||||
Returns:
|
||||
List of success messages for deleted facts
|
||||
"""
|
||||
results = []
|
||||
for fact_id in fact_ids:
|
||||
if fact_id in _global_memory['key_facts']:
|
||||
if fact_id in _global_memory["key_facts"]:
|
||||
# Delete the fact
|
||||
deleted_fact = _global_memory['key_facts'].pop(fact_id)
|
||||
deleted_fact = _global_memory["key_facts"].pop(fact_id)
|
||||
success_msg = f"Successfully deleted fact #{fact_id}: {deleted_fact}"
|
||||
console.print(Panel(Markdown(success_msg), title="Fact Deleted", border_style="green"))
|
||||
console.print(
|
||||
Panel(Markdown(success_msg), title="Fact Deleted", border_style="green")
|
||||
)
|
||||
results.append(success_msg)
|
||||
|
||||
log_work_event(f"Deleted facts {fact_ids}.")
|
||||
|
||||
log_work_event(f"Deleted facts {fact_ids}.")
|
||||
return "Facts deleted."
|
||||
|
||||
|
||||
@tool("delete_tasks")
|
||||
def delete_tasks(task_ids: List[int]) -> str:
|
||||
"""Delete multiple tasks from global memory by their IDs.
|
||||
Silently skips any IDs that don't exist.
|
||||
|
||||
|
||||
Args:
|
||||
task_ids: List of task IDs to delete
|
||||
|
||||
|
||||
Returns:
|
||||
Confirmation message
|
||||
"""
|
||||
results = []
|
||||
for task_id in task_ids:
|
||||
if task_id in _global_memory['tasks']:
|
||||
if task_id in _global_memory["tasks"]:
|
||||
# Delete the task
|
||||
deleted_task = _global_memory['tasks'].pop(task_id)
|
||||
deleted_task = _global_memory["tasks"].pop(task_id)
|
||||
success_msg = f"Successfully deleted task #{task_id}: {deleted_task}"
|
||||
console.print(Panel(Markdown(success_msg),
|
||||
title="Task Deleted",
|
||||
border_style="green"))
|
||||
console.print(
|
||||
Panel(Markdown(success_msg), title="Task Deleted", border_style="green")
|
||||
)
|
||||
results.append(success_msg)
|
||||
|
||||
log_work_event(f"Deleted tasks {task_ids}.")
|
||||
|
||||
log_work_event(f"Deleted tasks {task_ids}.")
|
||||
return "Tasks deleted."
|
||||
|
||||
|
||||
@tool("request_implementation")
|
||||
def request_implementation() -> str:
|
||||
"""Request that implementation proceed after research/planning.
|
||||
|
|
@ -178,46 +206,47 @@ def request_implementation() -> str:
|
|||
Do you need to request research subtasks first?
|
||||
Have you run relevant unit tests, if they exist, to get a baseline (this can be a subtask)?
|
||||
Do you need to crawl deeper to find all related files and symbols?
|
||||
|
||||
|
||||
Returns:
|
||||
Empty string
|
||||
"""
|
||||
_global_memory['implementation_requested'] = True
|
||||
_global_memory["implementation_requested"] = True
|
||||
console.print(Panel("🚀 Implementation Requested", style="yellow", padding=0))
|
||||
log_work_event("Implementation requested.")
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
@tool("emit_key_snippets")
|
||||
def emit_key_snippets(snippets: List[SnippetInfo]) -> str:
|
||||
"""Store multiple key source code snippets in global memory.
|
||||
Automatically adds the filepaths of the snippets to related files.
|
||||
|
||||
|
||||
This is for **existing**, or **just-written** files, not for things to be created in the future.
|
||||
|
||||
|
||||
Args:
|
||||
snippets: REQUIRED List of snippet information dictionaries containing:
|
||||
- filepath: Path to the source file
|
||||
- line_number: Line number where the snippet starts
|
||||
- line_number: Line number where the snippet starts
|
||||
- snippet: The source code snippet text
|
||||
- description: Optional description of the significance
|
||||
|
||||
|
||||
Returns:
|
||||
List of stored snippet confirmation messages
|
||||
"""
|
||||
# First collect unique filepaths to add as related files
|
||||
emit_related_files.invoke({"files": [snippet_info['filepath'] for snippet_info in snippets]})
|
||||
emit_related_files.invoke(
|
||||
{"files": [snippet_info["filepath"] for snippet_info in snippets]}
|
||||
)
|
||||
|
||||
results = []
|
||||
for snippet_info in snippets:
|
||||
# Get and increment snippet ID
|
||||
snippet_id = _global_memory['key_snippet_id_counter']
|
||||
_global_memory['key_snippet_id_counter'] += 1
|
||||
|
||||
# Get and increment snippet ID
|
||||
snippet_id = _global_memory["key_snippet_id_counter"]
|
||||
_global_memory["key_snippet_id_counter"] += 1
|
||||
|
||||
# Store snippet info
|
||||
_global_memory['key_snippets'][snippet_id] = snippet_info
|
||||
|
||||
_global_memory["key_snippets"][snippet_id] = snippet_info
|
||||
|
||||
# Format display text as markdown
|
||||
display_text = [
|
||||
f"**Source Location**:",
|
||||
|
|
@ -226,153 +255,170 @@ def emit_key_snippets(snippets: List[SnippetInfo]) -> str:
|
|||
"", # Empty line before code block
|
||||
"**Code**:",
|
||||
"```python",
|
||||
snippet_info['snippet'].rstrip(), # Remove trailing whitespace
|
||||
"```"
|
||||
snippet_info["snippet"].rstrip(), # Remove trailing whitespace
|
||||
"```",
|
||||
]
|
||||
if snippet_info['description']:
|
||||
display_text.extend(["", "**Description**:", snippet_info['description']])
|
||||
|
||||
if snippet_info["description"]:
|
||||
display_text.extend(["", "**Description**:", snippet_info["description"]])
|
||||
|
||||
# Display panel
|
||||
console.print(Panel(Markdown("\n".join(display_text)),
|
||||
title=f"📝 Key Snippet #{snippet_id}",
|
||||
border_style="bright_cyan"))
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown("\n".join(display_text)),
|
||||
title=f"📝 Key Snippet #{snippet_id}",
|
||||
border_style="bright_cyan",
|
||||
)
|
||||
)
|
||||
|
||||
results.append(f"Stored snippet #{snippet_id}")
|
||||
|
||||
log_work_event(f"Stored {len(snippets)} code snippets.")
|
||||
|
||||
log_work_event(f"Stored {len(snippets)} code snippets.")
|
||||
return "Snippets stored."
|
||||
|
||||
@tool("delete_key_snippets")
|
||||
|
||||
@tool("delete_key_snippets")
|
||||
def delete_key_snippets(snippet_ids: List[int]) -> str:
|
||||
"""Delete multiple key snippets from global memory by their IDs.
|
||||
Silently skips any IDs that don't exist.
|
||||
|
||||
|
||||
Args:
|
||||
snippet_ids: List of snippet IDs to delete
|
||||
|
||||
|
||||
Returns:
|
||||
List of success messages for deleted snippets
|
||||
"""
|
||||
results = []
|
||||
for snippet_id in snippet_ids:
|
||||
if snippet_id in _global_memory['key_snippets']:
|
||||
if snippet_id in _global_memory["key_snippets"]:
|
||||
# Delete the snippet
|
||||
deleted_snippet = _global_memory['key_snippets'].pop(snippet_id)
|
||||
deleted_snippet = _global_memory["key_snippets"].pop(snippet_id)
|
||||
success_msg = f"Successfully deleted snippet #{snippet_id} from {deleted_snippet['filepath']}"
|
||||
console.print(Panel(Markdown(success_msg),
|
||||
title="Snippet Deleted",
|
||||
border_style="green"))
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(success_msg), title="Snippet Deleted", border_style="green"
|
||||
)
|
||||
)
|
||||
results.append(success_msg)
|
||||
|
||||
log_work_event(f"Deleted snippets {snippet_ids}.")
|
||||
|
||||
log_work_event(f"Deleted snippets {snippet_ids}.")
|
||||
return "Snippets deleted."
|
||||
|
||||
|
||||
@tool("swap_task_order")
|
||||
def swap_task_order(id1: int, id2: int) -> str:
|
||||
"""Swap the order of two tasks in global memory by their IDs.
|
||||
|
||||
|
||||
Args:
|
||||
id1: First task ID
|
||||
id2: Second task ID
|
||||
|
||||
|
||||
Returns:
|
||||
Success or error message depending on outcome
|
||||
"""
|
||||
# Validate IDs are different
|
||||
if id1 == id2:
|
||||
return "Cannot swap task with itself"
|
||||
|
||||
|
||||
# Validate both IDs exist
|
||||
if id1 not in _global_memory['tasks'] or id2 not in _global_memory['tasks']:
|
||||
if id1 not in _global_memory["tasks"] or id2 not in _global_memory["tasks"]:
|
||||
return "Invalid task ID(s)"
|
||||
|
||||
|
||||
# Swap the tasks
|
||||
_global_memory['tasks'][id1], _global_memory['tasks'][id2] = \
|
||||
_global_memory['tasks'][id2], _global_memory['tasks'][id1]
|
||||
|
||||
_global_memory["tasks"][id1], _global_memory["tasks"][id2] = (
|
||||
_global_memory["tasks"][id2],
|
||||
_global_memory["tasks"][id1],
|
||||
)
|
||||
|
||||
# Display what was swapped
|
||||
console.print(Panel(
|
||||
Markdown(f"Swapped:\n- Task #{id1} ↔️ Task #{id2}"),
|
||||
title="🔄 Tasks Reordered",
|
||||
border_style="green"
|
||||
))
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(f"Swapped:\n- Task #{id1} ↔️ Task #{id2}"),
|
||||
title="🔄 Tasks Reordered",
|
||||
border_style="green",
|
||||
)
|
||||
)
|
||||
|
||||
return "Tasks swapped."
|
||||
|
||||
|
||||
@tool("one_shot_completed")
|
||||
def one_shot_completed(message: str) -> str:
|
||||
"""Signal that a one-shot task has been completed and execution should stop.
|
||||
|
||||
Only call this if you have already **fully** completed the original request.
|
||||
|
||||
|
||||
Args:
|
||||
message: Completion message to display
|
||||
"""
|
||||
if _global_memory.get('implementation_requested', False):
|
||||
if _global_memory.get("implementation_requested", False):
|
||||
return "Cannot complete in one shot - implementation was requested"
|
||||
|
||||
_global_memory['task_completed'] = True
|
||||
_global_memory['completion_message'] = message
|
||||
|
||||
_global_memory["task_completed"] = True
|
||||
_global_memory["completion_message"] = message
|
||||
console.print(Panel(Markdown(message), title="✅ Task Completed"))
|
||||
log_work_event(f"Task completed\n\n{message}")
|
||||
return "Completion noted."
|
||||
|
||||
|
||||
@tool("task_completed")
|
||||
def task_completed(message: str) -> str:
|
||||
"""Mark the current task as completed with a completion message.
|
||||
|
||||
|
||||
Args:
|
||||
message: Message explaining how/why the task is complete
|
||||
|
||||
|
||||
Returns:
|
||||
The completion message
|
||||
"""
|
||||
_global_memory['task_completed'] = True
|
||||
_global_memory['completion_message'] = message
|
||||
_global_memory["task_completed"] = True
|
||||
_global_memory["completion_message"] = message
|
||||
console.print(Panel(Markdown(message), title="✅ Task Completed"))
|
||||
return "Completion noted."
|
||||
|
||||
|
||||
@tool("plan_implementation_completed")
|
||||
def plan_implementation_completed(message: str) -> str:
|
||||
"""Mark the entire implementation plan as completed.
|
||||
|
||||
|
||||
Args:
|
||||
message: Message explaining how the implementation plan was completed
|
||||
|
||||
|
||||
Returns:
|
||||
Confirmation message
|
||||
"""
|
||||
_global_memory['plan_completed'] = True
|
||||
_global_memory['completion_message'] = message
|
||||
_global_memory['tasks'].clear() # Clear task list when plan is completed
|
||||
_global_memory['task_id_counter'] = 1
|
||||
_global_memory["plan_completed"] = True
|
||||
_global_memory["completion_message"] = message
|
||||
_global_memory["tasks"].clear() # Clear task list when plan is completed
|
||||
_global_memory["task_id_counter"] = 1
|
||||
console.print(Panel(Markdown(message), title="✅ Plan Executed"))
|
||||
log_work_event(f"Plan execution completed:\n\n{message}")
|
||||
return "Plan completion noted and task list cleared."
|
||||
|
||||
|
||||
def get_related_files() -> List[str]:
|
||||
"""Get the current list of related files.
|
||||
|
||||
|
||||
Returns:
|
||||
List of formatted strings in the format 'ID#X path/to/file.py'
|
||||
"""
|
||||
files = _global_memory['related_files']
|
||||
files = _global_memory["related_files"]
|
||||
return [f"ID#{file_id} {filepath}" for file_id, filepath in sorted(files.items())]
|
||||
|
||||
|
||||
@tool("emit_related_files")
|
||||
def emit_related_files(files: List[str]) -> str:
|
||||
"""Store multiple related files that tools should work with.
|
||||
|
||||
|
||||
Args:
|
||||
files: List of file paths to add
|
||||
|
||||
|
||||
Returns:
|
||||
Formatted string containing file IDs and paths for all processed files
|
||||
"""
|
||||
results = []
|
||||
added_files = []
|
||||
invalid_paths = []
|
||||
|
||||
|
||||
# Process files
|
||||
for file in files:
|
||||
# First check if path exists
|
||||
|
|
@ -380,111 +426,115 @@ def emit_related_files(files: List[str]) -> str:
|
|||
invalid_paths.append(file)
|
||||
results.append(f"Error: Path '{file}' does not exist")
|
||||
continue
|
||||
|
||||
|
||||
# Then check if it's a directory
|
||||
if os.path.isdir(file):
|
||||
invalid_paths.append(file)
|
||||
results.append(f"Error: Path '{file}' is a directory, not a file")
|
||||
continue
|
||||
|
||||
|
||||
# Finally validate it's a regular file
|
||||
if not os.path.isfile(file):
|
||||
invalid_paths.append(file)
|
||||
results.append(f"Error: Path '{file}' exists but is not a regular file")
|
||||
continue
|
||||
|
||||
|
||||
# Check if file path already exists in values
|
||||
existing_id = None
|
||||
for fid, fpath in _global_memory['related_files'].items():
|
||||
for fid, fpath in _global_memory["related_files"].items():
|
||||
if fpath == file:
|
||||
existing_id = fid
|
||||
break
|
||||
|
||||
|
||||
if existing_id is not None:
|
||||
# File exists, use existing ID
|
||||
results.append(f"File ID #{existing_id}: {file}")
|
||||
else:
|
||||
# New file, assign new ID
|
||||
file_id = _global_memory['related_file_id_counter']
|
||||
_global_memory['related_file_id_counter'] += 1
|
||||
|
||||
file_id = _global_memory["related_file_id_counter"]
|
||||
_global_memory["related_file_id_counter"] += 1
|
||||
|
||||
# Store file with ID
|
||||
_global_memory['related_files'][file_id] = file
|
||||
_global_memory["related_files"][file_id] = file
|
||||
added_files.append((file_id, file))
|
||||
results.append(f"File ID #{file_id}: {file}")
|
||||
|
||||
|
||||
# Rich output - single consolidated panel
|
||||
if added_files:
|
||||
files_added_md = '\n'.join(f"- `{file}`" for id, file in added_files)
|
||||
files_added_md = "\n".join(f"- `{file}`" for id, file in added_files)
|
||||
md_content = f"**Files Noted:**\n{files_added_md}"
|
||||
console.print(Panel(Markdown(md_content),
|
||||
title="📁 Related Files Noted",
|
||||
border_style="green"))
|
||||
|
||||
return '\n'.join(results)
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(md_content),
|
||||
title="📁 Related Files Noted",
|
||||
border_style="green",
|
||||
)
|
||||
)
|
||||
|
||||
return "\n".join(results)
|
||||
|
||||
|
||||
def log_work_event(event: str) -> str:
|
||||
"""Add timestamped entry to work log.
|
||||
|
||||
|
||||
Internal function used to track major events during agent execution.
|
||||
Each entry is stored with an ISO format timestamp.
|
||||
|
||||
|
||||
Args:
|
||||
event: Description of the event to log
|
||||
|
||||
|
||||
Returns:
|
||||
Confirmation message
|
||||
|
||||
|
||||
Note:
|
||||
Entries can be retrieved with get_work_log() as markdown formatted text.
|
||||
"""
|
||||
from datetime import datetime
|
||||
entry = WorkLogEntry(
|
||||
timestamp=datetime.now().isoformat(),
|
||||
event=event
|
||||
)
|
||||
_global_memory['work_log'].append(entry)
|
||||
|
||||
entry = WorkLogEntry(timestamp=datetime.now().isoformat(), event=event)
|
||||
_global_memory["work_log"].append(entry)
|
||||
return f"Event logged: {event}"
|
||||
|
||||
|
||||
def get_work_log() -> str:
|
||||
"""Return formatted markdown of work log entries.
|
||||
|
||||
|
||||
Returns:
|
||||
Markdown formatted text with timestamps as headings and events as content,
|
||||
or 'No work log entries' if the log is empty.
|
||||
|
||||
|
||||
Example:
|
||||
## 2024-12-23T11:39:10
|
||||
|
||||
Task #1 added: Create login form
|
||||
"""
|
||||
if not _global_memory['work_log']:
|
||||
if not _global_memory["work_log"]:
|
||||
return "No work log entries"
|
||||
|
||||
|
||||
entries = []
|
||||
for entry in _global_memory['work_log']:
|
||||
entries.extend([
|
||||
f"## {entry['timestamp']}",
|
||||
"",
|
||||
entry['event'],
|
||||
"" # Blank line between entries
|
||||
])
|
||||
|
||||
for entry in _global_memory["work_log"]:
|
||||
entries.extend(
|
||||
[
|
||||
f"## {entry['timestamp']}",
|
||||
"",
|
||||
entry["event"],
|
||||
"", # Blank line between entries
|
||||
]
|
||||
)
|
||||
|
||||
return "\n".join(entries).rstrip() # Remove trailing newline
|
||||
|
||||
|
||||
def reset_work_log() -> str:
|
||||
"""Clear the work log.
|
||||
|
||||
Returns:
|
||||
|
||||
Returns:
|
||||
Confirmation message
|
||||
|
||||
|
||||
Note:
|
||||
This permanently removes all work log entries. The operation cannot be undone.
|
||||
"""
|
||||
_global_memory['work_log'].clear()
|
||||
_global_memory["work_log"].clear()
|
||||
return "Work log cleared"
|
||||
|
||||
|
||||
|
|
@ -492,37 +542,44 @@ def reset_work_log() -> str:
|
|||
def deregister_related_files(file_ids: List[int]) -> str:
|
||||
"""Delete multiple related files from global memory by their IDs.
|
||||
Silently skips any IDs that don't exist.
|
||||
|
||||
|
||||
Args:
|
||||
file_ids: List of file IDs to delete
|
||||
|
||||
|
||||
Returns:
|
||||
Success message string
|
||||
"""
|
||||
results = []
|
||||
for file_id in file_ids:
|
||||
if file_id in _global_memory['related_files']:
|
||||
if file_id in _global_memory["related_files"]:
|
||||
# Delete the file reference
|
||||
deleted_file = _global_memory['related_files'].pop(file_id)
|
||||
success_msg = f"Successfully removed related file #{file_id}: {deleted_file}"
|
||||
console.print(Panel(Markdown(success_msg),
|
||||
title="File Reference Removed",
|
||||
border_style="green"))
|
||||
deleted_file = _global_memory["related_files"].pop(file_id)
|
||||
success_msg = (
|
||||
f"Successfully removed related file #{file_id}: {deleted_file}"
|
||||
)
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(success_msg),
|
||||
title="File Reference Removed",
|
||||
border_style="green",
|
||||
)
|
||||
)
|
||||
results.append(success_msg)
|
||||
|
||||
|
||||
return "File references removed."
|
||||
|
||||
|
||||
def get_memory_value(key: str) -> str:
|
||||
"""Get a value from global memory.
|
||||
|
||||
|
||||
Different memory types return different formats:
|
||||
- key_facts: Returns numbered list of facts in format '#ID: fact'
|
||||
- key_snippets: Returns formatted snippets with file path, line number and content
|
||||
- All other types: Returns newline-separated list of values
|
||||
|
||||
|
||||
Args:
|
||||
key: The key to get from memory
|
||||
|
||||
|
||||
Returns:
|
||||
String representation of the memory values:
|
||||
- For key_facts: '#ID: fact' format, one per line
|
||||
|
|
@ -530,23 +587,25 @@ def get_memory_value(key: str) -> str:
|
|||
- For other types: One value per line
|
||||
"""
|
||||
values = _global_memory.get(key, [])
|
||||
|
||||
if key == 'key_facts':
|
||||
|
||||
if key == "key_facts":
|
||||
# For empty dict, return empty string
|
||||
if not values:
|
||||
return ""
|
||||
# Sort by ID for consistent output and format as markdown sections
|
||||
facts = []
|
||||
for k, v in sorted(values.items()):
|
||||
facts.extend([
|
||||
f"## 🔑 Key Fact #{k}",
|
||||
"", # Empty line for better markdown spacing
|
||||
v,
|
||||
"" # Empty line between facts
|
||||
])
|
||||
facts.extend(
|
||||
[
|
||||
f"## 🔑 Key Fact #{k}",
|
||||
"", # Empty line for better markdown spacing
|
||||
v,
|
||||
"", # Empty line between facts
|
||||
]
|
||||
)
|
||||
return "\n".join(facts).rstrip() # Remove trailing newline
|
||||
|
||||
if key == 'key_snippets':
|
||||
|
||||
if key == "key_snippets":
|
||||
if not values:
|
||||
return ""
|
||||
# Format each snippet with file info and content using markdown
|
||||
|
|
@ -555,26 +614,25 @@ def get_memory_value(key: str) -> str:
|
|||
snippet_text = [
|
||||
f"## 📝 Code Snippet #{k}",
|
||||
"", # Empty line for better markdown spacing
|
||||
f"**Source Location**:",
|
||||
"**Source Location**:",
|
||||
f"- File: `{v['filepath']}`",
|
||||
f"- Line: `{v['line_number']}`",
|
||||
"", # Empty line before code block
|
||||
"**Code**:",
|
||||
"```python",
|
||||
v['snippet'].rstrip(), # Remove trailing whitespace
|
||||
"```"
|
||||
v["snippet"].rstrip(), # Remove trailing whitespace
|
||||
"```",
|
||||
]
|
||||
if v['description']:
|
||||
if v["description"]:
|
||||
# Add empty line and description
|
||||
snippet_text.extend(["", "**Description**:", v['description']])
|
||||
snippet_text.extend(["", "**Description**:", v["description"]])
|
||||
snippets.append("\n".join(snippet_text))
|
||||
return "\n\n".join(snippets)
|
||||
|
||||
if key == 'work_log':
|
||||
|
||||
if key == "work_log":
|
||||
if not values:
|
||||
return ""
|
||||
entries = [f"## {entry['timestamp']}\n{entry['event']}"
|
||||
for entry in values]
|
||||
entries = [f"## {entry['timestamp']}\n{entry['event']}" for entry in values]
|
||||
return "\n\n".join(entries)
|
||||
|
||||
# For other types (lists), join with newlines
|
||||
|
|
|
|||
|
|
@ -9,7 +9,13 @@ from dataclasses import dataclass
|
|||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||
|
||||
from ra_aid.env import validate_environment
|
||||
from ra_aid.llm import initialize_llm, initialize_expert_llm
|
||||
from ra_aid.llm import (
|
||||
initialize_llm,
|
||||
initialize_expert_llm,
|
||||
get_env_var,
|
||||
get_provider_config,
|
||||
create_llm_client
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def clean_env(monkeypatch):
|
||||
|
|
@ -39,7 +45,8 @@ def test_initialize_expert_defaults(clean_env, mock_openai, monkeypatch):
|
|||
|
||||
mock_openai.assert_called_once_with(
|
||||
api_key="test-key",
|
||||
model="o1"
|
||||
model="o1",
|
||||
temperature=0
|
||||
)
|
||||
|
||||
def test_initialize_expert_openai_custom(clean_env, mock_openai, monkeypatch):
|
||||
|
|
@ -49,7 +56,8 @@ def test_initialize_expert_openai_custom(clean_env, mock_openai, monkeypatch):
|
|||
|
||||
mock_openai.assert_called_once_with(
|
||||
api_key="test-key",
|
||||
model="gpt-4-preview"
|
||||
model="gpt-4-preview",
|
||||
temperature=0
|
||||
)
|
||||
|
||||
def test_initialize_expert_gemini(clean_env, mock_gemini, monkeypatch):
|
||||
|
|
@ -59,7 +67,8 @@ def test_initialize_expert_gemini(clean_env, mock_gemini, monkeypatch):
|
|||
|
||||
mock_gemini.assert_called_once_with(
|
||||
api_key="test-key",
|
||||
model="gemini-2.0-flash-thinking-exp-1219"
|
||||
model="gemini-2.0-flash-thinking-exp-1219",
|
||||
temperature=0
|
||||
)
|
||||
|
||||
def test_initialize_expert_anthropic(clean_env, mock_anthropic, monkeypatch):
|
||||
|
|
@ -69,7 +78,8 @@ def test_initialize_expert_anthropic(clean_env, mock_anthropic, monkeypatch):
|
|||
|
||||
mock_anthropic.assert_called_once_with(
|
||||
api_key="test-key",
|
||||
model_name="claude-3"
|
||||
model_name="claude-3",
|
||||
temperature=0
|
||||
)
|
||||
|
||||
def test_initialize_expert_openrouter(clean_env, mock_openai, monkeypatch):
|
||||
|
|
@ -80,7 +90,8 @@ def test_initialize_expert_openrouter(clean_env, mock_openai, monkeypatch):
|
|||
mock_openai.assert_called_once_with(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="models/mistral-large"
|
||||
model="models/mistral-large",
|
||||
temperature=0
|
||||
)
|
||||
|
||||
def test_initialize_expert_openai_compatible(clean_env, mock_openai, monkeypatch):
|
||||
|
|
@ -92,7 +103,8 @@ def test_initialize_expert_openai_compatible(clean_env, mock_openai, monkeypatch
|
|||
mock_openai.assert_called_once_with(
|
||||
api_key="test-key",
|
||||
base_url="http://test-url",
|
||||
model="local-model"
|
||||
model="local-model",
|
||||
temperature=0
|
||||
)
|
||||
|
||||
def test_initialize_expert_unsupported_provider(clean_env):
|
||||
|
|
@ -311,45 +323,49 @@ def test_initialize_llm_cross_provider(clean_env, mock_openai, mock_anthropic, m
|
|||
model="gemini-2.0-flash-thinking-exp-1219"
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
"""Test arguments class."""
|
||||
provider: str
|
||||
expert_provider: str
|
||||
model: str = None
|
||||
expert_model: str = None
|
||||
|
||||
def test_environment_variable_precedence(clean_env, mock_openai, monkeypatch):
|
||||
"""Test environment variable precedence and fallback."""
|
||||
from ra_aid.env import validate_environment
|
||||
from dataclasses import dataclass
|
||||
# Test get_env_var helper with fallback
|
||||
monkeypatch.setenv("TEST_KEY", "base-value")
|
||||
monkeypatch.setenv("EXPERT_TEST_KEY", "expert-value")
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
provider: str
|
||||
expert_provider: str
|
||||
model: str = None
|
||||
expert_model: str = None
|
||||
|
||||
# Test expert mode with explicit key
|
||||
# Set up base environment first
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "base-key")
|
||||
assert get_env_var("TEST_KEY") == "base-value"
|
||||
assert get_env_var("TEST_KEY", expert=True) == "expert-value"
|
||||
|
||||
# Test fallback when expert value not set
|
||||
monkeypatch.delenv("EXPERT_TEST_KEY", raising=False)
|
||||
assert get_env_var("TEST_KEY", expert=True) == "base-value"
|
||||
|
||||
# Test provider config
|
||||
monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "expert-key")
|
||||
monkeypatch.setenv("TAVILY_API_KEY", "tavily-key")
|
||||
monkeypatch.setenv("GEMINI_API_KEY", "gemini-key")
|
||||
args = Args(provider="openai", expert_provider="openai")
|
||||
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
|
||||
assert expert_enabled
|
||||
assert not expert_missing
|
||||
assert web_enabled
|
||||
assert not web_missing
|
||||
|
||||
llm = initialize_expert_llm()
|
||||
config = get_provider_config("openai", is_expert=True)
|
||||
assert config["api_key"] == "expert-key"
|
||||
|
||||
# Test LLM client creation with expert mode
|
||||
llm = create_llm_client("openai", "o1", is_expert=True)
|
||||
mock_openai.assert_called_with(
|
||||
api_key="expert-key",
|
||||
model="o1"
|
||||
model="o1",
|
||||
temperature=0
|
||||
)
|
||||
|
||||
# Test empty key validation
|
||||
# Test environment validation
|
||||
monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "")
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False) # Remove fallback
|
||||
monkeypatch.delenv("TAVILY_API_KEY", raising=False) # Remove web research
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key") # Add for provider validation
|
||||
monkeypatch.setenv("GEMINI_API_KEY", "gemini-key") # Add for provider validation
|
||||
monkeypatch.setenv("ANTHROPIC_MODEL", "claude-3-haiku-20240307") # Add model for provider validation
|
||||
args = Args(provider="anthropic", expert_provider="openai") # Change base provider to avoid validation error
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("TAVILY_API_KEY", raising=False)
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key")
|
||||
monkeypatch.setenv("GEMINI_API_KEY", "gemini-key")
|
||||
monkeypatch.setenv("ANTHROPIC_MODEL", "claude-3-haiku-20240307")
|
||||
|
||||
args = Args(provider="anthropic", expert_provider="openai")
|
||||
expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args)
|
||||
assert not expert_enabled
|
||||
assert expert_missing
|
||||
|
|
@ -372,3 +388,122 @@ def mock_gemini():
|
|||
with patch('ra_aid.llm.ChatGoogleGenerativeAI') as mock:
|
||||
mock.return_value = Mock(spec=ChatGoogleGenerativeAI)
|
||||
yield mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_deepseek_reasoner():
|
||||
"""Mock ChatDeepseekReasoner for testing DeepSeek provider initialization."""
|
||||
with patch('ra_aid.llm.ChatDeepseekReasoner') as mock:
|
||||
mock.return_value = Mock()
|
||||
yield mock
|
||||
|
||||
def test_initialize_deepseek(clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch):
|
||||
"""Test DeepSeek provider initialization with different models."""
|
||||
monkeypatch.setenv("DEEPSEEK_API_KEY", "test-key")
|
||||
|
||||
# Test with reasoner model
|
||||
model = initialize_llm("deepseek", "deepseek-reasoner")
|
||||
mock_deepseek_reasoner.assert_called_with(
|
||||
api_key="test-key",
|
||||
base_url="https://api.deepseek.com",
|
||||
temperature=1,
|
||||
model="deepseek-reasoner"
|
||||
)
|
||||
|
||||
# Test with non-reasoner model
|
||||
model = initialize_llm("deepseek", "deepseek-chat")
|
||||
mock_openai.assert_called_with(
|
||||
api_key="test-key",
|
||||
base_url="https://api.deepseek.com",
|
||||
temperature=1,
|
||||
model="deepseek-chat"
|
||||
)
|
||||
|
||||
def test_initialize_expert_deepseek(clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch):
|
||||
"""Test expert DeepSeek provider initialization."""
|
||||
monkeypatch.setenv("EXPERT_DEEPSEEK_API_KEY", "test-key")
|
||||
|
||||
# Test with reasoner model
|
||||
model = initialize_expert_llm("deepseek", "deepseek-reasoner")
|
||||
mock_deepseek_reasoner.assert_called_with(
|
||||
api_key="test-key",
|
||||
base_url="https://api.deepseek.com",
|
||||
temperature=0,
|
||||
model="deepseek-reasoner"
|
||||
)
|
||||
|
||||
# Test with non-reasoner model
|
||||
model = initialize_expert_llm("deepseek", "deepseek-chat")
|
||||
mock_openai.assert_called_with(
|
||||
api_key="test-key",
|
||||
base_url="https://api.deepseek.com",
|
||||
temperature=0,
|
||||
model="deepseek-chat"
|
||||
)
|
||||
|
||||
def test_initialize_openrouter_deepseek(clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch):
|
||||
"""Test OpenRouter DeepSeek model initialization."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "test-key")
|
||||
|
||||
# Test with DeepSeek R1 model
|
||||
model = initialize_llm("openrouter", "deepseek/deepseek-r1")
|
||||
mock_deepseek_reasoner.assert_called_with(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
temperature=1,
|
||||
model="deepseek/deepseek-r1"
|
||||
)
|
||||
|
||||
# Test with non-DeepSeek model
|
||||
model = initialize_llm("openrouter", "mistral/mistral-large")
|
||||
mock_openai.assert_called_with(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="mistral/mistral-large"
|
||||
)
|
||||
|
||||
def test_initialize_expert_openrouter_deepseek(clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch):
|
||||
"""Test expert OpenRouter DeepSeek model initialization."""
|
||||
monkeypatch.setenv("EXPERT_OPENROUTER_API_KEY", "test-key")
|
||||
|
||||
# Test with DeepSeek R1 model via create_llm_client
|
||||
model = create_llm_client("openrouter", "deepseek/deepseek-r1", is_expert=True)
|
||||
mock_deepseek_reasoner.assert_called_with(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
temperature=0,
|
||||
model="deepseek/deepseek-r1"
|
||||
)
|
||||
|
||||
# Test with non-DeepSeek model
|
||||
model = create_llm_client("openrouter", "mistral/mistral-large", is_expert=True)
|
||||
mock_openai.assert_called_with(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="mistral/mistral-large",
|
||||
temperature=0
|
||||
)
|
||||
|
||||
def test_deepseek_environment_fallback(clean_env, mock_deepseek_reasoner, monkeypatch):
|
||||
"""Test DeepSeek environment variable fallback behavior."""
|
||||
# Test environment variable helper with fallback
|
||||
monkeypatch.setenv("DEEPSEEK_API_KEY", "base-key")
|
||||
assert get_env_var("DEEPSEEK_API_KEY", expert=True) == "base-key"
|
||||
|
||||
# Test provider config with fallback
|
||||
config = get_provider_config("deepseek", is_expert=True)
|
||||
assert config["api_key"] == "base-key"
|
||||
assert config["base_url"] == "https://api.deepseek.com"
|
||||
|
||||
# Test with expert key
|
||||
monkeypatch.setenv("EXPERT_DEEPSEEK_API_KEY", "expert-key")
|
||||
config = get_provider_config("deepseek", is_expert=True)
|
||||
assert config["api_key"] == "expert-key"
|
||||
|
||||
# Test client creation with expert key
|
||||
model = create_llm_client("deepseek", "deepseek-reasoner", is_expert=True)
|
||||
mock_deepseek_reasoner.assert_called_with(
|
||||
api_key="expert-key",
|
||||
base_url="https://api.deepseek.com",
|
||||
temperature=0,
|
||||
model="deepseek-reasoner"
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue