Merge pull request #78 from ariel-frischer/fix-token-limit-bug

Fix token limit bug with custom --research/--planner args
This commit is contained in:
Ariel Frischer 2025-02-01 13:06:37 -08:00 committed by GitHub
commit c14fad6d14
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 438 additions and 732 deletions

View File

@ -34,9 +34,11 @@ logger = get_logger(__name__)
def launch_webui(host: str, port: int): def launch_webui(host: str, port: int):
"""Launch the RA.Aid web interface.""" """Launch the RA.Aid web interface."""
from ra_aid.webui import run_server from ra_aid.webui import run_server
print(f"Starting RA.Aid web interface on http://{host}:{port}") print(f"Starting RA.Aid web interface on http://{host}:{port}")
run_server(host=host, port=port) run_server(host=host, port=port)
def parse_arguments(args=None): def parse_arguments(args=None):
VALID_PROVIDERS = [ VALID_PROVIDERS = [
"anthropic", "anthropic",

View File

@ -5,7 +5,7 @@ import sys
import threading import threading
import time import time
import uuid import uuid
from typing import Any, Dict, List, Optional, Sequence from typing import Any, Dict, List, Literal, Optional, Sequence
import litellm import litellm
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
@ -122,13 +122,22 @@ def state_modifier(
return [first_message] + trimmed_remaining return [first_message] + trimmed_remaining
def get_model_token_limit(config: Dict[str, Any]) -> Optional[int]: def get_model_token_limit(
"""Get the token limit for the current model configuration. config: Dict[str, Any], agent_type: Literal["default", "research", "planner"]
) -> Optional[int]:
"""Get the token limit for the current model configuration based on agent type.
Returns: Returns:
Optional[int]: The token limit if found, None otherwise Optional[int]: The token limit if found, None otherwise
""" """
try: try:
if agent_type == "research":
provider = config.get("research_provider", "") or config.get("provider", "")
model_name = config.get("research_model", "") or config.get("model", "")
elif agent_type == "planner":
provider = config.get("planner_provider", "") or config.get("provider", "")
model_name = config.get("planner_model", "") or config.get("model", "")
else:
provider = config.get("provider", "") provider = config.get("provider", "")
model_name = config.get("model", "") model_name = config.get("model", "")
@ -224,6 +233,7 @@ def create_agent(
tools: List[Any], tools: List[Any],
*, *,
checkpointer: Any = None, checkpointer: Any = None,
agent_type: str = "default",
) -> Any: ) -> Any:
"""Create a react agent with the given configuration. """Create a react agent with the given configuration.
@ -245,7 +255,9 @@ def create_agent(
""" """
try: try:
config = _global_memory.get("config", {}) config = _global_memory.get("config", {})
max_input_tokens = get_model_token_limit(config) or DEFAULT_TOKEN_LIMIT max_input_tokens = (
get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT
)
# Use REACT agent for Anthropic Claude models, otherwise use CIAYN # Use REACT agent for Anthropic Claude models, otherwise use CIAYN
if is_anthropic_claude(config): if is_anthropic_claude(config):
@ -260,7 +272,7 @@ def create_agent(
# Default to REACT agent if provider/model detection fails # Default to REACT agent if provider/model detection fails
logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.") logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.")
config = _global_memory.get("config", {}) config = _global_memory.get("config", {})
max_input_tokens = get_model_token_limit(config) max_input_tokens = get_model_token_limit(config, agent_type)
agent_kwargs = build_agent_kwargs(checkpointer, config, max_input_tokens) agent_kwargs = build_agent_kwargs(checkpointer, config, max_input_tokens)
return create_react_agent(model, tools, **agent_kwargs) return create_react_agent(model, tools, **agent_kwargs)
@ -326,7 +338,7 @@ def run_research_agent(
web_research_enabled=config.get("web_research_enabled", False), web_research_enabled=config.get("web_research_enabled", False),
) )
agent = create_agent(model, tools, checkpointer=memory) agent = create_agent(model, tools, checkpointer=memory, agent_type="research")
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else "" expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else "" human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
@ -349,9 +361,11 @@ def run_research_agent(
prompt = (RESEARCH_ONLY_PROMPT if research_only else RESEARCH_PROMPT).format( prompt = (RESEARCH_ONLY_PROMPT if research_only else RESEARCH_PROMPT).format(
base_task=base_task_or_query, base_task=base_task_or_query,
research_only_note="" research_only_note=(
""
if research_only if research_only
else " Only request implementation if the user explicitly asked for changes to be made.", else " Only request implementation if the user explicitly asked for changes to be made."
),
expert_section=expert_section, expert_section=expert_section,
human_section=human_section, human_section=human_section,
web_research_section=web_research_section, web_research_section=web_research_section,
@ -455,7 +469,7 @@ def run_web_research_agent(
tools = get_web_research_tools(expert_enabled=expert_enabled) tools = get_web_research_tools(expert_enabled=expert_enabled)
agent = create_agent(model, tools, checkpointer=memory) agent = create_agent(model, tools, checkpointer=memory, agent_type="research")
expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else "" expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else ""
human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else "" human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else ""
@ -536,7 +550,7 @@ def run_planning_agent(
web_research_enabled=config.get("web_research_enabled", False), web_research_enabled=config.get("web_research_enabled", False),
) )
agent = create_agent(model, tools, checkpointer=memory) agent = create_agent(model, tools, checkpointer=memory, agent_type="planner")
expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else "" expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else ""
human_section = HUMAN_PROMPT_SECTION_PLANNING if hil else "" human_section = HUMAN_PROMPT_SECTION_PLANNING if hil else ""
@ -556,9 +570,11 @@ def run_planning_agent(
key_facts=get_memory_value("key_facts"), key_facts=get_memory_value("key_facts"),
key_snippets=get_memory_value("key_snippets"), key_snippets=get_memory_value("key_snippets"),
work_log=get_memory_value("work_log"), work_log=get_memory_value("work_log"),
research_only_note="" research_only_note=(
""
if config.get("research_only") if config.get("research_only")
else " Only request implementation if the user explicitly asked for changes to be made.", else " Only request implementation if the user explicitly asked for changes to be made."
),
) )
config = _global_memory.get("config", {}) if not config else config config = _global_memory.get("config", {}) if not config else config
@ -634,7 +650,7 @@ def run_task_implementation_agent(
web_research_enabled=config.get("web_research_enabled", False), web_research_enabled=config.get("web_research_enabled", False),
) )
agent = create_agent(model, tools, checkpointer=memory) agent = create_agent(model, tools, checkpointer=memory, agent_type="planner")
prompt = IMPLEMENTATION_PROMPT.format( prompt = IMPLEMENTATION_PROMPT.format(
base_task=base_task, base_task=base_task,
@ -647,12 +663,16 @@ def run_task_implementation_agent(
research_notes=get_memory_value("research_notes"), research_notes=get_memory_value("research_notes"),
work_log=get_memory_value("work_log"), work_log=get_memory_value("work_log"),
expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "", expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "",
human_section=HUMAN_PROMPT_SECTION_IMPLEMENTATION human_section=(
HUMAN_PROMPT_SECTION_IMPLEMENTATION
if _global_memory.get("config", {}).get("hil", False) if _global_memory.get("config", {}).get("hil", False)
else "", else ""
web_research_section=WEB_RESEARCH_PROMPT_SECTION_CHAT ),
web_research_section=(
WEB_RESEARCH_PROMPT_SECTION_CHAT
if config.get("web_research_enabled") if config.get("web_research_enabled")
else "", else ""
),
) )
config = _global_memory.get("config", {}) if not config else config config = _global_memory.get("config", {}) if not config else config

View File

@ -1,9 +1,6 @@
import os import os
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
known_temp_providers = {"openai", "anthropic", "openrouter", "openai-compatible", "gemini", "deepseek"}
from .models_params import models_params
from langchain_anthropic import ChatAnthropic from langchain_anthropic import ChatAnthropic
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_google_genai import ChatGoogleGenerativeAI from langchain_google_genai import ChatGoogleGenerativeAI
@ -12,6 +9,17 @@ from langchain_openai import ChatOpenAI
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
from ra_aid.logging_config import get_logger from ra_aid.logging_config import get_logger
from .models_params import models_params
known_temp_providers = {
"openai",
"anthropic",
"openrouter",
"openai-compatible",
"gemini",
"deepseek",
}
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@ -6,710 +6,324 @@ DEFAULT_TOKEN_LIMIT = 100000
models_params = { models_params = {
"openai": { "openai": {
"gpt-3.5-turbo-0125": { "gpt-3.5-turbo-0125": {"token_limit": 16385, "supports_temperature": True},
"token_limit": 16385, "gpt-3.5": {"token_limit": 4096, "supports_temperature": True},
"supports_temperature": True "gpt-3.5-turbo": {"token_limit": 16385, "supports_temperature": True},
}, "gpt-3.5-turbo-1106": {"token_limit": 16385, "supports_temperature": True},
"gpt-3.5": { "gpt-3.5-turbo-instruct": {"token_limit": 4096, "supports_temperature": True},
"token_limit": 4096, "gpt-4-0125-preview": {"token_limit": 128000, "supports_temperature": True},
"supports_temperature": True "gpt-4-turbo-preview": {"token_limit": 128000, "supports_temperature": True},
}, "gpt-4-turbo": {"token_limit": 128000, "supports_temperature": True},
"gpt-3.5-turbo": { "gpt-4-turbo-2024-04-09": {"token_limit": 128000, "supports_temperature": True},
"token_limit": 16385, "gpt-4-1106-preview": {"token_limit": 128000, "supports_temperature": True},
"supports_temperature": True "gpt-4-vision-preview": {"token_limit": 128000, "supports_temperature": True},
}, "gpt-4": {"token_limit": 8192, "supports_temperature": True},
"gpt-3.5-turbo-1106": { "gpt-4-0613": {"token_limit": 8192, "supports_temperature": True},
"token_limit": 16385, "gpt-4-32k": {"token_limit": 32768, "supports_temperature": True},
"supports_temperature": True "gpt-4-32k-0613": {"token_limit": 32768, "supports_temperature": True},
}, "gpt-4o": {"token_limit": 128000, "supports_temperature": True},
"gpt-3.5-turbo-instruct": { "gpt-4o-2024-08-06": {"token_limit": 128000, "supports_temperature": True},
"token_limit": 4096, "gpt-4o-2024-05-13": {"token_limit": 128000, "supports_temperature": True},
"supports_temperature": True "gpt-4o-mini": {"token_limit": 128000, "supports_temperature": True},
}, "o1-preview": {"token_limit": 128000, "supports_temperature": False},
"gpt-4-0125-preview": { "o1-mini": {"token_limit": 128000, "supports_temperature": False},
"token_limit": 128000,
"supports_temperature": True
},
"gpt-4-turbo-preview": {
"token_limit": 128000,
"supports_temperature": True
},
"gpt-4-turbo": {
"token_limit": 128000,
"supports_temperature": True
},
"gpt-4-turbo-2024-04-09": {
"token_limit": 128000,
"supports_temperature": True
},
"gpt-4-1106-preview": {
"token_limit": 128000,
"supports_temperature": True
},
"gpt-4-vision-preview": {
"token_limit": 128000,
"supports_temperature": True
},
"gpt-4": {
"token_limit": 8192,
"supports_temperature": True
},
"gpt-4-0613": {
"token_limit": 8192,
"supports_temperature": True
},
"gpt-4-32k": {
"token_limit": 32768,
"supports_temperature": True
},
"gpt-4-32k-0613": {
"token_limit": 32768,
"supports_temperature": True
},
"gpt-4o": {
"token_limit": 128000,
"supports_temperature": True
},
"gpt-4o-2024-08-06": {
"token_limit": 128000,
"supports_temperature": True
},
"gpt-4o-2024-05-13": {
"token_limit": 128000,
"supports_temperature": True
},
"gpt-4o-mini": {
"token_limit": 128000,
"supports_temperature": True
},
"o1-preview": {
"token_limit": 128000,
"supports_temperature": False
},
"o1-mini": {
"token_limit": 128000,
"supports_temperature": False
}
}, },
"azure_openai": { "azure_openai": {
"gpt-3.5-turbo-0125": { "gpt-3.5-turbo-0125": {"token_limit": 16385, "supports_temperature": True},
"token_limit": 16385, "gpt-3.5": {"token_limit": 4096, "supports_temperature": True},
"supports_temperature": True "gpt-3.5-turbo": {"token_limit": 16385, "supports_temperature": True},
}, "gpt-3.5-turbo-1106": {"token_limit": 16385, "supports_temperature": True},
"gpt-3.5": { "gpt-3.5-turbo-instruct": {"token_limit": 4096, "supports_temperature": True},
"token_limit": 4096, "gpt-4-0125-preview": {"token_limit": 128000, "supports_temperature": True},
"supports_temperature": True "gpt-4-turbo-preview": {"token_limit": 128000, "supports_temperature": True},
}, "gpt-4-turbo": {"token_limit": 128000, "supports_temperature": True},
"gpt-3.5-turbo": { "gpt-4-turbo-2024-04-09": {"token_limit": 128000, "supports_temperature": True},
"token_limit": 16385, "gpt-4-1106-preview": {"token_limit": 128000, "supports_temperature": True},
"supports_temperature": True "gpt-4-vision-preview": {"token_limit": 128000, "supports_temperature": True},
}, "gpt-4": {"token_limit": 8192, "supports_temperature": True},
"gpt-3.5-turbo-1106": { "gpt-4-0613": {"token_limit": 8192, "supports_temperature": True},
"token_limit": 16385, "gpt-4-32k": {"token_limit": 32768, "supports_temperature": True},
"supports_temperature": True "gpt-4-32k-0613": {"token_limit": 32768, "supports_temperature": True},
}, "gpt-4o": {"token_limit": 128000, "supports_temperature": True},
"gpt-3.5-turbo-instruct": { "gpt-4o-mini": {"token_limit": 128000, "supports_temperature": True},
"token_limit": 4096, "chatgpt-4o-latest": {"token_limit": 128000, "supports_temperature": True},
"supports_temperature": True "o1-preview": {"token_limit": 128000, "supports_temperature": False},
}, "o1-mini": {"token_limit": 128000, "supports_temperature": False},
"gpt-4-0125-preview": {
"token_limit": 128000,
"supports_temperature": True
},
"gpt-4-turbo-preview": {
"token_limit": 128000,
"supports_temperature": True
},
"gpt-4-turbo": {
"token_limit": 128000,
"supports_temperature": True
},
"gpt-4-turbo-2024-04-09": {
"token_limit": 128000,
"supports_temperature": True
},
"gpt-4-1106-preview": {
"token_limit": 128000,
"supports_temperature": True
},
"gpt-4-vision-preview": {
"token_limit": 128000,
"supports_temperature": True
},
"gpt-4": {
"token_limit": 8192,
"supports_temperature": True
},
"gpt-4-0613": {
"token_limit": 8192,
"supports_temperature": True
},
"gpt-4-32k": {
"token_limit": 32768,
"supports_temperature": True
},
"gpt-4-32k-0613": {
"token_limit": 32768,
"supports_temperature": True
},
"gpt-4o": {
"token_limit": 128000,
"supports_temperature": True
},
"gpt-4o-mini": {
"token_limit": 128000,
"supports_temperature": True
},
"chatgpt-4o-latest": {
"token_limit": 128000,
"supports_temperature": True
},
"o1-preview": {
"token_limit": 128000,
"supports_temperature": False
},
"o1-mini": {
"token_limit": 128000,
"supports_temperature": False
}
}, },
"google_genai": { "google_genai": {
"gemini-pro": { "gemini-pro": {"token_limit": 128000, "supports_temperature": True},
"token_limit": 128000,
"supports_temperature": True
},
"gemini-1.5-flash-latest": { "gemini-1.5-flash-latest": {
"token_limit": 128000, "token_limit": 128000,
"supports_temperature": True "supports_temperature": True,
}, },
"gemini-1.5-pro-latest": { "gemini-1.5-pro-latest": {"token_limit": 128000, "supports_temperature": True},
"token_limit": 128000, "models/embedding-001": {"token_limit": 2048, "supports_temperature": True},
"supports_temperature": True
},
"models/embedding-001": {
"token_limit": 2048,
"supports_temperature": True
}
}, },
"google_vertexai": { "google_vertexai": {
"gemini-1.5-flash": { "gemini-1.5-flash": {"token_limit": 128000, "supports_temperature": True},
"token_limit": 128000, "gemini-1.5-pro": {"token_limit": 128000, "supports_temperature": True},
"supports_temperature": True "gemini-1.0-pro": {"token_limit": 128000, "supports_temperature": True},
},
"gemini-1.5-pro": {
"token_limit": 128000,
"supports_temperature": True
},
"gemini-1.0-pro": {
"token_limit": 128000,
"supports_temperature": True
}
}, },
"ollama": { "ollama": {
"command-r": { "command-r": {"token_limit": 12800, "supports_temperature": True},
"token_limit": 12800, "codellama": {"token_limit": 16000, "supports_temperature": True},
"supports_temperature": True "dbrx": {"token_limit": 32768, "supports_temperature": True},
}, "deepseek-coder:33b": {"token_limit": 16000, "supports_temperature": True},
"codellama": { "falcon": {"token_limit": 2048, "supports_temperature": True},
"token_limit": 16000, "llama2": {"token_limit": 4096, "supports_temperature": True},
"supports_temperature": True "llama2:7b": {"token_limit": 4096, "supports_temperature": True},
}, "llama2:13b": {"token_limit": 4096, "supports_temperature": True},
"dbrx": { "llama2:70b": {"token_limit": 4096, "supports_temperature": True},
"token_limit": 32768, "llama3": {"token_limit": 8192, "supports_temperature": True},
"supports_temperature": True "llama3:8b": {"token_limit": 8192, "supports_temperature": True},
}, "llama3:70b": {"token_limit": 8192, "supports_temperature": True},
"deepseek-coder:33b": { "llama3.1": {"token_limit": 128000, "supports_temperature": True},
"token_limit": 16000, "llama3.1:8b": {"token_limit": 128000, "supports_temperature": True},
"supports_temperature": True "llama3.1:70b": {"token_limit": 128000, "supports_temperature": True},
}, "lama3.1:405b": {"token_limit": 128000, "supports_temperature": True},
"falcon": { "llama3.2": {"token_limit": 128000, "supports_temperature": True},
"token_limit": 2048, "llama3.2:1b": {"token_limit": 128000, "supports_temperature": True},
"supports_temperature": True "llama3.2:3b": {"token_limit": 128000, "supports_temperature": True},
}, "llama3.3:70b": {"token_limit": 128000, "supports_temperature": True},
"llama2": { "scrapegraph": {"token_limit": 8192, "supports_temperature": True},
"token_limit": 4096, "mistral-small": {"token_limit": 128000, "supports_temperature": True},
"supports_temperature": True "mistral-openorca": {"token_limit": 32000, "supports_temperature": True},
}, "mistral-large": {"token_limit": 128000, "supports_temperature": True},
"llama2:7b": { "grok-1": {"token_limit": 8192, "supports_temperature": True},
"token_limit": 4096, "llava": {"token_limit": 4096, "supports_temperature": True},
"supports_temperature": True "mixtral:8x22b-instruct": {"token_limit": 65536, "supports_temperature": True},
}, "nomic-embed-text": {"token_limit": 8192, "supports_temperature": True},
"llama2:13b": { "nous-hermes2:34b": {"token_limit": 4096, "supports_temperature": True},
"token_limit": 4096, "orca-mini": {"token_limit": 2048, "supports_temperature": True},
"supports_temperature": True "phi3:3.8b": {"token_limit": 12800, "supports_temperature": True},
}, "phi3:14b": {"token_limit": 128000, "supports_temperature": True},
"llama2:70b": { "qwen:0.5b": {"token_limit": 32000, "supports_temperature": True},
"token_limit": 4096, "qwen:1.8b": {"token_limit": 32000, "supports_temperature": True},
"supports_temperature": True "qwen:4b": {"token_limit": 32000, "supports_temperature": True},
}, "qwen:14b": {"token_limit": 32000, "supports_temperature": True},
"llama3": { "qwen:32b": {"token_limit": 32000, "supports_temperature": True},
"token_limit": 8192, "qwen:72b": {"token_limit": 32000, "supports_temperature": True},
"supports_temperature": True "qwen:110b": {"token_limit": 32000, "supports_temperature": True},
}, "stablelm-zephyr": {"token_limit": 8192, "supports_temperature": True},
"llama3:8b": { "wizardlm2:8x22b": {"token_limit": 65536, "supports_temperature": True},
"token_limit": 8192, "mistral": {"token_limit": 128000, "supports_temperature": True},
"supports_temperature": True "gemma2": {"token_limit": 128000, "supports_temperature": True},
}, "gemma2:9b": {"token_limit": 128000, "supports_temperature": True},
"llama3:70b": { "gemma2:27b": {"token_limit": 128000, "supports_temperature": True},
"token_limit": 8192,
"supports_temperature": True
},
"llama3.1": {
"token_limit": 128000,
"supports_temperature": True
},
"llama3.1:8b": {
"token_limit": 128000,
"supports_temperature": True
},
"llama3.1:70b": {
"token_limit": 128000,
"supports_temperature": True
},
"lama3.1:405b": {
"token_limit": 128000,
"supports_temperature": True
},
"llama3.2": {
"token_limit": 128000,
"supports_temperature": True
},
"llama3.2:1b": {
"token_limit": 128000,
"supports_temperature": True
},
"llama3.2:3b": {
"token_limit": 128000,
"supports_temperature": True
},
"llama3.3:70b": {
"token_limit": 128000,
"supports_temperature": True
},
"scrapegraph": {
"token_limit": 8192,
"supports_temperature": True
},
"mistral-small": {
"token_limit": 128000,
"supports_temperature": True
},
"mistral-openorca": {
"token_limit": 32000,
"supports_temperature": True
},
"mistral-large": {
"token_limit": 128000,
"supports_temperature": True
},
"grok-1": {
"token_limit": 8192,
"supports_temperature": True
},
"llava": {
"token_limit": 4096,
"supports_temperature": True
},
"mixtral:8x22b-instruct": {
"token_limit": 65536,
"supports_temperature": True
},
"nomic-embed-text": {
"token_limit": 8192,
"supports_temperature": True
},
"nous-hermes2:34b": {
"token_limit": 4096,
"supports_temperature": True
},
"orca-mini": {
"token_limit": 2048,
"supports_temperature": True
},
"phi3:3.8b": {
"token_limit": 12800,
"supports_temperature": True
},
"phi3:14b": {
"token_limit": 128000,
"supports_temperature": True
},
"qwen:0.5b": {
"token_limit": 32000,
"supports_temperature": True
},
"qwen:1.8b": {
"token_limit": 32000,
"supports_temperature": True
},
"qwen:4b": {
"token_limit": 32000,
"supports_temperature": True
},
"qwen:14b": {
"token_limit": 32000,
"supports_temperature": True
},
"qwen:32b": {
"token_limit": 32000,
"supports_temperature": True
},
"qwen:72b": {
"token_limit": 32000,
"supports_temperature": True
},
"qwen:110b": {
"token_limit": 32000,
"supports_temperature": True
},
"stablelm-zephyr": {
"token_limit": 8192,
"supports_temperature": True
},
"wizardlm2:8x22b": {
"token_limit": 65536,
"supports_temperature": True
},
"mistral": {
"token_limit": 128000,
"supports_temperature": True
},
"gemma2": {
"token_limit": 128000,
"supports_temperature": True
},
"gemma2:9b": {
"token_limit": 128000,
"supports_temperature": True
},
"gemma2:27b": {
"token_limit": 128000,
"supports_temperature": True
},
# embedding models # embedding models
"shaw/dmeta-embedding-zh-small-q4": { "shaw/dmeta-embedding-zh-small-q4": {
"token_limit": 8192, "token_limit": 8192,
"supports_temperature": True "supports_temperature": True,
}, },
"shaw/dmeta-embedding-zh-q4": { "shaw/dmeta-embedding-zh-q4": {
"token_limit": 8192, "token_limit": 8192,
"supports_temperature": True "supports_temperature": True,
}, },
"chevalblanc/acge_text_embedding": { "chevalblanc/acge_text_embedding": {
"token_limit": 8192, "token_limit": 8192,
"supports_temperature": True "supports_temperature": True,
}, },
"martcreation/dmeta-embedding-zh": { "martcreation/dmeta-embedding-zh": {
"token_limit": 8192, "token_limit": 8192,
"supports_temperature": True "supports_temperature": True,
}, },
"snowflake-arctic-embed": { "snowflake-arctic-embed": {"token_limit": 8192, "supports_temperature": True},
"token_limit": 8192, "mxbai-embed-large": {"token_limit": 512, "supports_temperature": True},
"supports_temperature": True
},
"mxbai-embed-large": {
"token_limit": 512,
"supports_temperature": True
}
},
"oneapi": {
"qwen-turbo": {
"token_limit": 6000,
"supports_temperature": True
}
}, },
"oneapi": {"qwen-turbo": {"token_limit": 6000, "supports_temperature": True}},
"nvidia": { "nvidia": {
"meta/llama3-70b-instruct": { "meta/llama3-70b-instruct": {"token_limit": 419, "supports_temperature": True},
"token_limit": 419, "meta/llama3-8b-instruct": {"token_limit": 419, "supports_temperature": True},
"supports_temperature": True "nemotron-4-340b-instruct": {"token_limit": 1024, "supports_temperature": True},
}, "databricks/dbrx-instruct": {"token_limit": 4096, "supports_temperature": True},
"meta/llama3-8b-instruct": { "google/codegemma-7b": {"token_limit": 8192, "supports_temperature": True},
"token_limit": 419, "google/gemma-2b": {"token_limit": 2048, "supports_temperature": True},
"supports_temperature": True "google/gemma-7b": {"token_limit": 8192, "supports_temperature": True},
}, "google/recurrentgemma-2b": {"token_limit": 2048, "supports_temperature": True},
"nemotron-4-340b-instruct": { "meta/codellama-70b": {"token_limit": 16384, "supports_temperature": True},
"token_limit": 1024, "meta/llama2-70b": {"token_limit": 4096, "supports_temperature": True},
"supports_temperature": True
},
"databricks/dbrx-instruct": {
"token_limit": 4096,
"supports_temperature": True
},
"google/codegemma-7b": {
"token_limit": 8192,
"supports_temperature": True
},
"google/gemma-2b": {
"token_limit": 2048,
"supports_temperature": True
},
"google/gemma-7b": {
"token_limit": 8192,
"supports_temperature": True
},
"google/recurrentgemma-2b": {
"token_limit": 2048,
"supports_temperature": True
},
"meta/codellama-70b": {
"token_limit": 16384,
"supports_temperature": True
},
"meta/llama2-70b": {
"token_limit": 4096,
"supports_temperature": True
},
"microsoft/phi-3-mini-128k-instruct": { "microsoft/phi-3-mini-128k-instruct": {
"token_limit": 122880, "token_limit": 122880,
"supports_temperature": True "supports_temperature": True,
}, },
"mistralai/mistral-7b-instruct-v0.2": { "mistralai/mistral-7b-instruct-v0.2": {
"token_limit": 4096, "token_limit": 4096,
"supports_temperature": True "supports_temperature": True,
},
"mistralai/mistral-large": {
"token_limit": 8192,
"supports_temperature": True
}, },
"mistralai/mistral-large": {"token_limit": 8192, "supports_temperature": True},
"mistralai/mixtral-8x22b-instruct-v0.1": { "mistralai/mixtral-8x22b-instruct-v0.1": {
"token_limit": 32768, "token_limit": 32768,
"supports_temperature": True "supports_temperature": True,
}, },
"mistralai/mixtral-8x7b-instruct-v0.1": { "mistralai/mixtral-8x7b-instruct-v0.1": {
"token_limit": 8192, "token_limit": 8192,
"supports_temperature": True "supports_temperature": True,
}, },
"snowflake/arctic": { "snowflake/arctic": {"token_limit": 16384, "supports_temperature": True},
"token_limit": 16384,
"supports_temperature": True
}
}, },
"groq": { "groq": {
"llama3-8b-8192": { "llama3-8b-8192": {"token_limit": 8192, "supports_temperature": True},
"token_limit": 8192, "llama3-70b-8192": {"token_limit": 8192, "supports_temperature": True},
"supports_temperature": True "mixtral-8x7b-32768": {"token_limit": 32768, "supports_temperature": True},
}, "gemma-7b-it": {"token_limit": 8192, "supports_temperature": True},
"llama3-70b-8192": { "claude-3-haiku-20240307'": {"token_limit": 8192, "supports_temperature": True},
"token_limit": 8192,
"supports_temperature": True
},
"mixtral-8x7b-32768": {
"token_limit": 32768,
"supports_temperature": True
},
"gemma-7b-it": {
"token_limit": 8192,
"supports_temperature": True
},
"claude-3-haiku-20240307'": {
"token_limit": 8192,
"supports_temperature": True
}
}, },
"toghetherai": { "toghetherai": {
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": { "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": {
"token_limit": 128000, "token_limit": 128000,
"supports_temperature": True "supports_temperature": True,
}, },
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": { "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": {
"token_limit": 128000, "token_limit": 128000,
"supports_temperature": True "supports_temperature": True,
}, },
"mistralai/Mixtral-8x22B-Instruct-v0.1": { "mistralai/Mixtral-8x22B-Instruct-v0.1": {
"token_limit": 128000, "token_limit": 128000,
"supports_temperature": True "supports_temperature": True,
}, },
"stabilityai/stable-diffusion-xl-base-1.0": { "stabilityai/stable-diffusion-xl-base-1.0": {
"token_limit": 2048, "token_limit": 2048,
"supports_temperature": True "supports_temperature": True,
}, },
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": { "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": {
"token_limit": 128000, "token_limit": 128000,
"supports_temperature": True "supports_temperature": True,
}, },
"NousResearch/Hermes-3-Llama-3.1-405B-Turbo": { "NousResearch/Hermes-3-Llama-3.1-405B-Turbo": {
"token_limit": 128000, "token_limit": 128000,
"supports_temperature": True "supports_temperature": True,
}, },
"Gryphe/MythoMax-L2-13b-Lite": { "Gryphe/MythoMax-L2-13b-Lite": {
"token_limit": 8192, "token_limit": 8192,
"supports_temperature": True "supports_temperature": True,
},
"Salesforce/Llama-Rank-V1": {
"token_limit": 8192,
"supports_temperature": True
}, },
"Salesforce/Llama-Rank-V1": {"token_limit": 8192, "supports_temperature": True},
"meta-llama/Meta-Llama-Guard-3-8B": { "meta-llama/Meta-Llama-Guard-3-8B": {
"token_limit": 128000, "token_limit": 128000,
"supports_temperature": True "supports_temperature": True,
}, },
"meta-llama/Meta-Llama-3-70B-Instruct-Turbo": { "meta-llama/Meta-Llama-3-70B-Instruct-Turbo": {
"token_limit": 128000, "token_limit": 128000,
"supports_temperature": True "supports_temperature": True,
}, },
"meta-llama/Llama-3-8b-chat-hf": { "meta-llama/Llama-3-8b-chat-hf": {
"token_limit": 8192, "token_limit": 8192,
"supports_temperature": True "supports_temperature": True,
}, },
"meta-llama/Llama-3-70b-chat-hf": { "meta-llama/Llama-3-70b-chat-hf": {
"token_limit": 8192, "token_limit": 8192,
"supports_temperature": True "supports_temperature": True,
}, },
"Qwen/Qwen2-72B-Instruct": { "Qwen/Qwen2-72B-Instruct": {
"token_limit": 128000, "token_limit": 128000,
"supports_temperature": True "supports_temperature": True,
}, },
"google/gemma-2-27b-it": { "google/gemma-2-27b-it": {"token_limit": 8192, "supports_temperature": True},
"token_limit": 8192,
"supports_temperature": True
}
}, },
"anthropic": { "anthropic": {
"claude_instant": { "claude_instant": {"token_limit": 100000, "supports_temperature": True},
"token_limit": 100000, "claude2": {"token_limit": 9000, "supports_temperature": True},
"supports_temperature": True "claude2.1": {"token_limit": 200000, "supports_temperature": True},
}, "claude3": {"token_limit": 200000, "supports_temperature": True},
"claude2": { "claude3.5": {"token_limit": 200000, "supports_temperature": True},
"token_limit": 9000, "claude-3-opus-20240229": {"token_limit": 200000, "supports_temperature": True},
"supports_temperature": True
},
"claude2.1": {
"token_limit": 200000,
"supports_temperature": True
},
"claude3": {
"token_limit": 200000,
"supports_temperature": True
},
"claude3.5": {
"token_limit": 200000,
"supports_temperature": True
},
"claude-3-opus-20240229": {
"token_limit": 200000,
"supports_temperature": True
},
"claude-3-sonnet-20240229": { "claude-3-sonnet-20240229": {
"token_limit": 200000, "token_limit": 200000,
"supports_temperature": True "supports_temperature": True,
}, },
"claude-3-haiku-20240307": { "claude-3-haiku-20240307": {
"token_limit": 200000, "token_limit": 200000,
"supports_temperature": True "supports_temperature": True,
}, },
"claude-3-5-sonnet-20240620": { "claude-3-5-sonnet-20240620": {
"token_limit": 200000, "token_limit": 200000,
"supports_temperature": True "supports_temperature": True,
}, },
"claude-3-5-sonnet-20241022": { "claude-3-5-sonnet-20241022": {
"token_limit": 200000, "token_limit": 200000,
"supports_temperature": True "supports_temperature": True,
}, },
"claude-3-5-haiku-latest": { "claude-3-5-haiku-latest": {
"token_limit": 200000, "token_limit": 200000,
"supports_temperature": True "supports_temperature": True,
} },
}, },
"bedrock": { "bedrock": {
"anthropic.claude-3-haiku-20240307-v1:0": { "anthropic.claude-3-haiku-20240307-v1:0": {
"token_limit": 200000, "token_limit": 200000,
"supports_temperature": True "supports_temperature": True,
}, },
"anthropic.claude-3-sonnet-20240229-v1:0": { "anthropic.claude-3-sonnet-20240229-v1:0": {
"token_limit": 200000, "token_limit": 200000,
"supports_temperature": True "supports_temperature": True,
}, },
"anthropic.claude-3-opus-20240229-v1:0": { "anthropic.claude-3-opus-20240229-v1:0": {
"token_limit": 200000, "token_limit": 200000,
"supports_temperature": True "supports_temperature": True,
}, },
"anthropic.claude-3-5-sonnet-20240620-v1:0": { "anthropic.claude-3-5-sonnet-20240620-v1:0": {
"token_limit": 200000, "token_limit": 200000,
"supports_temperature": True "supports_temperature": True,
}, },
"claude-3-5-haiku-latest": { "claude-3-5-haiku-latest": {
"token_limit": 200000, "token_limit": 200000,
"supports_temperature": True "supports_temperature": True,
},
"anthropic.claude-v2:1": {
"token_limit": 200000,
"supports_temperature": True
},
"anthropic.claude-v2": {
"token_limit": 100000,
"supports_temperature": True
}, },
"anthropic.claude-v2:1": {"token_limit": 200000, "supports_temperature": True},
"anthropic.claude-v2": {"token_limit": 100000, "supports_temperature": True},
"anthropic.claude-instant-v1": { "anthropic.claude-instant-v1": {
"token_limit": 100000, "token_limit": 100000,
"supports_temperature": True "supports_temperature": True,
}, },
"meta.llama3-8b-instruct-v1:0": { "meta.llama3-8b-instruct-v1:0": {
"token_limit": 8192, "token_limit": 8192,
"supports_temperature": True "supports_temperature": True,
}, },
"meta.llama3-70b-instruct-v1:0": { "meta.llama3-70b-instruct-v1:0": {
"token_limit": 8192, "token_limit": 8192,
"supports_temperature": True "supports_temperature": True,
},
"meta.llama2-13b-chat-v1": {
"token_limit": 4096,
"supports_temperature": True
},
"meta.llama2-70b-chat-v1": {
"token_limit": 4096,
"supports_temperature": True
}, },
"meta.llama2-13b-chat-v1": {"token_limit": 4096, "supports_temperature": True},
"meta.llama2-70b-chat-v1": {"token_limit": 4096, "supports_temperature": True},
"mistral.mistral-7b-instruct-v0:2": { "mistral.mistral-7b-instruct-v0:2": {
"token_limit": 32768, "token_limit": 32768,
"supports_temperature": True "supports_temperature": True,
}, },
"mistral.mixtral-8x7b-instruct-v0:1": { "mistral.mixtral-8x7b-instruct-v0:1": {
"token_limit": 32768, "token_limit": 32768,
"supports_temperature": True "supports_temperature": True,
}, },
"mistral.mistral-large-2402-v1:0": { "mistral.mistral-large-2402-v1:0": {
"token_limit": 32768, "token_limit": 32768,
"supports_temperature": True "supports_temperature": True,
}, },
"mistral.mistral-small-2402-v1:0": { "mistral.mistral-small-2402-v1:0": {
"token_limit": 32768, "token_limit": 32768,
"supports_temperature": True "supports_temperature": True,
}, },
"amazon.titan-embed-text-v1": { "amazon.titan-embed-text-v1": {
"token_limit": 8000, "token_limit": 8000,
"supports_temperature": True "supports_temperature": True,
}, },
"amazon.titan-embed-text-v2:0": { "amazon.titan-embed-text-v2:0": {
"token_limit": 8000, "token_limit": 8000,
"supports_temperature": True "supports_temperature": True,
},
"cohere.embed-english-v3": {
"token_limit": 512,
"supports_temperature": True
}, },
"cohere.embed-english-v3": {"token_limit": 512, "supports_temperature": True},
"cohere.embed-multilingual-v3": { "cohere.embed-multilingual-v3": {
"token_limit": 512, "token_limit": 512,
"supports_temperature": True "supports_temperature": True,
} },
}, },
"mistralai": { "mistralai": {
"mistral-large-latest": { "mistral-large-latest": {"token_limit": 128000, "supports_temperature": True},
"token_limit": 128000, "open-mistral-nemo": {"token_limit": 128000, "supports_temperature": True},
"supports_temperature": True "codestral-latest": {"token_limit": 32000, "supports_temperature": True},
},
"open-mistral-nemo": {
"token_limit": 128000,
"supports_temperature": True
},
"codestral-latest": {
"token_limit": 32000,
"supports_temperature": True
}
}, },
"togetherai": { "togetherai": {
"Meta-Llama-3.1-70B-Instruct-Turbo": { "Meta-Llama-3.1-70B-Instruct-Turbo": {
"token_limit": 128000, "token_limit": 128000,
"supports_temperature": True "supports_temperature": True,
}
} }
},
} }

