From c2ba638a95007dffce668ddc49db403fca9867c0 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Sat, 1 Feb 2025 12:55:36 -0800 Subject: [PATCH 1/2] feat(agent_utils.py): enhance get_model_token_limit to support agent types for better configuration management test(agent_utils.py): add tests for get_model_token_limit with different agent types to ensure correct functionality --- ra_aid/agent_utils.py | 66 +++++++++++++++++++++----------- tests/ra_aid/test_agent_utils.py | 42 ++++++++++++++++---- 2 files changed, 77 insertions(+), 31 deletions(-) diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 05b7bde..74ce26d 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -5,7 +5,7 @@ import sys import threading import time import uuid -from typing import Any, Dict, List, Optional, Sequence +from typing import Any, Dict, List, Literal, Optional, Sequence import litellm from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError @@ -122,15 +122,24 @@ def state_modifier( return [first_message] + trimmed_remaining -def get_model_token_limit(config: Dict[str, Any]) -> Optional[int]: - """Get the token limit for the current model configuration. +def get_model_token_limit( + config: Dict[str, Any], agent_type: Literal["default", "research", "planner"] +) -> Optional[int]: + """Get the token limit for the current model configuration based on agent type. Returns: Optional[int]: The token limit if found, None otherwise """ try: - provider = config.get("provider", "") - model_name = config.get("model", "") + if agent_type == "research": + provider = config.get("research_provider", "") or config.get("provider", "") + model_name = config.get("research_model", "") or config.get("model", "") + elif agent_type == "planner": + provider = config.get("planner_provider", "") or config.get("provider", "") + model_name = config.get("planner_model", "") or config.get("model", "") + else: + provider = config.get("provider", "") + model_name = config.get("model", "") try: provider_model = model_name if not provider else f"{provider}/{model_name}" @@ -224,6 +233,7 @@ def create_agent( tools: List[Any], *, checkpointer: Any = None, + agent_type: str = "default", ) -> Any: """Create a react agent with the given configuration. @@ -245,7 +255,9 @@ def create_agent( """ try: config = _global_memory.get("config", {}) - max_input_tokens = get_model_token_limit(config) or DEFAULT_TOKEN_LIMIT + max_input_tokens = ( + get_model_token_limit(config, agent_type) or DEFAULT_TOKEN_LIMIT + ) # Use REACT agent for Anthropic Claude models, otherwise use CIAYN if is_anthropic_claude(config): @@ -260,7 +272,7 @@ def create_agent( # Default to REACT agent if provider/model detection fails logger.warning(f"Failed to detect model type: {e}. Defaulting to REACT agent.") config = _global_memory.get("config", {}) - max_input_tokens = get_model_token_limit(config) + max_input_tokens = get_model_token_limit(config, agent_type) agent_kwargs = build_agent_kwargs(checkpointer, config, max_input_tokens) return create_react_agent(model, tools, **agent_kwargs) @@ -326,7 +338,7 @@ def run_research_agent( web_research_enabled=config.get("web_research_enabled", False), ) - agent = create_agent(model, tools, checkpointer=memory) + agent = create_agent(model, tools, checkpointer=memory, agent_type="research") expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else "" human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else "" @@ -349,9 +361,11 @@ def run_research_agent( prompt = (RESEARCH_ONLY_PROMPT if research_only else RESEARCH_PROMPT).format( base_task=base_task_or_query, - research_only_note="" - if research_only - else " Only request implementation if the user explicitly asked for changes to be made.", + research_only_note=( + "" + if research_only + else " Only request implementation if the user explicitly asked for changes to be made." + ), expert_section=expert_section, human_section=human_section, web_research_section=web_research_section, @@ -455,7 +469,7 @@ def run_web_research_agent( tools = get_web_research_tools(expert_enabled=expert_enabled) - agent = create_agent(model, tools, checkpointer=memory) + agent = create_agent(model, tools, checkpointer=memory, agent_type="research") expert_section = EXPERT_PROMPT_SECTION_RESEARCH if expert_enabled else "" human_section = HUMAN_PROMPT_SECTION_RESEARCH if hil else "" @@ -536,7 +550,7 @@ def run_planning_agent( web_research_enabled=config.get("web_research_enabled", False), ) - agent = create_agent(model, tools, checkpointer=memory) + agent = create_agent(model, tools, checkpointer=memory, agent_type="planner") expert_section = EXPERT_PROMPT_SECTION_PLANNING if expert_enabled else "" human_section = HUMAN_PROMPT_SECTION_PLANNING if hil else "" @@ -556,9 +570,11 @@ def run_planning_agent( key_facts=get_memory_value("key_facts"), key_snippets=get_memory_value("key_snippets"), work_log=get_memory_value("work_log"), - research_only_note="" - if config.get("research_only") - else " Only request implementation if the user explicitly asked for changes to be made.", + research_only_note=( + "" + if config.get("research_only") + else " Only request implementation if the user explicitly asked for changes to be made." + ), ) config = _global_memory.get("config", {}) if not config else config @@ -634,7 +650,7 @@ def run_task_implementation_agent( web_research_enabled=config.get("web_research_enabled", False), ) - agent = create_agent(model, tools, checkpointer=memory) + agent = create_agent(model, tools, checkpointer=memory, agent_type="planner") prompt = IMPLEMENTATION_PROMPT.format( base_task=base_task, @@ -647,12 +663,16 @@ def run_task_implementation_agent( research_notes=get_memory_value("research_notes"), work_log=get_memory_value("work_log"), expert_section=EXPERT_PROMPT_SECTION_IMPLEMENTATION if expert_enabled else "", - human_section=HUMAN_PROMPT_SECTION_IMPLEMENTATION - if _global_memory.get("config", {}).get("hil", False) - else "", - web_research_section=WEB_RESEARCH_PROMPT_SECTION_CHAT - if config.get("web_research_enabled") - else "", + human_section=( + HUMAN_PROMPT_SECTION_IMPLEMENTATION + if _global_memory.get("config", {}).get("hil", False) + else "" + ), + web_research_section=( + WEB_RESEARCH_PROMPT_SECTION_CHAT + if config.get("web_research_enabled") + else "" + ), ) config = _global_memory.get("config", {}) if not config else config diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index d557c97..549a862 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -35,7 +35,7 @@ def test_get_model_token_limit_anthropic(mock_memory): """Test get_model_token_limit with Anthropic model.""" config = {"provider": "anthropic", "model": "claude2"} - token_limit = get_model_token_limit(config) + token_limit = get_model_token_limit(config, "default") assert token_limit == models_params["anthropic"]["claude2"]["token_limit"] @@ -43,7 +43,7 @@ def test_get_model_token_limit_openai(mock_memory): """Test get_model_token_limit with OpenAI model.""" config = {"provider": "openai", "model": "gpt-4"} - token_limit = get_model_token_limit(config) + token_limit = get_model_token_limit(config, "default") assert token_limit == models_params["openai"]["gpt-4"]["token_limit"] @@ -51,7 +51,7 @@ def test_get_model_token_limit_unknown(mock_memory): """Test get_model_token_limit with unknown provider/model.""" config = {"provider": "unknown", "model": "unknown-model"} - token_limit = get_model_token_limit(config) + token_limit = get_model_token_limit(config, "default") assert token_limit is None @@ -59,7 +59,7 @@ def test_get_model_token_limit_missing_config(mock_memory): """Test get_model_token_limit with missing configuration.""" config = {} - token_limit = get_model_token_limit(config) + token_limit = get_model_token_limit(config, "default") assert token_limit is None @@ -69,7 +69,7 @@ def test_get_model_token_limit_litellm_success(): with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: mock_get_info.return_value = {"max_input_tokens": 100000} - token_limit = get_model_token_limit(config) + token_limit = get_model_token_limit(config, "default") assert token_limit == 100000 @@ -81,7 +81,7 @@ def test_get_model_token_limit_litellm_not_found(): mock_get_info.side_effect = litellm.exceptions.NotFoundError( message="Model not found", model="claude-2", llm_provider="anthropic" ) - token_limit = get_model_token_limit(config) + token_limit = get_model_token_limit(config, "default") assert token_limit == models_params["anthropic"]["claude2"]["token_limit"] @@ -91,7 +91,7 @@ def test_get_model_token_limit_litellm_error(): with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: mock_get_info.side_effect = Exception("Unknown error") - token_limit = get_model_token_limit(config) + token_limit = get_model_token_limit(config, "default") assert token_limit == models_params["anthropic"]["claude2"]["token_limit"] @@ -99,7 +99,7 @@ def test_get_model_token_limit_unexpected_error(): """Test returning None when unexpected errors occur.""" config = None # This will cause an attribute error when accessed - token_limit = get_model_token_limit(config) + token_limit = get_model_token_limit(config, "default") assert token_limit is None @@ -247,3 +247,29 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_memory) assert agent == "react_agent" mock_react.assert_called_once_with(mock_model, []) + +def test_get_model_token_limit_research(mock_memory): + """Test get_model_token_limit with research provider and model.""" + config = { + "provider": "openai", + "model": "gpt-4", + "research_provider": "anthropic", + "research_model": "claude-2" + } + with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: + mock_get_info.return_value = {"max_input_tokens": 150000} + token_limit = get_model_token_limit(config, "research") + assert token_limit == 150000 + +def test_get_model_token_limit_planner(mock_memory): + """Test get_model_token_limit with planner provider and model.""" + config = { + "provider": "openai", + "model": "gpt-4", + "planner_provider": "deepseek", + "planner_model": "dsm-1" + } + with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: + mock_get_info.return_value = {"max_input_tokens": 120000} + token_limit = get_model_token_limit(config, "planner") + assert token_limit == 120000 From 1cc5b8e16cdc35f1e70f53f289b20034fdefecd1 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Sat, 1 Feb 2025 13:00:15 -0800 Subject: [PATCH 2/2] refactor: clean up imports and improve code formatting for better readability fix: ensure proper error handling and logging in WebSocket connections feat: enhance WebSocket server to handle streaming and error messages more effectively chore: update model parameters and configurations for better performance and maintainability test: improve test coverage for model token limits and agent creation logic --- ra_aid/__main__.py | 2 + ra_aid/llm.py | 14 +- ra_aid/models_params.py | 732 ++++++++----------------------- ra_aid/webui/__init__.py | 2 +- ra_aid/webui/server.py | 149 ++++--- tests/ra_aid/test_agent_utils.py | 6 +- webui/server.py | 161 ++++--- 7 files changed, 363 insertions(+), 703 deletions(-) diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index 9eea37a..66ed881 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -34,9 +34,11 @@ logger = get_logger(__name__) def launch_webui(host: str, port: int): """Launch the RA.Aid web interface.""" from ra_aid.webui import run_server + print(f"Starting RA.Aid web interface on http://{host}:{port}") run_server(host=host, port=port) + def parse_arguments(args=None): VALID_PROVIDERS = [ "anthropic", diff --git a/ra_aid/llm.py b/ra_aid/llm.py index aebb497..3080813 100644 --- a/ra_aid/llm.py +++ b/ra_aid/llm.py @@ -1,9 +1,6 @@ import os from typing import Any, Dict, Optional -known_temp_providers = {"openai", "anthropic", "openrouter", "openai-compatible", "gemini", "deepseek"} - -from .models_params import models_params from langchain_anthropic import ChatAnthropic from langchain_core.language_models import BaseChatModel from langchain_google_genai import ChatGoogleGenerativeAI @@ -12,6 +9,17 @@ from langchain_openai import ChatOpenAI from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner from ra_aid.logging_config import get_logger +from .models_params import models_params + +known_temp_providers = { + "openai", + "anthropic", + "openrouter", + "openai-compatible", + "gemini", + "deepseek", +} + logger = get_logger(__name__) diff --git a/ra_aid/models_params.py b/ra_aid/models_params.py index 85ee55a..4bbf596 100644 --- a/ra_aid/models_params.py +++ b/ra_aid/models_params.py @@ -6,710 +6,324 @@ DEFAULT_TOKEN_LIMIT = 100000 models_params = { "openai": { - "gpt-3.5-turbo-0125": { - "token_limit": 16385, - "supports_temperature": True - }, - "gpt-3.5": { - "token_limit": 4096, - "supports_temperature": True - }, - "gpt-3.5-turbo": { - "token_limit": 16385, - "supports_temperature": True - }, - "gpt-3.5-turbo-1106": { - "token_limit": 16385, - "supports_temperature": True - }, - "gpt-3.5-turbo-instruct": { - "token_limit": 4096, - "supports_temperature": True - }, - "gpt-4-0125-preview": { - "token_limit": 128000, - "supports_temperature": True - }, - "gpt-4-turbo-preview": { - "token_limit": 128000, - "supports_temperature": True - }, - "gpt-4-turbo": { - "token_limit": 128000, - "supports_temperature": True - }, - "gpt-4-turbo-2024-04-09": { - "token_limit": 128000, - "supports_temperature": True - }, - "gpt-4-1106-preview": { - "token_limit": 128000, - "supports_temperature": True - }, - "gpt-4-vision-preview": { - "token_limit": 128000, - "supports_temperature": True - }, - "gpt-4": { - "token_limit": 8192, - "supports_temperature": True - }, - "gpt-4-0613": { - "token_limit": 8192, - "supports_temperature": True - }, - "gpt-4-32k": { - "token_limit": 32768, - "supports_temperature": True - }, - "gpt-4-32k-0613": { - "token_limit": 32768, - "supports_temperature": True - }, - "gpt-4o": { - "token_limit": 128000, - "supports_temperature": True - }, - "gpt-4o-2024-08-06": { - "token_limit": 128000, - "supports_temperature": True - }, - "gpt-4o-2024-05-13": { - "token_limit": 128000, - "supports_temperature": True - }, - "gpt-4o-mini": { - "token_limit": 128000, - "supports_temperature": True - }, - "o1-preview": { - "token_limit": 128000, - "supports_temperature": False - }, - "o1-mini": { - "token_limit": 128000, - "supports_temperature": False - } + "gpt-3.5-turbo-0125": {"token_limit": 16385, "supports_temperature": True}, + "gpt-3.5": {"token_limit": 4096, "supports_temperature": True}, + "gpt-3.5-turbo": {"token_limit": 16385, "supports_temperature": True}, + "gpt-3.5-turbo-1106": {"token_limit": 16385, "supports_temperature": True}, + "gpt-3.5-turbo-instruct": {"token_limit": 4096, "supports_temperature": True}, + "gpt-4-0125-preview": {"token_limit": 128000, "supports_temperature": True}, + "gpt-4-turbo-preview": {"token_limit": 128000, "supports_temperature": True}, + "gpt-4-turbo": {"token_limit": 128000, "supports_temperature": True}, + "gpt-4-turbo-2024-04-09": {"token_limit": 128000, "supports_temperature": True}, + "gpt-4-1106-preview": {"token_limit": 128000, "supports_temperature": True}, + "gpt-4-vision-preview": {"token_limit": 128000, "supports_temperature": True}, + "gpt-4": {"token_limit": 8192, "supports_temperature": True}, + "gpt-4-0613": {"token_limit": 8192, "supports_temperature": True}, + "gpt-4-32k": {"token_limit": 32768, "supports_temperature": True}, + "gpt-4-32k-0613": {"token_limit": 32768, "supports_temperature": True}, + "gpt-4o": {"token_limit": 128000, "supports_temperature": True}, + "gpt-4o-2024-08-06": {"token_limit": 128000, "supports_temperature": True}, + "gpt-4o-2024-05-13": {"token_limit": 128000, "supports_temperature": True}, + "gpt-4o-mini": {"token_limit": 128000, "supports_temperature": True}, + "o1-preview": {"token_limit": 128000, "supports_temperature": False}, + "o1-mini": {"token_limit": 128000, "supports_temperature": False}, }, "azure_openai": { - "gpt-3.5-turbo-0125": { - "token_limit": 16385, - "supports_temperature": True - }, - "gpt-3.5": { - "token_limit": 4096, - "supports_temperature": True - }, - "gpt-3.5-turbo": { - "token_limit": 16385, - "supports_temperature": True - }, - "gpt-3.5-turbo-1106": { - "token_limit": 16385, - "supports_temperature": True - }, - "gpt-3.5-turbo-instruct": { - "token_limit": 4096, - "supports_temperature": True - }, - "gpt-4-0125-preview": { - "token_limit": 128000, - "supports_temperature": True - }, - "gpt-4-turbo-preview": { - "token_limit": 128000, - "supports_temperature": True - }, - "gpt-4-turbo": { - "token_limit": 128000, - "supports_temperature": True - }, - "gpt-4-turbo-2024-04-09": { - "token_limit": 128000, - "supports_temperature": True - }, - "gpt-4-1106-preview": { - "token_limit": 128000, - "supports_temperature": True - }, - "gpt-4-vision-preview": { - "token_limit": 128000, - "supports_temperature": True - }, - "gpt-4": { - "token_limit": 8192, - "supports_temperature": True - }, - "gpt-4-0613": { - "token_limit": 8192, - "supports_temperature": True - }, - "gpt-4-32k": { - "token_limit": 32768, - "supports_temperature": True - }, - "gpt-4-32k-0613": { - "token_limit": 32768, - "supports_temperature": True - }, - "gpt-4o": { - "token_limit": 128000, - "supports_temperature": True - }, - "gpt-4o-mini": { - "token_limit": 128000, - "supports_temperature": True - }, - "chatgpt-4o-latest": { - "token_limit": 128000, - "supports_temperature": True - }, - "o1-preview": { - "token_limit": 128000, - "supports_temperature": False - }, - "o1-mini": { - "token_limit": 128000, - "supports_temperature": False - } + "gpt-3.5-turbo-0125": {"token_limit": 16385, "supports_temperature": True}, + "gpt-3.5": {"token_limit": 4096, "supports_temperature": True}, + "gpt-3.5-turbo": {"token_limit": 16385, "supports_temperature": True}, + "gpt-3.5-turbo-1106": {"token_limit": 16385, "supports_temperature": True}, + "gpt-3.5-turbo-instruct": {"token_limit": 4096, "supports_temperature": True}, + "gpt-4-0125-preview": {"token_limit": 128000, "supports_temperature": True}, + "gpt-4-turbo-preview": {"token_limit": 128000, "supports_temperature": True}, + "gpt-4-turbo": {"token_limit": 128000, "supports_temperature": True}, + "gpt-4-turbo-2024-04-09": {"token_limit": 128000, "supports_temperature": True}, + "gpt-4-1106-preview": {"token_limit": 128000, "supports_temperature": True}, + "gpt-4-vision-preview": {"token_limit": 128000, "supports_temperature": True}, + "gpt-4": {"token_limit": 8192, "supports_temperature": True}, + "gpt-4-0613": {"token_limit": 8192, "supports_temperature": True}, + "gpt-4-32k": {"token_limit": 32768, "supports_temperature": True}, + "gpt-4-32k-0613": {"token_limit": 32768, "supports_temperature": True}, + "gpt-4o": {"token_limit": 128000, "supports_temperature": True}, + "gpt-4o-mini": {"token_limit": 128000, "supports_temperature": True}, + "chatgpt-4o-latest": {"token_limit": 128000, "supports_temperature": True}, + "o1-preview": {"token_limit": 128000, "supports_temperature": False}, + "o1-mini": {"token_limit": 128000, "supports_temperature": False}, }, "google_genai": { - "gemini-pro": { - "token_limit": 128000, - "supports_temperature": True - }, + "gemini-pro": {"token_limit": 128000, "supports_temperature": True}, "gemini-1.5-flash-latest": { "token_limit": 128000, - "supports_temperature": True + "supports_temperature": True, }, - "gemini-1.5-pro-latest": { - "token_limit": 128000, - "supports_temperature": True - }, - "models/embedding-001": { - "token_limit": 2048, - "supports_temperature": True - } + "gemini-1.5-pro-latest": {"token_limit": 128000, "supports_temperature": True}, + "models/embedding-001": {"token_limit": 2048, "supports_temperature": True}, }, "google_vertexai": { - "gemini-1.5-flash": { - "token_limit": 128000, - "supports_temperature": True - }, - "gemini-1.5-pro": { - "token_limit": 128000, - "supports_temperature": True - }, - "gemini-1.0-pro": { - "token_limit": 128000, - "supports_temperature": True - } + "gemini-1.5-flash": {"token_limit": 128000, "supports_temperature": True}, + "gemini-1.5-pro": {"token_limit": 128000, "supports_temperature": True}, + "gemini-1.0-pro": {"token_limit": 128000, "supports_temperature": True}, }, "ollama": { - "command-r": { - "token_limit": 12800, - "supports_temperature": True - }, - "codellama": { - "token_limit": 16000, - "supports_temperature": True - }, - "dbrx": { - "token_limit": 32768, - "supports_temperature": True - }, - "deepseek-coder:33b": { - "token_limit": 16000, - "supports_temperature": True - }, - "falcon": { - "token_limit": 2048, - "supports_temperature": True - }, - "llama2": { - "token_limit": 4096, - "supports_temperature": True - }, - "llama2:7b": { - "token_limit": 4096, - "supports_temperature": True - }, - "llama2:13b": { - "token_limit": 4096, - "supports_temperature": True - }, - "llama2:70b": { - "token_limit": 4096, - "supports_temperature": True - }, - "llama3": { - "token_limit": 8192, - "supports_temperature": True - }, - "llama3:8b": { - "token_limit": 8192, - "supports_temperature": True - }, - "llama3:70b": { - "token_limit": 8192, - "supports_temperature": True - }, - "llama3.1": { - "token_limit": 128000, - "supports_temperature": True - }, - "llama3.1:8b": { - "token_limit": 128000, - "supports_temperature": True - }, - "llama3.1:70b": { - "token_limit": 128000, - "supports_temperature": True - }, - "lama3.1:405b": { - "token_limit": 128000, - "supports_temperature": True - }, - "llama3.2": { - "token_limit": 128000, - "supports_temperature": True - }, - "llama3.2:1b": { - "token_limit": 128000, - "supports_temperature": True - }, - "llama3.2:3b": { - "token_limit": 128000, - "supports_temperature": True - }, - "llama3.3:70b": { - "token_limit": 128000, - "supports_temperature": True - }, - "scrapegraph": { - "token_limit": 8192, - "supports_temperature": True - }, - "mistral-small": { - "token_limit": 128000, - "supports_temperature": True - }, - "mistral-openorca": { - "token_limit": 32000, - "supports_temperature": True - }, - "mistral-large": { - "token_limit": 128000, - "supports_temperature": True - }, - "grok-1": { - "token_limit": 8192, - "supports_temperature": True - }, - "llava": { - "token_limit": 4096, - "supports_temperature": True - }, - "mixtral:8x22b-instruct": { - "token_limit": 65536, - "supports_temperature": True - }, - "nomic-embed-text": { - "token_limit": 8192, - "supports_temperature": True - }, - "nous-hermes2:34b": { - "token_limit": 4096, - "supports_temperature": True - }, - "orca-mini": { - "token_limit": 2048, - "supports_temperature": True - }, - "phi3:3.8b": { - "token_limit": 12800, - "supports_temperature": True - }, - "phi3:14b": { - "token_limit": 128000, - "supports_temperature": True - }, - "qwen:0.5b": { - "token_limit": 32000, - "supports_temperature": True - }, - "qwen:1.8b": { - "token_limit": 32000, - "supports_temperature": True - }, - "qwen:4b": { - "token_limit": 32000, - "supports_temperature": True - }, - "qwen:14b": { - "token_limit": 32000, - "supports_temperature": True - }, - "qwen:32b": { - "token_limit": 32000, - "supports_temperature": True - }, - "qwen:72b": { - "token_limit": 32000, - "supports_temperature": True - }, - "qwen:110b": { - "token_limit": 32000, - "supports_temperature": True - }, - "stablelm-zephyr": { - "token_limit": 8192, - "supports_temperature": True - }, - "wizardlm2:8x22b": { - "token_limit": 65536, - "supports_temperature": True - }, - "mistral": { - "token_limit": 128000, - "supports_temperature": True - }, - "gemma2": { - "token_limit": 128000, - "supports_temperature": True - }, - "gemma2:9b": { - "token_limit": 128000, - "supports_temperature": True - }, - "gemma2:27b": { - "token_limit": 128000, - "supports_temperature": True - }, + "command-r": {"token_limit": 12800, "supports_temperature": True}, + "codellama": {"token_limit": 16000, "supports_temperature": True}, + "dbrx": {"token_limit": 32768, "supports_temperature": True}, + "deepseek-coder:33b": {"token_limit": 16000, "supports_temperature": True}, + "falcon": {"token_limit": 2048, "supports_temperature": True}, + "llama2": {"token_limit": 4096, "supports_temperature": True}, + "llama2:7b": {"token_limit": 4096, "supports_temperature": True}, + "llama2:13b": {"token_limit": 4096, "supports_temperature": True}, + "llama2:70b": {"token_limit": 4096, "supports_temperature": True}, + "llama3": {"token_limit": 8192, "supports_temperature": True}, + "llama3:8b": {"token_limit": 8192, "supports_temperature": True}, + "llama3:70b": {"token_limit": 8192, "supports_temperature": True}, + "llama3.1": {"token_limit": 128000, "supports_temperature": True}, + "llama3.1:8b": {"token_limit": 128000, "supports_temperature": True}, + "llama3.1:70b": {"token_limit": 128000, "supports_temperature": True}, + "lama3.1:405b": {"token_limit": 128000, "supports_temperature": True}, + "llama3.2": {"token_limit": 128000, "supports_temperature": True}, + "llama3.2:1b": {"token_limit": 128000, "supports_temperature": True}, + "llama3.2:3b": {"token_limit": 128000, "supports_temperature": True}, + "llama3.3:70b": {"token_limit": 128000, "supports_temperature": True}, + "scrapegraph": {"token_limit": 8192, "supports_temperature": True}, + "mistral-small": {"token_limit": 128000, "supports_temperature": True}, + "mistral-openorca": {"token_limit": 32000, "supports_temperature": True}, + "mistral-large": {"token_limit": 128000, "supports_temperature": True}, + "grok-1": {"token_limit": 8192, "supports_temperature": True}, + "llava": {"token_limit": 4096, "supports_temperature": True}, + "mixtral:8x22b-instruct": {"token_limit": 65536, "supports_temperature": True}, + "nomic-embed-text": {"token_limit": 8192, "supports_temperature": True}, + "nous-hermes2:34b": {"token_limit": 4096, "supports_temperature": True}, + "orca-mini": {"token_limit": 2048, "supports_temperature": True}, + "phi3:3.8b": {"token_limit": 12800, "supports_temperature": True}, + "phi3:14b": {"token_limit": 128000, "supports_temperature": True}, + "qwen:0.5b": {"token_limit": 32000, "supports_temperature": True}, + "qwen:1.8b": {"token_limit": 32000, "supports_temperature": True}, + "qwen:4b": {"token_limit": 32000, "supports_temperature": True}, + "qwen:14b": {"token_limit": 32000, "supports_temperature": True}, + "qwen:32b": {"token_limit": 32000, "supports_temperature": True}, + "qwen:72b": {"token_limit": 32000, "supports_temperature": True}, + "qwen:110b": {"token_limit": 32000, "supports_temperature": True}, + "stablelm-zephyr": {"token_limit": 8192, "supports_temperature": True}, + "wizardlm2:8x22b": {"token_limit": 65536, "supports_temperature": True}, + "mistral": {"token_limit": 128000, "supports_temperature": True}, + "gemma2": {"token_limit": 128000, "supports_temperature": True}, + "gemma2:9b": {"token_limit": 128000, "supports_temperature": True}, + "gemma2:27b": {"token_limit": 128000, "supports_temperature": True}, # embedding models "shaw/dmeta-embedding-zh-small-q4": { "token_limit": 8192, - "supports_temperature": True + "supports_temperature": True, }, "shaw/dmeta-embedding-zh-q4": { "token_limit": 8192, - "supports_temperature": True + "supports_temperature": True, }, "chevalblanc/acge_text_embedding": { "token_limit": 8192, - "supports_temperature": True + "supports_temperature": True, }, "martcreation/dmeta-embedding-zh": { "token_limit": 8192, - "supports_temperature": True + "supports_temperature": True, }, - "snowflake-arctic-embed": { - "token_limit": 8192, - "supports_temperature": True - }, - "mxbai-embed-large": { - "token_limit": 512, - "supports_temperature": True - } - }, - "oneapi": { - "qwen-turbo": { - "token_limit": 6000, - "supports_temperature": True - } + "snowflake-arctic-embed": {"token_limit": 8192, "supports_temperature": True}, + "mxbai-embed-large": {"token_limit": 512, "supports_temperature": True}, }, + "oneapi": {"qwen-turbo": {"token_limit": 6000, "supports_temperature": True}}, "nvidia": { - "meta/llama3-70b-instruct": { - "token_limit": 419, - "supports_temperature": True - }, - "meta/llama3-8b-instruct": { - "token_limit": 419, - "supports_temperature": True - }, - "nemotron-4-340b-instruct": { - "token_limit": 1024, - "supports_temperature": True - }, - "databricks/dbrx-instruct": { - "token_limit": 4096, - "supports_temperature": True - }, - "google/codegemma-7b": { - "token_limit": 8192, - "supports_temperature": True - }, - "google/gemma-2b": { - "token_limit": 2048, - "supports_temperature": True - }, - "google/gemma-7b": { - "token_limit": 8192, - "supports_temperature": True - }, - "google/recurrentgemma-2b": { - "token_limit": 2048, - "supports_temperature": True - }, - "meta/codellama-70b": { - "token_limit": 16384, - "supports_temperature": True - }, - "meta/llama2-70b": { - "token_limit": 4096, - "supports_temperature": True - }, + "meta/llama3-70b-instruct": {"token_limit": 419, "supports_temperature": True}, + "meta/llama3-8b-instruct": {"token_limit": 419, "supports_temperature": True}, + "nemotron-4-340b-instruct": {"token_limit": 1024, "supports_temperature": True}, + "databricks/dbrx-instruct": {"token_limit": 4096, "supports_temperature": True}, + "google/codegemma-7b": {"token_limit": 8192, "supports_temperature": True}, + "google/gemma-2b": {"token_limit": 2048, "supports_temperature": True}, + "google/gemma-7b": {"token_limit": 8192, "supports_temperature": True}, + "google/recurrentgemma-2b": {"token_limit": 2048, "supports_temperature": True}, + "meta/codellama-70b": {"token_limit": 16384, "supports_temperature": True}, + "meta/llama2-70b": {"token_limit": 4096, "supports_temperature": True}, "microsoft/phi-3-mini-128k-instruct": { "token_limit": 122880, - "supports_temperature": True + "supports_temperature": True, }, "mistralai/mistral-7b-instruct-v0.2": { "token_limit": 4096, - "supports_temperature": True - }, - "mistralai/mistral-large": { - "token_limit": 8192, - "supports_temperature": True + "supports_temperature": True, }, + "mistralai/mistral-large": {"token_limit": 8192, "supports_temperature": True}, "mistralai/mixtral-8x22b-instruct-v0.1": { "token_limit": 32768, - "supports_temperature": True + "supports_temperature": True, }, "mistralai/mixtral-8x7b-instruct-v0.1": { "token_limit": 8192, - "supports_temperature": True + "supports_temperature": True, }, - "snowflake/arctic": { - "token_limit": 16384, - "supports_temperature": True - } + "snowflake/arctic": {"token_limit": 16384, "supports_temperature": True}, }, "groq": { - "llama3-8b-8192": { - "token_limit": 8192, - "supports_temperature": True - }, - "llama3-70b-8192": { - "token_limit": 8192, - "supports_temperature": True - }, - "mixtral-8x7b-32768": { - "token_limit": 32768, - "supports_temperature": True - }, - "gemma-7b-it": { - "token_limit": 8192, - "supports_temperature": True - }, - "claude-3-haiku-20240307'": { - "token_limit": 8192, - "supports_temperature": True - } + "llama3-8b-8192": {"token_limit": 8192, "supports_temperature": True}, + "llama3-70b-8192": {"token_limit": 8192, "supports_temperature": True}, + "mixtral-8x7b-32768": {"token_limit": 32768, "supports_temperature": True}, + "gemma-7b-it": {"token_limit": 8192, "supports_temperature": True}, + "claude-3-haiku-20240307'": {"token_limit": 8192, "supports_temperature": True}, }, "toghetherai": { "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": { "token_limit": 128000, - "supports_temperature": True + "supports_temperature": True, }, "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": { "token_limit": 128000, - "supports_temperature": True + "supports_temperature": True, }, "mistralai/Mixtral-8x22B-Instruct-v0.1": { "token_limit": 128000, - "supports_temperature": True + "supports_temperature": True, }, "stabilityai/stable-diffusion-xl-base-1.0": { "token_limit": 2048, - "supports_temperature": True + "supports_temperature": True, }, "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": { "token_limit": 128000, - "supports_temperature": True + "supports_temperature": True, }, "NousResearch/Hermes-3-Llama-3.1-405B-Turbo": { "token_limit": 128000, - "supports_temperature": True + "supports_temperature": True, }, "Gryphe/MythoMax-L2-13b-Lite": { "token_limit": 8192, - "supports_temperature": True - }, - "Salesforce/Llama-Rank-V1": { - "token_limit": 8192, - "supports_temperature": True + "supports_temperature": True, }, + "Salesforce/Llama-Rank-V1": {"token_limit": 8192, "supports_temperature": True}, "meta-llama/Meta-Llama-Guard-3-8B": { "token_limit": 128000, - "supports_temperature": True + "supports_temperature": True, }, "meta-llama/Meta-Llama-3-70B-Instruct-Turbo": { "token_limit": 128000, - "supports_temperature": True + "supports_temperature": True, }, "meta-llama/Llama-3-8b-chat-hf": { "token_limit": 8192, - "supports_temperature": True + "supports_temperature": True, }, "meta-llama/Llama-3-70b-chat-hf": { "token_limit": 8192, - "supports_temperature": True + "supports_temperature": True, }, "Qwen/Qwen2-72B-Instruct": { "token_limit": 128000, - "supports_temperature": True + "supports_temperature": True, }, - "google/gemma-2-27b-it": { - "token_limit": 8192, - "supports_temperature": True - } + "google/gemma-2-27b-it": {"token_limit": 8192, "supports_temperature": True}, }, "anthropic": { - "claude_instant": { - "token_limit": 100000, - "supports_temperature": True - }, - "claude2": { - "token_limit": 9000, - "supports_temperature": True - }, - "claude2.1": { - "token_limit": 200000, - "supports_temperature": True - }, - "claude3": { - "token_limit": 200000, - "supports_temperature": True - }, - "claude3.5": { - "token_limit": 200000, - "supports_temperature": True - }, - "claude-3-opus-20240229": { - "token_limit": 200000, - "supports_temperature": True - }, + "claude_instant": {"token_limit": 100000, "supports_temperature": True}, + "claude2": {"token_limit": 9000, "supports_temperature": True}, + "claude2.1": {"token_limit": 200000, "supports_temperature": True}, + "claude3": {"token_limit": 200000, "supports_temperature": True}, + "claude3.5": {"token_limit": 200000, "supports_temperature": True}, + "claude-3-opus-20240229": {"token_limit": 200000, "supports_temperature": True}, "claude-3-sonnet-20240229": { "token_limit": 200000, - "supports_temperature": True + "supports_temperature": True, }, "claude-3-haiku-20240307": { "token_limit": 200000, - "supports_temperature": True + "supports_temperature": True, }, "claude-3-5-sonnet-20240620": { "token_limit": 200000, - "supports_temperature": True + "supports_temperature": True, }, "claude-3-5-sonnet-20241022": { "token_limit": 200000, - "supports_temperature": True + "supports_temperature": True, }, "claude-3-5-haiku-latest": { "token_limit": 200000, - "supports_temperature": True - } + "supports_temperature": True, + }, }, "bedrock": { "anthropic.claude-3-haiku-20240307-v1:0": { "token_limit": 200000, - "supports_temperature": True + "supports_temperature": True, }, "anthropic.claude-3-sonnet-20240229-v1:0": { "token_limit": 200000, - "supports_temperature": True + "supports_temperature": True, }, "anthropic.claude-3-opus-20240229-v1:0": { "token_limit": 200000, - "supports_temperature": True + "supports_temperature": True, }, "anthropic.claude-3-5-sonnet-20240620-v1:0": { "token_limit": 200000, - "supports_temperature": True + "supports_temperature": True, }, "claude-3-5-haiku-latest": { "token_limit": 200000, - "supports_temperature": True - }, - "anthropic.claude-v2:1": { - "token_limit": 200000, - "supports_temperature": True - }, - "anthropic.claude-v2": { - "token_limit": 100000, - "supports_temperature": True + "supports_temperature": True, }, + "anthropic.claude-v2:1": {"token_limit": 200000, "supports_temperature": True}, + "anthropic.claude-v2": {"token_limit": 100000, "supports_temperature": True}, "anthropic.claude-instant-v1": { "token_limit": 100000, - "supports_temperature": True + "supports_temperature": True, }, "meta.llama3-8b-instruct-v1:0": { "token_limit": 8192, - "supports_temperature": True + "supports_temperature": True, }, "meta.llama3-70b-instruct-v1:0": { "token_limit": 8192, - "supports_temperature": True - }, - "meta.llama2-13b-chat-v1": { - "token_limit": 4096, - "supports_temperature": True - }, - "meta.llama2-70b-chat-v1": { - "token_limit": 4096, - "supports_temperature": True + "supports_temperature": True, }, + "meta.llama2-13b-chat-v1": {"token_limit": 4096, "supports_temperature": True}, + "meta.llama2-70b-chat-v1": {"token_limit": 4096, "supports_temperature": True}, "mistral.mistral-7b-instruct-v0:2": { "token_limit": 32768, - "supports_temperature": True + "supports_temperature": True, }, "mistral.mixtral-8x7b-instruct-v0:1": { "token_limit": 32768, - "supports_temperature": True + "supports_temperature": True, }, "mistral.mistral-large-2402-v1:0": { "token_limit": 32768, - "supports_temperature": True + "supports_temperature": True, }, "mistral.mistral-small-2402-v1:0": { "token_limit": 32768, - "supports_temperature": True + "supports_temperature": True, }, "amazon.titan-embed-text-v1": { "token_limit": 8000, - "supports_temperature": True + "supports_temperature": True, }, "amazon.titan-embed-text-v2:0": { "token_limit": 8000, - "supports_temperature": True - }, - "cohere.embed-english-v3": { - "token_limit": 512, - "supports_temperature": True + "supports_temperature": True, }, + "cohere.embed-english-v3": {"token_limit": 512, "supports_temperature": True}, "cohere.embed-multilingual-v3": { "token_limit": 512, - "supports_temperature": True - } + "supports_temperature": True, + }, }, "mistralai": { - "mistral-large-latest": { - "token_limit": 128000, - "supports_temperature": True - }, - "open-mistral-nemo": { - "token_limit": 128000, - "supports_temperature": True - }, - "codestral-latest": { - "token_limit": 32000, - "supports_temperature": True - } + "mistral-large-latest": {"token_limit": 128000, "supports_temperature": True}, + "open-mistral-nemo": {"token_limit": 128000, "supports_temperature": True}, + "codestral-latest": {"token_limit": 32000, "supports_temperature": True}, }, "togetherai": { "Meta-Llama-3.1-70B-Instruct-Turbo": { "token_limit": 128000, - "supports_temperature": True + "supports_temperature": True, } - } + }, } diff --git a/ra_aid/webui/__init__.py b/ra_aid/webui/__init__.py index b829fab..5acfc1e 100644 --- a/ra_aid/webui/__init__.py +++ b/ra_aid/webui/__init__.py @@ -2,4 +2,4 @@ from .server import run_server -__all__ = ['run_server'] \ No newline at end of file +__all__ = ["run_server"] diff --git a/ra_aid/webui/server.py b/ra_aid/webui/server.py index 8fac321..fd2bd91 100644 --- a/ra_aid/webui/server.py +++ b/ra_aid/webui/server.py @@ -1,17 +1,17 @@ """Web interface server implementation for RA.Aid.""" import asyncio +import logging import shutil import sys -import logging from pathlib import Path -from typing import List, Dict, Any +from typing import Any, Dict, List import uvicorn -from fastapi import FastAPI, WebSocket, Request, WebSocketDisconnect -from fastapi.staticfiles import StaticFiles -from fastapi.responses import HTMLResponse +from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import HTMLResponse +from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates # Configure logging @@ -19,8 +19,10 @@ logging.basicConfig(level=logging.DEBUG) # Set to DEBUG for more info logger = logging.getLogger(__name__) # Verify ra-aid is available -if not shutil.which('ra-aid'): - logger.error("ra-aid command not found. Please ensure it's installed and in your PATH") +if not shutil.which("ra-aid"): + logger.error( + "ra-aid command not found. Please ensure it's installed and in your PATH" + ) sys.exit(1) app = FastAPI() @@ -45,6 +47,7 @@ logger.info(f"Using static directory: {STATIC_DIR}") # Setup templates templates = Jinja2Templates(directory=str(STATIC_DIR)) + class ConnectionManager: def __init__(self): self.active_connections: List[WebSocket] = [] @@ -73,77 +76,84 @@ class ConnectionManager: async def handle_error(self, websocket: WebSocket, error_message: str): try: - await websocket.send_json({ - "type": "chunk", - "chunk": { - "tools": { - "messages": [{ - "content": f"Error: {error_message}", - "status": "error" - }] - } + await websocket.send_json( + { + "type": "chunk", + "chunk": { + "tools": { + "messages": [ + { + "content": f"Error: {error_message}", + "status": "error", + } + ] + } + }, } - }) + ) except Exception as e: logger.error(f"Error sending error message: {e}") + manager = ConnectionManager() + @app.get("/", response_class=HTMLResponse) async def get_root(request: Request): """Serve the index.html file with port parameter.""" return templates.TemplateResponse( - "index.html", - {"request": request, "server_port": request.url.port or 8080} + "index.html", {"request": request, "server_port": request.url.port or 8080} ) + # Mount static files for js and other assets app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static") + @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): client_id = id(websocket) logger.info(f"New WebSocket connection attempt from client {client_id}") - + if not await manager.connect(websocket): logger.error(f"Failed to accept WebSocket connection for client {client_id}") return logger.info(f"WebSocket connection accepted for client {client_id}") - + try: # Send initial connection success message - await manager.send_message(websocket, { - "type": "chunk", - "chunk": { - "agent": { - "messages": [{ - "content": "Connected to RA.Aid server", - "status": "info" - }] - } - } - }) - + await manager.send_message( + websocket, + { + "type": "chunk", + "chunk": { + "agent": { + "messages": [ + {"content": "Connected to RA.Aid server", "status": "info"} + ] + } + }, + }, + ) + while True: try: message = await websocket.receive_json() logger.debug(f"Received message from client {client_id}: {message}") - + if message["type"] == "request": - await manager.send_message(websocket, { - "type": "stream_start" - }) - + await manager.send_message(websocket, {"type": "stream_start"}) + try: # Run ra-aid with the request cmd = ["ra-aid", "-m", message["content"], "--cowboy-mode"] logger.info(f"Executing command: {' '.join(cmd)}") - + process = await asyncio.create_subprocess_exec( *cmd, stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + stderr=asyncio.subprocess.PIPE, ) logger.info(f"Process started with PID: {process.pid}") @@ -152,27 +162,36 @@ async def websocket_endpoint(websocket: WebSocket): line = await stream.readline() if not line: break - + try: decoded_line = line.decode().strip() if decoded_line: - await manager.send_message(websocket, { - "type": "chunk", - "chunk": { - "tools" if is_error else "agent": { - "messages": [{ - "content": decoded_line, - "status": "error" if is_error else "info" - }] - } - } - }) + await manager.send_message( + websocket, + { + "type": "chunk", + "chunk": { + "tools" if is_error else "agent": { + "messages": [ + { + "content": decoded_line, + "status": "error" + if is_error + else "info", + } + ] + } + }, + }, + ) except Exception as e: logger.error(f"Error processing output: {e}") # Create tasks for reading stdout and stderr stdout_task = asyncio.create_task(read_stream(process.stdout)) - stderr_task = asyncio.create_task(read_stream(process.stderr, True)) + stderr_task = asyncio.create_task( + read_stream(process.stderr, True) + ) # Wait for both streams to complete await asyncio.gather(stdout_task, stderr_task) @@ -182,23 +201,22 @@ async def websocket_endpoint(websocket: WebSocket): if return_code != 0: await manager.handle_error( - websocket, - f"Process exited with code {return_code}" + websocket, f"Process exited with code {return_code}" ) - - await manager.send_message(websocket, { - "type": "stream_end", - "request": message["content"] - }) - + + await manager.send_message( + websocket, + {"type": "stream_end", "request": message["content"]}, + ) + except Exception as e: logger.error(f"Error executing ra-aid: {e}") await manager.handle_error(websocket, str(e)) - + except Exception as e: logger.error(f"Error processing message: {e}") await manager.handle_error(websocket, str(e)) - + except WebSocketDisconnect: logger.info(f"WebSocket client {client_id} disconnected") except Exception as e: @@ -207,6 +225,7 @@ async def websocket_endpoint(websocket: WebSocket): manager.disconnect(websocket) logger.info(f"WebSocket connection cleaned up for client {client_id}") + def run_server(host: str = "0.0.0.0", port: int = 8080): """Run the FastAPI server.""" logger.info(f"Starting server on {host}:{port}") @@ -216,5 +235,5 @@ def run_server(host: str = "0.0.0.0", port: int = 8080): port=port, log_level="debug", ws_max_size=16777216, # 16MB - timeout_keep_alive=0 # Disable keep-alive timeout - ) \ No newline at end of file + timeout_keep_alive=0, # Disable keep-alive timeout + ) diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index 549a862..5e935ed 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -248,26 +248,28 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_memory) assert agent == "react_agent" mock_react.assert_called_once_with(mock_model, []) + def test_get_model_token_limit_research(mock_memory): """Test get_model_token_limit with research provider and model.""" config = { "provider": "openai", "model": "gpt-4", "research_provider": "anthropic", - "research_model": "claude-2" + "research_model": "claude-2", } with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: mock_get_info.return_value = {"max_input_tokens": 150000} token_limit = get_model_token_limit(config, "research") assert token_limit == 150000 + def test_get_model_token_limit_planner(mock_memory): """Test get_model_token_limit with planner provider and model.""" config = { "provider": "openai", "model": "gpt-4", "planner_provider": "deepseek", - "planner_model": "dsm-1" + "planner_model": "dsm-1", } with patch("ra_aid.agent_utils.get_model_info") as mock_get_info: mock_get_info.return_value = {"max_input_tokens": 120000} diff --git a/webui/server.py b/webui/server.py index 51f5cef..03c3f88 100755 --- a/webui/server.py +++ b/webui/server.py @@ -1,22 +1,23 @@ #!/usr/bin/env python3 -import asyncio import argparse -import json -import sys +import asyncio import shutil +import sys from pathlib import Path -from typing import List, Dict, Any +from typing import List import uvicorn -from fastapi import FastAPI, WebSocket, Request, WebSocketDisconnect -from fastapi.staticfiles import StaticFiles -from fastapi.responses import HTMLResponse +from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import HTMLResponse +from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates # Verify ra-aid is available -if not shutil.which('ra-aid'): - print("Error: ra-aid command not found. Please ensure it's installed and in your PATH") +if not shutil.which("ra-aid"): + print( + "Error: ra-aid command not found. Please ensure it's installed and in your PATH" + ) sys.exit(1) app = FastAPI() @@ -33,15 +34,16 @@ app.add_middleware( # Setup templates templates = Jinja2Templates(directory=Path(__file__).parent) + # Create a route for the root to serve index.html with port parameter @app.get("/", response_class=HTMLResponse) async def get_root(request: Request): """Serve the index.html file with port parameter.""" return templates.TemplateResponse( - "index.html", - {"request": request, "server_port": request.url.port or 8080} + "index.html", {"request": request, "server_port": request.url.port or 8080} ) + # Mount static files for js and other assets app.mount("/static", StaticFiles(directory=Path(__file__).parent), name="static") @@ -50,36 +52,35 @@ app.mount("/static", StaticFiles(directory=Path(__file__).parent), name="static" # Store active WebSocket connections active_connections: List[WebSocket] = [] + @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): print(f"New WebSocket connection from {websocket.client}") await websocket.accept() print("WebSocket connection accepted") active_connections.append(websocket) - + try: while True: print("Waiting for message...") message = await websocket.receive_json() print(f"Received message: {message}") - + if message["type"] == "request": print(f"Processing request: {message['content']}") # Notify client that streaming is starting - await websocket.send_json({ - "type": "stream_start" - }) - + await websocket.send_json({"type": "stream_start"}) + try: # Run ra-aid with the request using -m flag and cowboy mode cmd = ["ra-aid", "-m", message["content"], "--cowboy-mode"] print(f"Executing command: {' '.join(cmd)}") - + # Create subprocess process = await asyncio.create_subprocess_exec( *cmd, stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + stderr=asyncio.subprocess.PIPE, ) print(f"Process started with PID: {process.pid}") @@ -92,22 +93,30 @@ async def websocket_endpoint(websocket: WebSocket): if not line: print(f"End of {stream_type} stream") break - + try: decoded_line = line.decode().strip() print(f"{stream_type} line: {decoded_line}") if decoded_line: - await websocket.send_json({ - "type": "chunk", - "chunk": { - "tools" if is_error else "agent": { - "messages": [{ - "content": decoded_line, - "status": "error" if is_error else "info" - }] - } + await websocket.send_json( + { + "type": "chunk", + "chunk": { + "tools" if is_error else "agent": { + "messages": [ + { + "content": decoded_line, + "status": ( + "error" + if is_error + else "info" + ), + } + ] + } + }, } - }) + ) except Exception as e: print(f"Error sending output: {e}") @@ -122,74 +131,80 @@ async def websocket_endpoint(websocket: WebSocket): return_code = await process.wait() if return_code != 0: - await websocket.send_json({ - "type": "chunk", - "chunk": { - "tools": { - "messages": [{ - "content": f"Process exited with code {return_code}", - "status": "error" - }] - } + await websocket.send_json( + { + "type": "chunk", + "chunk": { + "tools": { + "messages": [ + { + "content": f"Process exited with code {return_code}", + "status": "error", + } + ] + } + }, } - }) - + ) + # Notify client that streaming is complete - await websocket.send_json({ - "type": "stream_end", - "request": message["content"] - }) - + await websocket.send_json( + {"type": "stream_end", "request": message["content"]} + ) + except Exception as e: error_msg = f"Error executing ra-aid: {str(e)}" print(error_msg) - await websocket.send_json({ - "type": "chunk", - "chunk": { - "tools": { - "messages": [{ - "content": error_msg, - "status": "error" - }] - } + await websocket.send_json( + { + "type": "chunk", + "chunk": { + "tools": { + "messages": [ + {"content": error_msg, "status": "error"} + ] + } + }, } - }) - + ) + except WebSocketDisconnect: - print(f"WebSocket client disconnected") + print("WebSocket client disconnected") active_connections.remove(websocket) except Exception as e: print(f"WebSocket error: {e}") try: - await websocket.send_json({ - "type": "error", - "error": str(e) - }) - except: + await websocket.send_json({"type": "error", "error": str(e)}) + except Exception: pass finally: if websocket in active_connections: active_connections.remove(websocket) print("WebSocket connection cleaned up") + @app.get("/config") async def get_config(request: Request): """Return server configuration including host and port.""" - return { - "host": request.client.host, - "port": request.scope.get("server")[1] - } + return {"host": request.client.host, "port": request.scope.get("server")[1]} + def run_server(host: str = "0.0.0.0", port: int = 8080): """Run the FastAPI server.""" uvicorn.run(app, host=host, port=port) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="RA.Aid Web Interface Server") - parser.add_argument("--port", type=int, default=8080, - help="Port to listen on (default: 8080)") - parser.add_argument("--host", type=str, default="0.0.0.0", - help="Host to listen on (default: 0.0.0.0)") - + parser.add_argument( + "--port", type=int, default=8080, help="Port to listen on (default: 8080)" + ) + parser.add_argument( + "--host", + type=str, + default="0.0.0.0", + help="Host to listen on (default: 0.0.0.0)", + ) + args = parser.parse_args() - run_server(host=args.host, port=args.port) \ No newline at end of file + run_server(host=args.host, port=args.port)