View File

@ -2,4 +2,4 @@
from .server import run_server from .server import run_server
__all__ = ['run_server'] __all__ = ["run_server"]

View File

@ -1,17 +1,17 @@
"""Web interface server implementation for RA.Aid.""" """Web interface server implementation for RA.Aid."""
import asyncio import asyncio
import logging
import shutil import shutil
import sys import sys
import logging
from pathlib import Path from pathlib import Path
from typing import List, Dict, Any from typing import Any, Dict, List
import uvicorn import uvicorn
from fastapi import FastAPI, WebSocket, Request, WebSocketDisconnect from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
# Configure logging # Configure logging
@ -19,8 +19,10 @@ logging.basicConfig(level=logging.DEBUG) # Set to DEBUG for more info
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Verify ra-aid is available # Verify ra-aid is available
if not shutil.which('ra-aid'): if not shutil.which("ra-aid"):
logger.error("ra-aid command not found. Please ensure it's installed and in your PATH") logger.error(
"ra-aid command not found. Please ensure it's installed and in your PATH"
)
sys.exit(1) sys.exit(1)
app = FastAPI() app = FastAPI()
@ -45,6 +47,7 @@ logger.info(f"Using static directory: {STATIC_DIR}")
# Setup templates # Setup templates
templates = Jinja2Templates(directory=str(STATIC_DIR)) templates = Jinja2Templates(directory=str(STATIC_DIR))
class ConnectionManager: class ConnectionManager:
def __init__(self): def __init__(self):
self.active_connections: List[WebSocket] = [] self.active_connections: List[WebSocket] = []
@ -73,33 +76,40 @@ class ConnectionManager:
async def handle_error(self, websocket: WebSocket, error_message: str): async def handle_error(self, websocket: WebSocket, error_message: str):
try: try:
await websocket.send_json({ await websocket.send_json(
{
"type": "chunk", "type": "chunk",
"chunk": { "chunk": {
"tools": { "tools": {
"messages": [{ "messages": [
{
"content": f"Error: {error_message}", "content": f"Error: {error_message}",
"status": "error" "status": "error",
}]
} }
]
} }
}) },
}
)
except Exception as e: except Exception as e:
logger.error(f"Error sending error message: {e}") logger.error(f"Error sending error message: {e}")
manager = ConnectionManager() manager = ConnectionManager()
@app.get("/", response_class=HTMLResponse) @app.get("/", response_class=HTMLResponse)
async def get_root(request: Request): async def get_root(request: Request):
"""Serve the index.html file with port parameter.""" """Serve the index.html file with port parameter."""
return templates.TemplateResponse( return templates.TemplateResponse(
"index.html", "index.html", {"request": request, "server_port": request.url.port or 8080}
{"request": request, "server_port": request.url.port or 8080}
) )
# Mount static files for js and other assets # Mount static files for js and other assets
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
@app.websocket("/ws") @app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
client_id = id(websocket) client_id = id(websocket)
@ -113,17 +123,19 @@ async def websocket_endpoint(websocket: WebSocket):
try: try:
# Send initial connection success message # Send initial connection success message
await manager.send_message(websocket, { await manager.send_message(
websocket,
{
"type": "chunk", "type": "chunk",
"chunk": { "chunk": {
"agent": { "agent": {
"messages": [{ "messages": [
"content": "Connected to RA.Aid server", {"content": "Connected to RA.Aid server", "status": "info"}
"status": "info" ]
}]
} }
} },
}) },
)
while True: while True:
try: try:
@ -131,9 +143,7 @@ async def websocket_endpoint(websocket: WebSocket):
logger.debug(f"Received message from client {client_id}: {message}") logger.debug(f"Received message from client {client_id}: {message}")
if message["type"] == "request": if message["type"] == "request":
await manager.send_message(websocket, { await manager.send_message(websocket, {"type": "stream_start"})
"type": "stream_start"
})
try: try:
# Run ra-aid with the request # Run ra-aid with the request
@ -143,7 +153,7 @@ async def websocket_endpoint(websocket: WebSocket):
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
*cmd, *cmd,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE stderr=asyncio.subprocess.PIPE,
) )
logger.info(f"Process started with PID: {process.pid}") logger.info(f"Process started with PID: {process.pid}")
@ -156,23 +166,32 @@ async def websocket_endpoint(websocket: WebSocket):
try: try:
decoded_line = line.decode().strip() decoded_line = line.decode().strip()
if decoded_line: if decoded_line:
await manager.send_message(websocket, { await manager.send_message(
websocket,
{
"type": "chunk", "type": "chunk",
"chunk": { "chunk": {
"tools" if is_error else "agent": { "tools" if is_error else "agent": {
"messages": [{ "messages": [
{
"content": decoded_line, "content": decoded_line,
"status": "error" if is_error else "info" "status": "error"
}] if is_error
else "info",
} }
]
} }
}) },
},
)
except Exception as e: except Exception as e:
logger.error(f"Error processing output: {e}") logger.error(f"Error processing output: {e}")
# Create tasks for reading stdout and stderr # Create tasks for reading stdout and stderr
stdout_task = asyncio.create_task(read_stream(process.stdout)) stdout_task = asyncio.create_task(read_stream(process.stdout))
stderr_task = asyncio.create_task(read_stream(process.stderr, True)) stderr_task = asyncio.create_task(
read_stream(process.stderr, True)
)
# Wait for both streams to complete # Wait for both streams to complete
await asyncio.gather(stdout_task, stderr_task) await asyncio.gather(stdout_task, stderr_task)
@ -182,14 +201,13 @@ async def websocket_endpoint(websocket: WebSocket):
if return_code != 0: if return_code != 0:
await manager.handle_error( await manager.handle_error(
websocket, websocket, f"Process exited with code {return_code}"
f"Process exited with code {return_code}"
) )
await manager.send_message(websocket, { await manager.send_message(
"type": "stream_end", websocket,
"request": message["content"] {"type": "stream_end", "request": message["content"]},
}) )
except Exception as e: except Exception as e:
logger.error(f"Error executing ra-aid: {e}") logger.error(f"Error executing ra-aid: {e}")
@ -207,6 +225,7 @@ async def websocket_endpoint(websocket: WebSocket):
manager.disconnect(websocket) manager.disconnect(websocket)
logger.info(f"WebSocket connection cleaned up for client {client_id}") logger.info(f"WebSocket connection cleaned up for client {client_id}")
def run_server(host: str = "0.0.0.0", port: int = 8080): def run_server(host: str = "0.0.0.0", port: int = 8080):
"""Run the FastAPI server.""" """Run the FastAPI server."""
logger.info(f"Starting server on {host}:{port}") logger.info(f"Starting server on {host}:{port}")
@ -216,5 +235,5 @@ def run_server(host: str = "0.0.0.0", port: int = 8080):
port=port, port=port,
log_level="debug", log_level="debug",
ws_max_size=16777216, # 16MB ws_max_size=16777216, # 16MB
timeout_keep_alive=0 # Disable keep-alive timeout timeout_keep_alive=0, # Disable keep-alive timeout
) )

View File

@ -35,7 +35,7 @@ def test_get_model_token_limit_anthropic(mock_memory):
"""Test get_model_token_limit with Anthropic model.""" """Test get_model_token_limit with Anthropic model."""
config = {"provider": "anthropic", "model": "claude2"} config = {"provider": "anthropic", "model": "claude2"}
token_limit = get_model_token_limit(config) token_limit = get_model_token_limit(config, "default")
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"] assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
@ -43,7 +43,7 @@ def test_get_model_token_limit_openai(mock_memory):
"""Test get_model_token_limit with OpenAI model.""" """Test get_model_token_limit with OpenAI model."""
config = {"provider": "openai", "model": "gpt-4"} config = {"provider": "openai", "model": "gpt-4"}
token_limit = get_model_token_limit(config) token_limit = get_model_token_limit(config, "default")
assert token_limit == models_params["openai"]["gpt-4"]["token_limit"] assert token_limit == models_params["openai"]["gpt-4"]["token_limit"]
@ -51,7 +51,7 @@ def test_get_model_token_limit_unknown(mock_memory):
"""Test get_model_token_limit with unknown provider/model.""" """Test get_model_token_limit with unknown provider/model."""
config = {"provider": "unknown", "model": "unknown-model"} config = {"provider": "unknown", "model": "unknown-model"}
token_limit = get_model_token_limit(config) token_limit = get_model_token_limit(config, "default")
assert token_limit is None assert token_limit is None
@ -59,7 +59,7 @@ def test_get_model_token_limit_missing_config(mock_memory):
"""Test get_model_token_limit with missing configuration.""" """Test get_model_token_limit with missing configuration."""
config = {} config = {}
token_limit = get_model_token_limit(config) token_limit = get_model_token_limit(config, "default")
assert token_limit is None assert token_limit is None
@ -69,7 +69,7 @@ def test_get_model_token_limit_litellm_success():
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.return_value = {"max_input_tokens": 100000} mock_get_info.return_value = {"max_input_tokens": 100000}
token_limit = get_model_token_limit(config) token_limit = get_model_token_limit(config, "default")
assert token_limit == 100000 assert token_limit == 100000
@ -81,7 +81,7 @@ def test_get_model_token_limit_litellm_not_found():
mock_get_info.side_effect = litellm.exceptions.NotFoundError( mock_get_info.side_effect = litellm.exceptions.NotFoundError(
message="Model not found", model="claude-2", llm_provider="anthropic" message="Model not found", model="claude-2", llm_provider="anthropic"
) )
token_limit = get_model_token_limit(config) token_limit = get_model_token_limit(config, "default")
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"] assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
@ -91,7 +91,7 @@ def test_get_model_token_limit_litellm_error():
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.side_effect = Exception("Unknown error") mock_get_info.side_effect = Exception("Unknown error")
token_limit = get_model_token_limit(config) token_limit = get_model_token_limit(config, "default")
assert token_limit == models_params["anthropic"]["claude2"]["token_limit"] assert token_limit == models_params["anthropic"]["claude2"]["token_limit"]
@ -99,7 +99,7 @@ def test_get_model_token_limit_unexpected_error():
"""Test returning None when unexpected errors occur.""" """Test returning None when unexpected errors occur."""
config = None # This will cause an attribute error when accessed config = None # This will cause an attribute error when accessed
token_limit = get_model_token_limit(config) token_limit = get_model_token_limit(config, "default")
assert token_limit is None assert token_limit is None
@ -247,3 +247,31 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_memory)
assert agent == "react_agent" assert agent == "react_agent"
mock_react.assert_called_once_with(mock_model, []) mock_react.assert_called_once_with(mock_model, [])
def test_get_model_token_limit_research(mock_memory):
"""Test get_model_token_limit with research provider and model."""
config = {
"provider": "openai",
"model": "gpt-4",
"research_provider": "anthropic",
"research_model": "claude-2",
}
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.return_value = {"max_input_tokens": 150000}
token_limit = get_model_token_limit(config, "research")
assert token_limit == 150000
def test_get_model_token_limit_planner(mock_memory):
"""Test get_model_token_limit with planner provider and model."""
config = {
"provider": "openai",
"model": "gpt-4",
"planner_provider": "deepseek",
"planner_model": "dsm-1",
}
with patch("ra_aid.agent_utils.get_model_info") as mock_get_info:
mock_get_info.return_value = {"max_input_tokens": 120000}
token_limit = get_model_token_limit(config, "planner")
assert token_limit == 120000

View File

@ -1,22 +1,23 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import asyncio
import argparse import argparse
import json import asyncio
import sys
import shutil import shutil
import sys
from pathlib import Path from pathlib import Path
from typing import List, Dict, Any from typing import List
import uvicorn import uvicorn
from fastapi import FastAPI, WebSocket, Request, WebSocketDisconnect from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
# Verify ra-aid is available # Verify ra-aid is available
if not shutil.which('ra-aid'): if not shutil.which("ra-aid"):
print("Error: ra-aid command not found. Please ensure it's installed and in your PATH") print(
"Error: ra-aid command not found. Please ensure it's installed and in your PATH"
)
sys.exit(1) sys.exit(1)
app = FastAPI() app = FastAPI()
@ -33,15 +34,16 @@ app.add_middleware(
# Setup templates # Setup templates
templates = Jinja2Templates(directory=Path(__file__).parent) templates = Jinja2Templates(directory=Path(__file__).parent)
# Create a route for the root to serve index.html with port parameter # Create a route for the root to serve index.html with port parameter
@app.get("/", response_class=HTMLResponse) @app.get("/", response_class=HTMLResponse)
async def get_root(request: Request): async def get_root(request: Request):
"""Serve the index.html file with port parameter.""" """Serve the index.html file with port parameter."""
return templates.TemplateResponse( return templates.TemplateResponse(
"index.html", "index.html", {"request": request, "server_port": request.url.port or 8080}
{"request": request, "server_port": request.url.port or 8080}
) )
# Mount static files for js and other assets # Mount static files for js and other assets
app.mount("/static", StaticFiles(directory=Path(__file__).parent), name="static") app.mount("/static", StaticFiles(directory=Path(__file__).parent), name="static")
@ -50,6 +52,7 @@ app.mount("/static", StaticFiles(directory=Path(__file__).parent), name="static"
# Store active WebSocket connections # Store active WebSocket connections
active_connections: List[WebSocket] = [] active_connections: List[WebSocket] = []
@app.websocket("/ws") @app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
print(f"New WebSocket connection from {websocket.client}") print(f"New WebSocket connection from {websocket.client}")
@ -66,9 +69,7 @@ async def websocket_endpoint(websocket: WebSocket):
if message["type"] == "request": if message["type"] == "request":
print(f"Processing request: {message['content']}") print(f"Processing request: {message['content']}")
# Notify client that streaming is starting # Notify client that streaming is starting
await websocket.send_json({ await websocket.send_json({"type": "stream_start"})
"type": "stream_start"
})
try: try:
# Run ra-aid with the request using -m flag and cowboy mode # Run ra-aid with the request using -m flag and cowboy mode
@ -79,7 +80,7 @@ async def websocket_endpoint(websocket: WebSocket):
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
*cmd, *cmd,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE stderr=asyncio.subprocess.PIPE,
) )
print(f"Process started with PID: {process.pid}") print(f"Process started with PID: {process.pid}")
@ -97,17 +98,25 @@ async def websocket_endpoint(websocket: WebSocket):
decoded_line = line.decode().strip() decoded_line = line.decode().strip()
print(f"{stream_type} line: {decoded_line}") print(f"{stream_type} line: {decoded_line}")
if decoded_line: if decoded_line:
await websocket.send_json({ await websocket.send_json(
{
"type": "chunk", "type": "chunk",
"chunk": { "chunk": {
"tools" if is_error else "agent": { "tools" if is_error else "agent": {
"messages": [{ "messages": [
{
"content": decoded_line, "content": decoded_line,
"status": "error" if is_error else "info" "status": (
}] "error"
if is_error
else "info"
),
} }
]
} }
}) },
}
)
except Exception as e: except Exception as e:
print(f"Error sending output: {e}") print(f"Error sending output: {e}")
@ -122,74 +131,80 @@ async def websocket_endpoint(websocket: WebSocket):
return_code = await process.wait() return_code = await process.wait()
if return_code != 0: if return_code != 0:
await websocket.send_json({ await websocket.send_json(
{
"type": "chunk", "type": "chunk",
"chunk": { "chunk": {
"tools": { "tools": {
"messages": [{ "messages": [
{
"content": f"Process exited with code {return_code}", "content": f"Process exited with code {return_code}",
"status": "error" "status": "error",
}]
} }
]
} }
}) },
}
)
# Notify client that streaming is complete # Notify client that streaming is complete
await websocket.send_json({ await websocket.send_json(
"type": "stream_end", {"type": "stream_end", "request": message["content"]}
"request": message["content"] )
})
except Exception as e: except Exception as e:
error_msg = f"Error executing ra-aid: {str(e)}" error_msg = f"Error executing ra-aid: {str(e)}"
print(error_msg) print(error_msg)
await websocket.send_json({ await websocket.send_json(
{
"type": "chunk", "type": "chunk",
"chunk": { "chunk": {
"tools": { "tools": {
"messages": [{ "messages": [
"content": error_msg, {"content": error_msg, "status": "error"}
"status": "error" ]
}]
} }
},
} }
}) )
except WebSocketDisconnect: except WebSocketDisconnect:
print(f"WebSocket client disconnected") print("WebSocket client disconnected")
active_connections.remove(websocket) active_connections.remove(websocket)
except Exception as e: except Exception as e:
print(f"WebSocket error: {e}") print(f"WebSocket error: {e}")
try: try:
await websocket.send_json({ await websocket.send_json({"type": "error", "error": str(e)})
"type": "error", except Exception:
"error": str(e)
})
except:
pass pass
finally: finally:
if websocket in active_connections: if websocket in active_connections:
active_connections.remove(websocket) active_connections.remove(websocket)
print("WebSocket connection cleaned up") print("WebSocket connection cleaned up")
@app.get("/config") @app.get("/config")
async def get_config(request: Request): async def get_config(request: Request):
"""Return server configuration including host and port.""" """Return server configuration including host and port."""
return { return {"host": request.client.host, "port": request.scope.get("server")[1]}
"host": request.client.host,
"port": request.scope.get("server")[1]
}
def run_server(host: str = "0.0.0.0", port: int = 8080): def run_server(host: str = "0.0.0.0", port: int = 8080):
"""Run the FastAPI server.""" """Run the FastAPI server."""
uvicorn.run(app, host=host, port=port) uvicorn.run(app, host=host, port=port)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="RA.Aid Web Interface Server") parser = argparse.ArgumentParser(description="RA.Aid Web Interface Server")
parser.add_argument("--port", type=int, default=8080, parser.add_argument(
help="Port to listen on (default: 8080)") "--port", type=int, default=8080, help="Port to listen on (default: 8080)"
parser.add_argument("--host", type=str, default="0.0.0.0", )
help="Host to listen on (default: 0.0.0.0)") parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="Host to listen on (default: 0.0.0.0)",
)
args = parser.parse_args() args = parser.parse_args()
run_server(host=args.host, port=args.port) run_server(host=args.host, port=args.port)