feat: add session and trajectory models to track application state and events
- Introduce a new `Session` model to store information about each program run, including command line arguments and environment details. - Implement a `Trajectory` model to log significant events and errors during execution, enhancing debugging and monitoring capabilities. - Update various repository classes to support session and trajectory management, allowing for better tracking of user interactions and system behavior. - Modify existing functions to record relevant events in the trajectory, ensuring comprehensive logging of application activities. - Enhance error handling by logging errors to the trajectory, providing insights into failures and system performance. feat(vsc): add initial setup for VS Code extension "ra-aid" with essential files and configurations chore(vsc): create tasks.json for managing build and watch tasks in VS Code chore(vsc): add .vscodeignore to exclude unnecessary files from the extension package docs(vsc): create CHANGELOG.md to document changes and updates for the extension docs(vsc): add README.md with instructions and information about the extension feat(vsc): include esbuild.js for building and bundling the extension chore(vsc): add eslint.config.mjs for TypeScript linting configuration chore(vsc): create package.json with dependencies and scripts for the extension feat(vsc): implement extension logic in src/extension.ts with webview support test(vsc): add initial test suite in extension.test.ts for extension functionality chore(vsc): create tsconfig.json for TypeScript compiler options docs(vsc): add vsc-extension-quickstart.md for guidance on extension development
This commit is contained in:
commit
77a256317a
|
|
@ -14,3 +14,5 @@ __pycache__/
|
||||||
.envrc
|
.envrc
|
||||||
appmap.log
|
appmap.log
|
||||||
*.swp
|
*.swp
|
||||||
|
/vsc/node_modules
|
||||||
|
/vsc/dist
|
||||||
|
|
|
||||||
|
|
@ -64,6 +64,9 @@ from ra_aid.database.repositories.trajectory_repository import (
|
||||||
TrajectoryRepositoryManager,
|
TrajectoryRepositoryManager,
|
||||||
get_trajectory_repository,
|
get_trajectory_repository,
|
||||||
)
|
)
|
||||||
|
from ra_aid.database.repositories.session_repository import (
|
||||||
|
SessionRepositoryManager, get_session_repository
|
||||||
|
)
|
||||||
from ra_aid.database.repositories.related_files_repository import (
|
from ra_aid.database.repositories.related_files_repository import (
|
||||||
RelatedFilesRepositoryManager,
|
RelatedFilesRepositoryManager,
|
||||||
)
|
)
|
||||||
|
|
@ -298,6 +301,11 @@ Examples:
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Display model thinking content extracted from think tags when supported by the model",
|
help="Display model thinking content extracted from think tags when supported by the model",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--show-cost",
|
||||||
|
action="store_true",
|
||||||
|
help="Display cost information as the agent works",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--reasoning-assistance",
|
"--reasoning-assistance",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|
@ -538,18 +546,18 @@ def main():
|
||||||
env_discovery.discover()
|
env_discovery.discover()
|
||||||
env_data = env_discovery.format_markdown()
|
env_data = env_discovery.format_markdown()
|
||||||
|
|
||||||
with (
|
with SessionRepositoryManager(db) as session_repo, \
|
||||||
KeyFactRepositoryManager(db) as key_fact_repo,
|
KeyFactRepositoryManager(db) as key_fact_repo, \
|
||||||
KeySnippetRepositoryManager(db) as key_snippet_repo,
|
KeySnippetRepositoryManager(db) as key_snippet_repo, \
|
||||||
HumanInputRepositoryManager(db) as human_input_repo,
|
HumanInputRepositoryManager(db) as human_input_repo, \
|
||||||
ResearchNoteRepositoryManager(db) as research_note_repo,
|
ResearchNoteRepositoryManager(db) as research_note_repo, \
|
||||||
RelatedFilesRepositoryManager() as related_files_repo,
|
RelatedFilesRepositoryManager() as related_files_repo, \
|
||||||
TrajectoryRepositoryManager(db) as trajectory_repo,
|
TrajectoryRepositoryManager(db) as trajectory_repo, \
|
||||||
WorkLogRepositoryManager() as work_log_repo,
|
WorkLogRepositoryManager() as work_log_repo, \
|
||||||
ConfigRepositoryManager(config) as config_repo,
|
ConfigRepositoryManager(config) as config_repo, \
|
||||||
EnvInvManager(env_data) as env_inv,
|
EnvInvManager(env_data) as env_inv:
|
||||||
):
|
|
||||||
# This initializes all repositories and makes them available via their respective get methods
|
# This initializes all repositories and makes them available via their respective get methods
|
||||||
|
logger.debug("Initialized SessionRepository")
|
||||||
logger.debug("Initialized KeyFactRepository")
|
logger.debug("Initialized KeyFactRepository")
|
||||||
logger.debug("Initialized KeySnippetRepository")
|
logger.debug("Initialized KeySnippetRepository")
|
||||||
logger.debug("Initialized HumanInputRepository")
|
logger.debug("Initialized HumanInputRepository")
|
||||||
|
|
@ -560,6 +568,10 @@ def main():
|
||||||
logger.debug("Initialized ConfigRepository")
|
logger.debug("Initialized ConfigRepository")
|
||||||
logger.debug("Initialized Environment Inventory")
|
logger.debug("Initialized Environment Inventory")
|
||||||
|
|
||||||
|
# Create a new session for this program run
|
||||||
|
logger.debug("Initializing new session")
|
||||||
|
session_repo.create_session()
|
||||||
|
|
||||||
# Check dependencies before proceeding
|
# Check dependencies before proceeding
|
||||||
check_dependencies()
|
check_dependencies()
|
||||||
|
|
||||||
|
|
@ -611,6 +623,7 @@ def main():
|
||||||
)
|
)
|
||||||
config_repo.set("web_research_enabled", web_research_enabled)
|
config_repo.set("web_research_enabled", web_research_enabled)
|
||||||
config_repo.set("show_thoughts", args.show_thoughts)
|
config_repo.set("show_thoughts", args.show_thoughts)
|
||||||
|
config_repo.set("show_cost", args.show_cost)
|
||||||
config_repo.set("force_reasoning_assistance", args.reasoning_assistance)
|
config_repo.set("force_reasoning_assistance", args.reasoning_assistance)
|
||||||
config_repo.set(
|
config_repo.set(
|
||||||
"disable_reasoning_assistance", args.no_reasoning_assistance
|
"disable_reasoning_assistance", args.no_reasoning_assistance
|
||||||
|
|
@ -636,11 +649,41 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.research_only:
|
if args.research_only:
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
error_message = "Chat mode cannot be used with --research-only"
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"display_title": "Error",
|
||||||
|
"error_message": error_message,
|
||||||
|
},
|
||||||
|
record_type="error",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
is_error=True,
|
||||||
|
error_message=error_message,
|
||||||
|
)
|
||||||
|
except Exception as traj_error:
|
||||||
|
# Swallow exception to avoid recursion
|
||||||
|
logger.debug(f"Error recording trajectory: {traj_error}")
|
||||||
|
pass
|
||||||
print_error("Chat mode cannot be used with --research-only")
|
print_error("Chat mode cannot be used with --research-only")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
print_stage_header("Chat Mode")
|
print_stage_header("Chat Mode")
|
||||||
|
|
||||||
|
# Record stage transition in trajectory
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"stage": "chat_mode",
|
||||||
|
"display_title": "Chat Mode",
|
||||||
|
},
|
||||||
|
record_type="stage_transition",
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
|
|
||||||
# Get project info
|
# Get project info
|
||||||
try:
|
try:
|
||||||
project_info = get_project_info(".", file_limit=2000)
|
project_info = get_project_info(".", file_limit=2000)
|
||||||
|
|
@ -690,12 +733,9 @@ def main():
|
||||||
config_repo.set("expert_model", args.expert_model)
|
config_repo.set("expert_model", args.expert_model)
|
||||||
config_repo.set("temperature", args.temperature)
|
config_repo.set("temperature", args.temperature)
|
||||||
config_repo.set("show_thoughts", args.show_thoughts)
|
config_repo.set("show_thoughts", args.show_thoughts)
|
||||||
config_repo.set(
|
config_repo.set("show_cost", args.show_cost)
|
||||||
"force_reasoning_assistance", args.reasoning_assistance
|
config_repo.set("force_reasoning_assistance", args.reasoning_assistance)
|
||||||
)
|
config_repo.set("disable_reasoning_assistance", args.no_reasoning_assistance)
|
||||||
config_repo.set(
|
|
||||||
"disable_reasoning_assistance", args.no_reasoning_assistance
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set modification tools based on use_aider flag
|
# Set modification tools based on use_aider flag
|
||||||
set_modification_tools(args.use_aider)
|
set_modification_tools(args.use_aider)
|
||||||
|
|
@ -737,6 +777,24 @@ def main():
|
||||||
|
|
||||||
# Validate message is provided
|
# Validate message is provided
|
||||||
if not args.message:
|
if not args.message:
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
error_message = "--message is required"
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"display_title": "Error",
|
||||||
|
"error_message": error_message,
|
||||||
|
},
|
||||||
|
record_type="error",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
is_error=True,
|
||||||
|
error_message=error_message,
|
||||||
|
)
|
||||||
|
except Exception as traj_error:
|
||||||
|
# Swallow exception to avoid recursion
|
||||||
|
logger.debug(f"Error recording trajectory: {traj_error}")
|
||||||
|
pass
|
||||||
print_error("--message is required")
|
print_error("--message is required")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
@ -806,6 +864,18 @@ def main():
|
||||||
# Run research stage
|
# Run research stage
|
||||||
print_stage_header("Research Stage")
|
print_stage_header("Research Stage")
|
||||||
|
|
||||||
|
# Record stage transition in trajectory
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"stage": "research_stage",
|
||||||
|
"display_title": "Research Stage",
|
||||||
|
},
|
||||||
|
record_type="stage_transition",
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize research model with potential overrides
|
# Initialize research model with potential overrides
|
||||||
research_provider = args.research_provider or args.provider
|
research_provider = args.research_provider or args.provider
|
||||||
research_model_name = args.research_model or args.model
|
research_model_name = args.research_model or args.model
|
||||||
|
|
|
||||||
|
|
@ -462,6 +462,38 @@ class CiaynAgent:
|
||||||
error_msg = f"Error: {str(e)} \n Could not execute code: {code}"
|
error_msg = f"Error: {str(e)} \n Could not execute code: {code}"
|
||||||
tool_name = self.extract_tool_name(code)
|
tool_name = self.extract_tool_name(code)
|
||||||
logger.info(f"Tool execution failed for `{tool_name}`: {str(e)}")
|
logger.info(f"Tool execution failed for `{tool_name}`: {str(e)}")
|
||||||
|
|
||||||
|
# Record error in trajectory
|
||||||
|
try:
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||||
|
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||||
|
from ra_aid.database.connection import get_db
|
||||||
|
|
||||||
|
# Create repositories directly
|
||||||
|
trajectory_repo = TrajectoryRepository(get_db())
|
||||||
|
human_input_repo = HumanInputRepository(get_db())
|
||||||
|
human_input_id = human_input_repo.get_most_recent_id()
|
||||||
|
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"error_message": f"Tool execution failed for `{tool_name}`:\nError: {str(e)}",
|
||||||
|
"display_title": "Tool Error",
|
||||||
|
"code": code,
|
||||||
|
"tool_name": tool_name
|
||||||
|
},
|
||||||
|
record_type="tool_execution",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
is_error=True,
|
||||||
|
error_message=str(e),
|
||||||
|
error_type="ToolExecutionError",
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_parameters={"code": code}
|
||||||
|
)
|
||||||
|
except Exception as trajectory_error:
|
||||||
|
# Just log and continue if there's an error in trajectory recording
|
||||||
|
logger.error(f"Error recording trajectory for tool error display: {trajectory_error}")
|
||||||
|
|
||||||
print_warning(f"Tool execution failed for `{tool_name}`:\nError: {str(e)}\n\nCode:\n\n````\n{code}\n````", title="Tool Error")
|
print_warning(f"Tool execution failed for `{tool_name}`:\nError: {str(e)}\n\nCode:\n\n````\n{code}\n````", title="Tool Error")
|
||||||
raise ToolExecutionError(
|
raise ToolExecutionError(
|
||||||
error_msg, base_message=msg, tool_name=tool_name
|
error_msg, base_message=msg, tool_name=tool_name
|
||||||
|
|
@ -495,6 +527,36 @@ class CiaynAgent:
|
||||||
if not fallback_response:
|
if not fallback_response:
|
||||||
self.chat_history.append(err_msg)
|
self.chat_history.append(err_msg)
|
||||||
logger.info(f"Tool fallback was attempted but did not succeed. Original error: {str(e)}")
|
logger.info(f"Tool fallback was attempted but did not succeed. Original error: {str(e)}")
|
||||||
|
|
||||||
|
# Record error in trajectory
|
||||||
|
try:
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||||
|
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||||
|
from ra_aid.database.connection import get_db
|
||||||
|
|
||||||
|
# Create repositories directly
|
||||||
|
trajectory_repo = TrajectoryRepository(get_db())
|
||||||
|
human_input_repo = HumanInputRepository(get_db())
|
||||||
|
human_input_id = human_input_repo.get_most_recent_id()
|
||||||
|
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"error_message": f"Tool fallback was attempted but did not succeed. Original error: {str(e)}",
|
||||||
|
"display_title": "Fallback Failed",
|
||||||
|
"tool_name": e.tool_name if hasattr(e, "tool_name") else "unknown_tool"
|
||||||
|
},
|
||||||
|
record_type="error",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
is_error=True,
|
||||||
|
error_message=str(e),
|
||||||
|
error_type="FallbackFailedError",
|
||||||
|
tool_name=e.tool_name if hasattr(e, "tool_name") else "unknown_tool"
|
||||||
|
)
|
||||||
|
except Exception as trajectory_error:
|
||||||
|
# Just log and continue if there's an error in trajectory recording
|
||||||
|
logger.error(f"Error recording trajectory for fallback failed warning: {trajectory_error}")
|
||||||
|
|
||||||
print_warning(f"Tool fallback was attempted but did not succeed. Original error: {str(e)}", title="Fallback Failed")
|
print_warning(f"Tool fallback was attempted but did not succeed. Original error: {str(e)}", title="Fallback Failed")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
@ -595,6 +657,35 @@ class CiaynAgent:
|
||||||
matches = re.findall(pattern, response, re.DOTALL)
|
matches = re.findall(pattern, response, re.DOTALL)
|
||||||
if len(matches) == 0:
|
if len(matches) == 0:
|
||||||
logger.info("Failed to extract a valid tool call from the model's response.")
|
logger.info("Failed to extract a valid tool call from the model's response.")
|
||||||
|
|
||||||
|
# Record error in trajectory
|
||||||
|
try:
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||||
|
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||||
|
from ra_aid.database.connection import get_db
|
||||||
|
|
||||||
|
# Create repositories directly
|
||||||
|
trajectory_repo = TrajectoryRepository(get_db())
|
||||||
|
human_input_repo = HumanInputRepository(get_db())
|
||||||
|
human_input_id = human_input_repo.get_most_recent_id()
|
||||||
|
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"error_message": "Failed to extract a valid tool call from the model's response.",
|
||||||
|
"display_title": "Extraction Failed",
|
||||||
|
"code": code
|
||||||
|
},
|
||||||
|
record_type="error",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
is_error=True,
|
||||||
|
error_message="Failed to extract a valid tool call from the model's response.",
|
||||||
|
error_type="ExtractionError"
|
||||||
|
)
|
||||||
|
except Exception as trajectory_error:
|
||||||
|
# Just log and continue if there's an error in trajectory recording
|
||||||
|
logger.error(f"Error recording trajectory for extraction error display: {trajectory_error}")
|
||||||
|
|
||||||
print_warning("Failed to extract a valid tool call from the model's response.", title="Extraction Failed")
|
print_warning("Failed to extract a valid tool call from the model's response.", title="Extraction Failed")
|
||||||
raise ToolExecutionError("Failed to extract tool call")
|
raise ToolExecutionError("Failed to extract tool call")
|
||||||
ma = matches[0][0].strip()
|
ma = matches[0][0].strip()
|
||||||
|
|
@ -647,6 +738,36 @@ class CiaynAgent:
|
||||||
|
|
||||||
warning_message = f"The model returned an empty response (attempt {empty_response_count} of {max_empty_responses}). Requesting the model to make a valid tool call."
|
warning_message = f"The model returned an empty response (attempt {empty_response_count} of {max_empty_responses}). Requesting the model to make a valid tool call."
|
||||||
logger.info(warning_message)
|
logger.info(warning_message)
|
||||||
|
|
||||||
|
# Record warning in trajectory
|
||||||
|
try:
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||||
|
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||||
|
from ra_aid.database.connection import get_db_connection
|
||||||
|
|
||||||
|
# Create repositories directly
|
||||||
|
trajectory_repo = TrajectoryRepository(get_db_connection())
|
||||||
|
human_input_repo = HumanInputRepository(get_db_connection())
|
||||||
|
human_input_id = human_input_repo.get_most_recent_id()
|
||||||
|
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"warning_message": warning_message,
|
||||||
|
"display_title": "Empty Response",
|
||||||
|
"attempt": empty_response_count,
|
||||||
|
"max_attempts": max_empty_responses
|
||||||
|
},
|
||||||
|
record_type="error",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
is_error=True,
|
||||||
|
error_message=warning_message,
|
||||||
|
error_type="EmptyResponseWarning"
|
||||||
|
)
|
||||||
|
except Exception as trajectory_error:
|
||||||
|
# Just log and continue if there's an error in trajectory recording
|
||||||
|
logger.error(f"Error recording trajectory for empty response warning: {trajectory_error}")
|
||||||
|
|
||||||
print_warning(warning_message, title="Empty Response")
|
print_warning(warning_message, title="Empty Response")
|
||||||
|
|
||||||
if empty_response_count >= max_empty_responses:
|
if empty_response_count >= max_empty_responses:
|
||||||
|
|
@ -658,6 +779,36 @@ class CiaynAgent:
|
||||||
|
|
||||||
error_message = "The agent has crashed after multiple failed attempts to generate a valid tool call."
|
error_message = "The agent has crashed after multiple failed attempts to generate a valid tool call."
|
||||||
logger.error(error_message)
|
logger.error(error_message)
|
||||||
|
|
||||||
|
# Record error in trajectory
|
||||||
|
try:
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||||
|
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||||
|
from ra_aid.database.connection import get_db_connection
|
||||||
|
|
||||||
|
# Create repositories directly
|
||||||
|
trajectory_repo = TrajectoryRepository(get_db_connection())
|
||||||
|
human_input_repo = HumanInputRepository(get_db_connection())
|
||||||
|
human_input_id = human_input_repo.get_most_recent_id()
|
||||||
|
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"error_message": error_message,
|
||||||
|
"display_title": "Agent Crashed",
|
||||||
|
"crash_reason": crash_message,
|
||||||
|
"attempts": empty_response_count
|
||||||
|
},
|
||||||
|
record_type="error",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
is_error=True,
|
||||||
|
error_message=error_message,
|
||||||
|
error_type="AgentCrashError"
|
||||||
|
)
|
||||||
|
except Exception as trajectory_error:
|
||||||
|
# Just log and continue if there's an error in trajectory recording
|
||||||
|
logger.error(f"Error recording trajectory for agent crash: {trajectory_error}")
|
||||||
|
|
||||||
print_error(error_message)
|
print_error(error_message)
|
||||||
|
|
||||||
yield self._create_error_chunk(crash_message)
|
yield self._create_error_chunk(crash_message)
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,10 @@ from ra_aid.fallback_handler import FallbackHandler
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
|
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
|
||||||
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
|
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
|
||||||
|
from ra_aid.database.repositories.human_input_repository import (
|
||||||
|
get_human_input_repository,
|
||||||
|
)
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
from ra_aid.anthropic_token_limiter import sonnet_35_state_modifier, state_modifier, get_model_token_limit
|
from ra_aid.anthropic_token_limiter import sonnet_35_state_modifier, state_modifier, get_model_token_limit
|
||||||
|
|
||||||
|
|
@ -284,9 +288,23 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
|
||||||
|
|
||||||
logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e))
|
logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e))
|
||||||
delay = base_delay * (2**attempt)
|
delay = base_delay * (2**attempt)
|
||||||
print_error(
|
error_message = f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
|
||||||
f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
|
|
||||||
|
# Record error in trajectory
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"error_message": error_message,
|
||||||
|
"display_title": "Error",
|
||||||
|
},
|
||||||
|
record_type="error",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
is_error=True,
|
||||||
|
error_message=error_message
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print_error(error_message)
|
||||||
start = time.monotonic()
|
start = time.monotonic()
|
||||||
while time.monotonic() - start < delay:
|
while time.monotonic() - start < delay:
|
||||||
check_interrupt()
|
check_interrupt()
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from ra_aid import agent_utils
|
||||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||||
from ra_aid.llm import initialize_llm
|
from ra_aid.llm import initialize_llm
|
||||||
from ra_aid.prompts.key_facts_gc_prompts import KEY_FACTS_GC_PROMPT
|
from ra_aid.prompts.key_facts_gc_prompts import KEY_FACTS_GC_PROMPT
|
||||||
from ra_aid.tools.memory import log_work_event
|
from ra_aid.tools.memory import log_work_event
|
||||||
|
|
@ -82,6 +83,22 @@ def delete_key_facts(fact_ids: List[int]) -> str:
|
||||||
if deleted_facts:
|
if deleted_facts:
|
||||||
deleted_msg = "Successfully deleted facts:\n" + "\n".join([f"- #{fact_id}: {content}" for fact_id, content in deleted_facts])
|
deleted_msg = "Successfully deleted facts:\n" + "\n".join([f"- #{fact_id}: {content}" for fact_id, content in deleted_facts])
|
||||||
result_parts.append(deleted_msg)
|
result_parts.append(deleted_msg)
|
||||||
|
# Record GC operation in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"deleted_facts": deleted_facts,
|
||||||
|
"display_title": "Facts Deleted",
|
||||||
|
},
|
||||||
|
record_type="gc_operation",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
tool_name="key_facts_gc_agent"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Continue if trajectory recording fails
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Markdown(deleted_msg), title="Facts Deleted", border_style="green")
|
Panel(Markdown(deleted_msg), title="Facts Deleted", border_style="green")
|
||||||
)
|
)
|
||||||
|
|
@ -89,6 +106,22 @@ def delete_key_facts(fact_ids: List[int]) -> str:
|
||||||
if protected_facts:
|
if protected_facts:
|
||||||
protected_msg = "Protected facts (associated with current request):\n" + "\n".join([f"- #{fact_id}: {content}" for fact_id, content in protected_facts])
|
protected_msg = "Protected facts (associated with current request):\n" + "\n".join([f"- #{fact_id}: {content}" for fact_id, content in protected_facts])
|
||||||
result_parts.append(protected_msg)
|
result_parts.append(protected_msg)
|
||||||
|
# Record GC operation in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"protected_facts": protected_facts,
|
||||||
|
"display_title": "Facts Protected",
|
||||||
|
},
|
||||||
|
record_type="gc_operation",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
tool_name="key_facts_gc_agent"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Continue if trajectory recording fails
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Markdown(protected_msg), title="Facts Protected", border_style="blue")
|
Panel(Markdown(protected_msg), title="Facts Protected", border_style="blue")
|
||||||
)
|
)
|
||||||
|
|
@ -120,10 +153,44 @@ def run_key_facts_gc_agent() -> None:
|
||||||
fact_count = len(facts)
|
fact_count = len(facts)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||||
|
# Record GC error in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"error": str(e),
|
||||||
|
"display_title": "GC Error",
|
||||||
|
},
|
||||||
|
record_type="gc_operation",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
tool_name="key_facts_gc_agent",
|
||||||
|
is_error=True,
|
||||||
|
error_message=str(e),
|
||||||
|
error_type="Repository Error"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Continue if trajectory recording fails
|
||||||
|
|
||||||
console.print(Panel(f"Error: {str(e)}", title="🗑 GC Error", border_style="red"))
|
console.print(Panel(f"Error: {str(e)}", title="🗑 GC Error", border_style="red"))
|
||||||
return # Exit the function if we can't access the repository
|
return # Exit the function if we can't access the repository
|
||||||
|
|
||||||
# Display status panel with fact count included
|
# Display status panel with fact count included
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"fact_count": fact_count,
|
||||||
|
"display_title": "Garbage Collection",
|
||||||
|
},
|
||||||
|
record_type="gc_operation",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
tool_name="key_facts_gc_agent"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Continue if trajectory recording fails
|
||||||
|
|
||||||
console.print(Panel(f"Gathering my thoughts...\nCurrent number of key facts: {fact_count}", title="🗑 Garbage Collection"))
|
console.print(Panel(f"Gathering my thoughts...\nCurrent number of key facts: {fact_count}", title="🗑 Garbage Collection"))
|
||||||
|
|
||||||
# Only run the agent if we actually have facts to clean
|
# Only run the agent if we actually have facts to clean
|
||||||
|
|
@ -185,6 +252,24 @@ def run_key_facts_gc_agent() -> None:
|
||||||
# Show info panel with updated count and protected facts count
|
# Show info panel with updated count and protected facts count
|
||||||
protected_count = len(protected_facts)
|
protected_count = len(protected_facts)
|
||||||
if protected_count > 0:
|
if protected_count > 0:
|
||||||
|
# Record GC completion in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"original_count": fact_count,
|
||||||
|
"updated_count": updated_count,
|
||||||
|
"protected_count": protected_count,
|
||||||
|
"display_title": "GC Complete",
|
||||||
|
},
|
||||||
|
record_type="gc_operation",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
tool_name="key_facts_gc_agent"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Continue if trajectory recording fails
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
Panel(
|
Panel(
|
||||||
f"Cleaned key facts: {fact_count} → {updated_count}\nProtected facts (associated with current request): {protected_count}",
|
f"Cleaned key facts: {fact_count} → {updated_count}\nProtected facts (associated with current request): {protected_count}",
|
||||||
|
|
@ -192,6 +277,24 @@ def run_key_facts_gc_agent() -> None:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# Record GC completion in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"original_count": fact_count,
|
||||||
|
"updated_count": updated_count,
|
||||||
|
"protected_count": 0,
|
||||||
|
"display_title": "GC Complete",
|
||||||
|
},
|
||||||
|
record_type="gc_operation",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
tool_name="key_facts_gc_agent"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Continue if trajectory recording fails
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
Panel(
|
Panel(
|
||||||
f"Cleaned key facts: {fact_count} → {updated_count}",
|
f"Cleaned key facts: {fact_count} → {updated_count}",
|
||||||
|
|
@ -199,6 +302,40 @@ def run_key_facts_gc_agent() -> None:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# Record GC info in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"protected_count": len(protected_facts),
|
||||||
|
"message": "All facts are protected",
|
||||||
|
"display_title": "GC Info",
|
||||||
|
},
|
||||||
|
record_type="gc_operation",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
tool_name="key_facts_gc_agent"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Continue if trajectory recording fails
|
||||||
|
|
||||||
console.print(Panel(f"All {len(protected_facts)} facts are associated with the current request and protected from deletion.", title="🗑 GC Info"))
|
console.print(Panel(f"All {len(protected_facts)} facts are associated with the current request and protected from deletion.", title="🗑 GC Info"))
|
||||||
else:
|
else:
|
||||||
|
# Record GC info in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"fact_count": 0,
|
||||||
|
"message": "No key facts to clean",
|
||||||
|
"display_title": "GC Info",
|
||||||
|
},
|
||||||
|
record_type="gc_operation",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
tool_name="key_facts_gc_agent"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Continue if trajectory recording fails
|
||||||
|
|
||||||
console.print(Panel("No key facts to clean.", title="🗑 GC Info"))
|
console.print(Panel("No key facts to clean.", title="🗑 GC Info"))
|
||||||
|
|
@ -18,6 +18,7 @@ from ra_aid import agent_utils
|
||||||
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||||
from ra_aid.llm import initialize_llm
|
from ra_aid.llm import initialize_llm
|
||||||
from ra_aid.prompts.key_snippets_gc_prompts import KEY_SNIPPETS_GC_PROMPT
|
from ra_aid.prompts.key_snippets_gc_prompts import KEY_SNIPPETS_GC_PROMPT
|
||||||
from ra_aid.tools.memory import log_work_event
|
from ra_aid.tools.memory import log_work_event
|
||||||
|
|
@ -65,6 +66,23 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
|
||||||
success = get_key_snippet_repository().delete(snippet_id)
|
success = get_key_snippet_repository().delete(snippet_id)
|
||||||
if success:
|
if success:
|
||||||
success_msg = f"Successfully deleted snippet #{snippet_id} from {filepath}"
|
success_msg = f"Successfully deleted snippet #{snippet_id} from {filepath}"
|
||||||
|
# Record GC operation in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"deleted_snippet_id": snippet_id,
|
||||||
|
"filepath": filepath,
|
||||||
|
"display_title": "Snippet Deleted",
|
||||||
|
},
|
||||||
|
record_type="gc_operation",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
tool_name="key_snippets_gc_agent"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Continue if trajectory recording fails
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
Panel(
|
Panel(
|
||||||
Markdown(success_msg), title="Snippet Deleted", border_style="green"
|
Markdown(success_msg), title="Snippet Deleted", border_style="green"
|
||||||
|
|
@ -86,6 +104,22 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
|
||||||
if protected_snippets:
|
if protected_snippets:
|
||||||
protected_msg = "Protected snippets (associated with current request):\n" + "\n".join([f"- #{snippet_id}: {filepath}" for snippet_id, filepath in protected_snippets])
|
protected_msg = "Protected snippets (associated with current request):\n" + "\n".join([f"- #{snippet_id}: {filepath}" for snippet_id, filepath in protected_snippets])
|
||||||
result_parts.append(protected_msg)
|
result_parts.append(protected_msg)
|
||||||
|
# Record GC operation in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"protected_snippets": protected_snippets,
|
||||||
|
"display_title": "Snippets Protected",
|
||||||
|
},
|
||||||
|
record_type="gc_operation",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
tool_name="key_snippets_gc_agent"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Continue if trajectory recording fails
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Markdown(protected_msg), title="Snippets Protected", border_style="blue")
|
Panel(Markdown(protected_msg), title="Snippets Protected", border_style="blue")
|
||||||
)
|
)
|
||||||
|
|
@ -116,6 +150,21 @@ def run_key_snippets_gc_agent() -> None:
|
||||||
snippet_count = len(snippets)
|
snippet_count = len(snippets)
|
||||||
|
|
||||||
# Display status panel with snippet count included
|
# Display status panel with snippet count included
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"snippet_count": snippet_count,
|
||||||
|
"display_title": "Garbage Collection",
|
||||||
|
},
|
||||||
|
record_type="gc_operation",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
tool_name="key_snippets_gc_agent"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Continue if trajectory recording fails
|
||||||
|
|
||||||
console.print(Panel(f"Gathering my thoughts...\nCurrent number of key snippets: {snippet_count}", title="🗑 Garbage Collection"))
|
console.print(Panel(f"Gathering my thoughts...\nCurrent number of key snippets: {snippet_count}", title="🗑 Garbage Collection"))
|
||||||
|
|
||||||
# Only run the agent if we actually have snippets to clean
|
# Only run the agent if we actually have snippets to clean
|
||||||
|
|
@ -185,6 +234,24 @@ def run_key_snippets_gc_agent() -> None:
|
||||||
# Show info panel with updated count and protected snippets count
|
# Show info panel with updated count and protected snippets count
|
||||||
protected_count = len(protected_snippets)
|
protected_count = len(protected_snippets)
|
||||||
if protected_count > 0:
|
if protected_count > 0:
|
||||||
|
# Record GC completion in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"original_count": snippet_count,
|
||||||
|
"updated_count": updated_count,
|
||||||
|
"protected_count": protected_count,
|
||||||
|
"display_title": "GC Complete",
|
||||||
|
},
|
||||||
|
record_type="gc_operation",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
tool_name="key_snippets_gc_agent"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Continue if trajectory recording fails
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
Panel(
|
Panel(
|
||||||
f"Cleaned key snippets: {snippet_count} → {updated_count}\nProtected snippets (associated with current request): {protected_count}",
|
f"Cleaned key snippets: {snippet_count} → {updated_count}\nProtected snippets (associated with current request): {protected_count}",
|
||||||
|
|
@ -192,6 +259,24 @@ def run_key_snippets_gc_agent() -> None:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# Record GC completion in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"original_count": snippet_count,
|
||||||
|
"updated_count": updated_count,
|
||||||
|
"protected_count": 0,
|
||||||
|
"display_title": "GC Complete",
|
||||||
|
},
|
||||||
|
record_type="gc_operation",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
tool_name="key_snippets_gc_agent"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Continue if trajectory recording fails
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
Panel(
|
Panel(
|
||||||
f"Cleaned key snippets: {snippet_count} → {updated_count}",
|
f"Cleaned key snippets: {snippet_count} → {updated_count}",
|
||||||
|
|
@ -199,6 +284,40 @@ def run_key_snippets_gc_agent() -> None:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# Record GC info in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"protected_count": len(protected_snippets),
|
||||||
|
"message": "All snippets are protected",
|
||||||
|
"display_title": "GC Info",
|
||||||
|
},
|
||||||
|
record_type="gc_operation",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
tool_name="key_snippets_gc_agent"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Continue if trajectory recording fails
|
||||||
|
|
||||||
console.print(Panel(f"All {len(protected_snippets)} snippets are associated with the current request and protected from deletion.", title="🗑 GC Info"))
|
console.print(Panel(f"All {len(protected_snippets)} snippets are associated with the current request and protected from deletion.", title="🗑 GC Info"))
|
||||||
else:
|
else:
|
||||||
|
# Record GC info in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"snippet_count": 0,
|
||||||
|
"message": "No key snippets to clean",
|
||||||
|
"display_title": "GC Info",
|
||||||
|
},
|
||||||
|
record_type="gc_operation",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
tool_name="key_snippets_gc_agent"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Continue if trajectory recording fails
|
||||||
|
|
||||||
console.print(Panel("No key snippets to clean.", title="🗑 GC Info"))
|
console.print(Panel("No key snippets to clean.", title="🗑 GC Info"))
|
||||||
|
|
@ -24,6 +24,8 @@ from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_
|
||||||
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
||||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
from ra_aid.database.repositories.work_log_repository import get_work_log_repository
|
from ra_aid.database.repositories.work_log_repository import get_work_log_repository
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||||
|
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||||
from ra_aid.env_inv_context import get_env_inv
|
from ra_aid.env_inv_context import get_env_inv
|
||||||
from ra_aid.exceptions import AgentInterrupt
|
from ra_aid.exceptions import AgentInterrupt
|
||||||
from ra_aid.llm import initialize_expert_llm
|
from ra_aid.llm import initialize_expert_llm
|
||||||
|
|
@ -156,6 +158,18 @@ def run_planning_agent(
|
||||||
# Display the planning stage header before any reasoning assistance
|
# Display the planning stage header before any reasoning assistance
|
||||||
print_stage_header("Planning Stage")
|
print_stage_header("Planning Stage")
|
||||||
|
|
||||||
|
# Record stage transition in trajectory
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"stage": "planning_stage",
|
||||||
|
"display_title": "Planning Stage",
|
||||||
|
},
|
||||||
|
record_type="stage_transition",
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize expert guidance section
|
# Initialize expert guidance section
|
||||||
expert_guidance = ""
|
expert_guidance = ""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from ra_aid.agent_utils import create_agent, run_agent_with_retry
|
||||||
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
||||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||||
from ra_aid.llm import initialize_llm
|
from ra_aid.llm import initialize_llm
|
||||||
from ra_aid.model_formatters.research_notes_formatter import format_research_note
|
from ra_aid.model_formatters.research_notes_formatter import format_research_note
|
||||||
from ra_aid.tools.memory import log_work_event
|
from ra_aid.tools.memory import log_work_event
|
||||||
|
|
@ -84,6 +85,22 @@ def delete_research_notes(note_ids: List[int]) -> str:
|
||||||
if deleted_notes:
|
if deleted_notes:
|
||||||
deleted_msg = "Successfully deleted research notes:\n" + "\n".join([f"- #{note_id}: {content[:100]}..." if len(content) > 100 else f"- #{note_id}: {content}" for note_id, content in deleted_notes])
|
deleted_msg = "Successfully deleted research notes:\n" + "\n".join([f"- #{note_id}: {content[:100]}..." if len(content) > 100 else f"- #{note_id}: {content}" for note_id, content in deleted_notes])
|
||||||
result_parts.append(deleted_msg)
|
result_parts.append(deleted_msg)
|
||||||
|
# Record GC operation in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"deleted_notes": deleted_notes,
|
||||||
|
"display_title": "Research Notes Deleted",
|
||||||
|
},
|
||||||
|
record_type="gc_operation",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
tool_name="research_notes_gc_agent"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Continue if trajectory recording fails
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Markdown(deleted_msg), title="Research Notes Deleted", border_style="green")
|
Panel(Markdown(deleted_msg), title="Research Notes Deleted", border_style="green")
|
||||||
)
|
)
|
||||||
|
|
@ -91,6 +108,22 @@ def delete_research_notes(note_ids: List[int]) -> str:
|
||||||
if protected_notes:
|
if protected_notes:
|
||||||
protected_msg = "Protected research notes (associated with current request):\n" + "\n".join([f"- #{note_id}: {content[:100]}..." if len(content) > 100 else f"- #{note_id}: {content}" for note_id, content in protected_notes])
|
protected_msg = "Protected research notes (associated with current request):\n" + "\n".join([f"- #{note_id}: {content[:100]}..." if len(content) > 100 else f"- #{note_id}: {content}" for note_id, content in protected_notes])
|
||||||
result_parts.append(protected_msg)
|
result_parts.append(protected_msg)
|
||||||
|
# Record GC operation in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"protected_notes": protected_notes,
|
||||||
|
"display_title": "Research Notes Protected",
|
||||||
|
},
|
||||||
|
record_type="gc_operation",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
tool_name="research_notes_gc_agent"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Continue if trajectory recording fails
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Markdown(protected_msg), title="Research Notes Protected", border_style="blue")
|
Panel(Markdown(protected_msg), title="Research Notes Protected", border_style="blue")
|
||||||
)
|
)
|
||||||
|
|
@ -125,10 +158,44 @@ def run_research_notes_gc_agent(threshold: int = 30) -> None:
|
||||||
note_count = len(notes)
|
note_count = len(notes)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.error(f"Failed to access research note repository: {str(e)}")
|
logger.error(f"Failed to access research note repository: {str(e)}")
|
||||||
|
# Record GC error in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"error": str(e),
|
||||||
|
"display_title": "GC Error",
|
||||||
|
},
|
||||||
|
record_type="gc_operation",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
tool_name="research_notes_gc_agent",
|
||||||
|
is_error=True,
|
||||||
|
error_message=str(e),
|
||||||
|
error_type="Repository Error"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Continue if trajectory recording fails
|
||||||
|
|
||||||
console.print(Panel(f"Error: {str(e)}", title="🗑 GC Error", border_style="red"))
|
console.print(Panel(f"Error: {str(e)}", title="🗑 GC Error", border_style="red"))
|
||||||
return # Exit the function if we can't access the repository
|
return # Exit the function if we can't access the repository
|
||||||
|
|
||||||
# Display status panel with note count included
|
# Display status panel with note count included
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"note_count": note_count,
|
||||||
|
"display_title": "Garbage Collection",
|
||||||
|
},
|
||||||
|
record_type="gc_operation",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
tool_name="research_notes_gc_agent"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Continue if trajectory recording fails
|
||||||
|
|
||||||
console.print(Panel(f"Gathering my thoughts...\nCurrent number of research notes: {note_count}", title="🗑 Garbage Collection"))
|
console.print(Panel(f"Gathering my thoughts...\nCurrent number of research notes: {note_count}", title="🗑 Garbage Collection"))
|
||||||
|
|
||||||
# Only run the agent if we actually have notes to clean and we're over the threshold
|
# Only run the agent if we actually have notes to clean and we're over the threshold
|
||||||
|
|
@ -235,6 +302,24 @@ Remember: Your goal is to maintain a concise, high-value collection of research
|
||||||
# Show info panel with updated count and protected notes count
|
# Show info panel with updated count and protected notes count
|
||||||
protected_count = len(protected_notes)
|
protected_count = len(protected_notes)
|
||||||
if protected_count > 0:
|
if protected_count > 0:
|
||||||
|
# Record GC completion in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"original_count": note_count,
|
||||||
|
"updated_count": updated_count,
|
||||||
|
"protected_count": protected_count,
|
||||||
|
"display_title": "GC Complete",
|
||||||
|
},
|
||||||
|
record_type="gc_operation",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
tool_name="research_notes_gc_agent"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Continue if trajectory recording fails
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
Panel(
|
Panel(
|
||||||
f"Cleaned research notes: {note_count} → {updated_count}\nProtected notes (associated with current request): {protected_count}",
|
f"Cleaned research notes: {note_count} → {updated_count}\nProtected notes (associated with current request): {protected_count}",
|
||||||
|
|
@ -242,6 +327,24 @@ Remember: Your goal is to maintain a concise, high-value collection of research
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# Record GC completion in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"original_count": note_count,
|
||||||
|
"updated_count": updated_count,
|
||||||
|
"protected_count": 0,
|
||||||
|
"display_title": "GC Complete",
|
||||||
|
},
|
||||||
|
record_type="gc_operation",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
tool_name="research_notes_gc_agent"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Continue if trajectory recording fails
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
Panel(
|
Panel(
|
||||||
f"Cleaned research notes: {note_count} → {updated_count}",
|
f"Cleaned research notes: {note_count} → {updated_count}",
|
||||||
|
|
@ -249,6 +352,41 @@ Remember: Your goal is to maintain a concise, high-value collection of research
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# Record GC info in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"protected_count": len(protected_notes),
|
||||||
|
"message": "All research notes are protected",
|
||||||
|
"display_title": "GC Info",
|
||||||
|
},
|
||||||
|
record_type="gc_operation",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
tool_name="research_notes_gc_agent"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Continue if trajectory recording fails
|
||||||
|
|
||||||
console.print(Panel(f"All {len(protected_notes)} research notes are associated with the current request and protected from deletion.", title="🗑 GC Info"))
|
console.print(Panel(f"All {len(protected_notes)} research notes are associated with the current request and protected from deletion.", title="🗑 GC Info"))
|
||||||
else:
|
else:
|
||||||
|
# Record GC info in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"note_count": note_count,
|
||||||
|
"threshold": threshold,
|
||||||
|
"message": "Below threshold - no cleanup needed",
|
||||||
|
"display_title": "GC Info",
|
||||||
|
},
|
||||||
|
record_type="gc_operation",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
tool_name="research_notes_gc_agent"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Continue if trajectory recording fails
|
||||||
|
|
||||||
console.print(Panel(f"Research notes count ({note_count}) is below threshold ({threshold}). No cleanup needed.", title="🗑 GC Info"))
|
console.print(Panel(f"Research notes count ({note_count}) is below threshold ({threshold}). No cleanup needed.", title="🗑 GC Info"))
|
||||||
|
|
@ -7,6 +7,7 @@ FALLBACK_TOOL_MODEL_LIMIT = 5
|
||||||
RETRY_FALLBACK_COUNT = 3
|
RETRY_FALLBACK_COUNT = 3
|
||||||
DEFAULT_TEST_CMD_TIMEOUT = 60 * 5 # 5 minutes in seconds
|
DEFAULT_TEST_CMD_TIMEOUT = 60 * 5 # 5 minutes in seconds
|
||||||
DEFAULT_MODEL="claude-3-7-sonnet-20250219"
|
DEFAULT_MODEL="claude-3-7-sonnet-20250219"
|
||||||
|
DEFAULT_SHOW_COST = False
|
||||||
|
|
||||||
|
|
||||||
VALID_PROVIDERS = [
|
VALID_PROVIDERS = [
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,10 @@
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||||
|
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,14 +6,18 @@ from rich.panel import Panel
|
||||||
|
|
||||||
from ra_aid.exceptions import ToolExecutionError
|
from ra_aid.exceptions import ToolExecutionError
|
||||||
from ra_aid.callbacks.anthropic_callback_handler import AnthropicCallbackHandler
|
from ra_aid.callbacks.anthropic_callback_handler import AnthropicCallbackHandler
|
||||||
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
|
from ra_aid.config import DEFAULT_SHOW_COST
|
||||||
|
|
||||||
# Import shared console instance
|
# Import shared console instance
|
||||||
from .formatting import console
|
from .formatting import console
|
||||||
|
|
||||||
|
|
||||||
def get_cost_subtitle(cost_cb: Optional[AnthropicCallbackHandler]) -> Optional[str]:
|
def get_cost_subtitle(cost_cb: Optional[AnthropicCallbackHandler]) -> Optional[str]:
|
||||||
"""Generate a subtitle with cost information if a callback is provided."""
|
"""Generate a subtitle with cost information if a callback is provided and show_cost is enabled."""
|
||||||
if cost_cb:
|
# Only show cost information if both cost_cb is provided AND show_cost is True
|
||||||
|
show_cost = get_config_repository().get("show_cost", DEFAULT_SHOW_COST)
|
||||||
|
if cost_cb and show_cost:
|
||||||
return f"Cost: ${cost_cb.total_cost:.6f} | Tokens: {cost_cb.total_tokens}"
|
return f"Cost: ${cost_cb.total_cost:.6f} | Tokens: {cost_cb.total_tokens}"
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -42,8 +42,8 @@ def initialize_database():
|
||||||
# to avoid circular imports
|
# to avoid circular imports
|
||||||
# Note: This import needs to be here, not at the top level
|
# Note: This import needs to be here, not at the top level
|
||||||
try:
|
try:
|
||||||
from ra_aid.database.models import KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory
|
from ra_aid.database.models import KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory, Session
|
||||||
db.create_tables([KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory], safe=True)
|
db.create_tables([KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory, Session], safe=True)
|
||||||
logger.debug("Ensured database tables exist")
|
logger.debug("Ensured database tables exist")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error creating tables: {str(e)}")
|
logger.error(f"Error creating tables: {str(e)}")
|
||||||
|
|
@ -99,6 +99,25 @@ class BaseModel(peewee.Model):
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class Session(BaseModel):
|
||||||
|
"""
|
||||||
|
Model representing a session stored in the database.
|
||||||
|
|
||||||
|
Sessions track information about each program run, providing a way to group
|
||||||
|
related records like human inputs, trajectories, and key facts.
|
||||||
|
|
||||||
|
Each session record captures details about when the program was started,
|
||||||
|
what command line arguments were used, and environment information.
|
||||||
|
"""
|
||||||
|
start_time = peewee.DateTimeField(default=datetime.datetime.now)
|
||||||
|
command_line = peewee.TextField(null=True)
|
||||||
|
program_version = peewee.TextField(null=True)
|
||||||
|
machine_info = peewee.TextField(null=True) # JSON-encoded machine information
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
table_name = "session"
|
||||||
|
|
||||||
|
|
||||||
class HumanInput(BaseModel):
|
class HumanInput(BaseModel):
|
||||||
"""
|
"""
|
||||||
Model representing human input stored in the database.
|
Model representing human input stored in the database.
|
||||||
|
|
@ -109,6 +128,7 @@ class HumanInput(BaseModel):
|
||||||
"""
|
"""
|
||||||
content = peewee.TextField()
|
content = peewee.TextField()
|
||||||
source = peewee.TextField() # 'cli', 'chat', or 'hil'
|
source = peewee.TextField() # 'cli', 'chat', or 'hil'
|
||||||
|
session = peewee.ForeignKeyField(Session, backref='human_inputs', null=True)
|
||||||
# created_at and updated_at are inherited from BaseModel
|
# created_at and updated_at are inherited from BaseModel
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
|
@ -124,6 +144,7 @@ class KeyFact(BaseModel):
|
||||||
"""
|
"""
|
||||||
content = peewee.TextField()
|
content = peewee.TextField()
|
||||||
human_input = peewee.ForeignKeyField(HumanInput, backref='key_facts', null=True)
|
human_input = peewee.ForeignKeyField(HumanInput, backref='key_facts', null=True)
|
||||||
|
session = peewee.ForeignKeyField(Session, backref='key_facts', null=True)
|
||||||
# created_at and updated_at are inherited from BaseModel
|
# created_at and updated_at are inherited from BaseModel
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
|
@ -143,6 +164,7 @@ class KeySnippet(BaseModel):
|
||||||
snippet = peewee.TextField()
|
snippet = peewee.TextField()
|
||||||
description = peewee.TextField(null=True)
|
description = peewee.TextField(null=True)
|
||||||
human_input = peewee.ForeignKeyField(HumanInput, backref='key_snippets', null=True)
|
human_input = peewee.ForeignKeyField(HumanInput, backref='key_snippets', null=True)
|
||||||
|
session = peewee.ForeignKeyField(Session, backref='key_snippets', null=True)
|
||||||
# created_at and updated_at are inherited from BaseModel
|
# created_at and updated_at are inherited from BaseModel
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
|
@ -159,6 +181,7 @@ class ResearchNote(BaseModel):
|
||||||
"""
|
"""
|
||||||
content = peewee.TextField()
|
content = peewee.TextField()
|
||||||
human_input = peewee.ForeignKeyField(HumanInput, backref='research_notes', null=True)
|
human_input = peewee.ForeignKeyField(HumanInput, backref='research_notes', null=True)
|
||||||
|
session = peewee.ForeignKeyField(Session, backref='research_notes', null=True)
|
||||||
# created_at and updated_at are inherited from BaseModel
|
# created_at and updated_at are inherited from BaseModel
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
|
@ -182,17 +205,18 @@ class Trajectory(BaseModel):
|
||||||
- Error information (when a tool execution fails)
|
- Error information (when a tool execution fails)
|
||||||
"""
|
"""
|
||||||
human_input = peewee.ForeignKeyField(HumanInput, backref='trajectories', null=True)
|
human_input = peewee.ForeignKeyField(HumanInput, backref='trajectories', null=True)
|
||||||
tool_name = peewee.TextField()
|
tool_name = peewee.TextField(null=True)
|
||||||
tool_parameters = peewee.TextField() # JSON-encoded parameters
|
tool_parameters = peewee.TextField(null=True) # JSON-encoded parameters
|
||||||
tool_result = peewee.TextField() # JSON-encoded result
|
tool_result = peewee.TextField(null=True) # JSON-encoded result
|
||||||
step_data = peewee.TextField() # JSON-encoded UI rendering data
|
step_data = peewee.TextField(null=True) # JSON-encoded UI rendering data
|
||||||
record_type = peewee.TextField() # Type of trajectory record
|
record_type = peewee.TextField(null=True) # Type of trajectory record
|
||||||
cost = peewee.FloatField(null=True) # Placeholder for cost tracking
|
cost = peewee.FloatField(null=True) # Placeholder for cost tracking
|
||||||
tokens = peewee.IntegerField(null=True) # Placeholder for token usage tracking
|
tokens = peewee.IntegerField(null=True) # Placeholder for token usage tracking
|
||||||
is_error = peewee.BooleanField(default=False) # Flag indicating if this record represents an error
|
is_error = peewee.BooleanField(default=False) # Flag indicating if this record represents an error
|
||||||
error_message = peewee.TextField(null=True) # The error message
|
error_message = peewee.TextField(null=True) # The error message
|
||||||
error_type = peewee.TextField(null=True) # The type/class of the error
|
error_type = peewee.TextField(null=True) # The type/class of the error
|
||||||
error_details = peewee.TextField(null=True) # Additional error details like stack traces or context
|
error_details = peewee.TextField(null=True) # Additional error details like stack traces or context
|
||||||
|
session = peewee.ForeignKeyField(Session, backref='trajectories', null=True)
|
||||||
# created_at and updated_at are inherited from BaseModel
|
# created_at and updated_at are inherited from BaseModel
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@ class ConfigRepository:
|
||||||
FALLBACK_TOOL_MODEL_LIMIT,
|
FALLBACK_TOOL_MODEL_LIMIT,
|
||||||
RETRY_FALLBACK_COUNT,
|
RETRY_FALLBACK_COUNT,
|
||||||
DEFAULT_TEST_CMD_TIMEOUT,
|
DEFAULT_TEST_CMD_TIMEOUT,
|
||||||
|
DEFAULT_SHOW_COST,
|
||||||
VALID_PROVIDERS,
|
VALID_PROVIDERS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -42,6 +43,7 @@ class ConfigRepository:
|
||||||
"fallback_tool_model_limit": FALLBACK_TOOL_MODEL_LIMIT,
|
"fallback_tool_model_limit": FALLBACK_TOOL_MODEL_LIMIT,
|
||||||
"retry_fallback_count": RETRY_FALLBACK_COUNT,
|
"retry_fallback_count": RETRY_FALLBACK_COUNT,
|
||||||
"test_cmd_timeout": DEFAULT_TEST_CMD_TIMEOUT,
|
"test_cmd_timeout": DEFAULT_TEST_CMD_TIMEOUT,
|
||||||
|
"show_cost": DEFAULT_SHOW_COST,
|
||||||
"valid_providers": VALID_PROVIDERS,
|
"valid_providers": VALID_PROVIDERS,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,243 @@
|
||||||
|
"""
|
||||||
|
Session repository implementation for database access.
|
||||||
|
|
||||||
|
This module provides a repository implementation for the Session model,
|
||||||
|
following the repository pattern for data access abstraction. It handles
|
||||||
|
operations for storing and retrieving application session information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
import contextvars
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import peewee
|
||||||
|
|
||||||
|
from ra_aid.database.models import Session
|
||||||
|
from ra_aid.__version__ import __version__
|
||||||
|
from ra_aid.logging_config import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# Create contextvar to hold the SessionRepository instance
|
||||||
|
session_repo_var = contextvars.ContextVar("session_repo", default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionRepositoryManager:
|
||||||
|
"""
|
||||||
|
Context manager for SessionRepository.
|
||||||
|
|
||||||
|
This class provides a context manager interface for SessionRepository,
|
||||||
|
using the contextvars approach for thread safety.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
with DatabaseManager() as db:
|
||||||
|
with SessionRepositoryManager(db) as repo:
|
||||||
|
# Use the repository
|
||||||
|
session = repo.create_session()
|
||||||
|
current_session = repo.get_current_session()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db):
|
||||||
|
"""
|
||||||
|
Initialize the SessionRepositoryManager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database connection to use (required)
|
||||||
|
"""
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
def __enter__(self) -> 'SessionRepository':
|
||||||
|
"""
|
||||||
|
Initialize the SessionRepository and return it.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SessionRepository: The initialized repository
|
||||||
|
"""
|
||||||
|
repo = SessionRepository(self.db)
|
||||||
|
session_repo_var.set(repo)
|
||||||
|
return repo
|
||||||
|
|
||||||
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[type],
|
||||||
|
exc_val: Optional[Exception],
|
||||||
|
exc_tb: Optional[object],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Reset the repository when exiting the context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exc_type: The exception type if an exception was raised
|
||||||
|
exc_val: The exception value if an exception was raised
|
||||||
|
exc_tb: The traceback if an exception was raised
|
||||||
|
"""
|
||||||
|
# Reset the contextvar to None
|
||||||
|
session_repo_var.set(None)
|
||||||
|
|
||||||
|
# Don't suppress exceptions
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_session_repository() -> 'SessionRepository':
|
||||||
|
"""
|
||||||
|
Get the current SessionRepository instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SessionRepository: The current repository instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If no repository has been initialized with SessionRepositoryManager
|
||||||
|
"""
|
||||||
|
repo = session_repo_var.get()
|
||||||
|
if repo is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"No SessionRepository available. "
|
||||||
|
"Make sure to initialize one with SessionRepositoryManager first."
|
||||||
|
)
|
||||||
|
return repo
|
||||||
|
|
||||||
|
|
||||||
|
class SessionRepository:
|
||||||
|
"""
|
||||||
|
Repository for handling Session records in the database.
|
||||||
|
|
||||||
|
This class provides methods for creating, retrieving, and managing Session records.
|
||||||
|
It abstracts away the database operations and provides a clean interface for working
|
||||||
|
with Session entities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db):
|
||||||
|
"""
|
||||||
|
Initialize the SessionRepository.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Database connection to use (required)
|
||||||
|
"""
|
||||||
|
if db is None:
|
||||||
|
raise ValueError("Database connection is required for SessionRepository")
|
||||||
|
self.db = db
|
||||||
|
self.current_session = None
|
||||||
|
|
||||||
|
def create_session(self, metadata: Optional[Dict[str, Any]] = None) -> Session:
|
||||||
|
"""
|
||||||
|
Create a new session record in the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata: Optional dictionary of additional metadata to store with the session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Session: The newly created session instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
peewee.DatabaseError: If there's an error creating the record
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get command line arguments
|
||||||
|
command_line = " ".join(sys.argv)
|
||||||
|
|
||||||
|
# Get program version
|
||||||
|
program_version = __version__
|
||||||
|
|
||||||
|
# JSON encode metadata if provided
|
||||||
|
machine_info = json.dumps(metadata) if metadata is not None else None
|
||||||
|
|
||||||
|
session = Session.create(
|
||||||
|
start_time=datetime.datetime.now(),
|
||||||
|
command_line=command_line,
|
||||||
|
program_version=program_version,
|
||||||
|
machine_info=machine_info
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store the current session
|
||||||
|
self.current_session = session
|
||||||
|
|
||||||
|
logger.debug(f"Created new session with ID {session.id}")
|
||||||
|
return session
|
||||||
|
except peewee.DatabaseError as e:
|
||||||
|
logger.error(f"Failed to create session record: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_current_session(self) -> Optional[Session]:
|
||||||
|
"""
|
||||||
|
Get the current active session.
|
||||||
|
|
||||||
|
If no session has been created in this repository instance,
|
||||||
|
retrieves the most recent session from the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Session]: The current session or None if no sessions exist
|
||||||
|
"""
|
||||||
|
if self.current_session is not None:
|
||||||
|
return self.current_session
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Find the most recent session
|
||||||
|
session = Session.select().order_by(Session.created_at.desc()).first()
|
||||||
|
if session:
|
||||||
|
self.current_session = session
|
||||||
|
return session
|
||||||
|
except peewee.DatabaseError as e:
|
||||||
|
logger.error(f"Failed to get current session: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_current_session_id(self) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Get the ID of the current active session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[int]: The ID of the current session or None if no session exists
|
||||||
|
"""
|
||||||
|
session = self.get_current_session()
|
||||||
|
return session.id if session else None
|
||||||
|
|
||||||
|
def get(self, session_id: int) -> Optional[Session]:
|
||||||
|
"""
|
||||||
|
Get a session by its ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The ID of the session to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Session]: The session with the given ID or None if not found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return Session.get_or_none(Session.id == session_id)
|
||||||
|
except peewee.DatabaseError as e:
|
||||||
|
logger.error(f"Database error getting session {session_id}: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_all(self) -> List[Session]:
|
||||||
|
"""
|
||||||
|
Get all sessions from the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Session]: List of all sessions
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return list(Session.select().order_by(Session.created_at.desc()))
|
||||||
|
except peewee.DatabaseError as e:
|
||||||
|
logger.error(f"Failed to get all sessions: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_recent(self, limit: int = 10) -> List[Session]:
|
||||||
|
"""
|
||||||
|
Get the most recent sessions from the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: Maximum number of sessions to return (default: 10)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Session]: List of the most recent sessions
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return list(
|
||||||
|
Session.select()
|
||||||
|
.order_by(Session.created_at.desc())
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
except peewee.DatabaseError as e:
|
||||||
|
logger.error(f"Failed to get recent sessions: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
@ -132,8 +132,8 @@ class TrajectoryRepository:
|
||||||
|
|
||||||
def create(
|
def create(
|
||||||
self,
|
self,
|
||||||
tool_name: str,
|
tool_name: Optional[str] = None,
|
||||||
tool_parameters: Dict[str, Any],
|
tool_parameters: Optional[Dict[str, Any]] = None,
|
||||||
tool_result: Optional[Dict[str, Any]] = None,
|
tool_result: Optional[Dict[str, Any]] = None,
|
||||||
step_data: Optional[Dict[str, Any]] = None,
|
step_data: Optional[Dict[str, Any]] = None,
|
||||||
record_type: str = "tool_execution",
|
record_type: str = "tool_execution",
|
||||||
|
|
@ -149,8 +149,8 @@ class TrajectoryRepository:
|
||||||
Create a new trajectory record in the database.
|
Create a new trajectory record in the database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool_name: Name of the tool that was executed
|
tool_name: Optional name of the tool that was executed
|
||||||
tool_parameters: Parameters passed to the tool (will be JSON encoded)
|
tool_parameters: Optional parameters passed to the tool (will be JSON encoded)
|
||||||
tool_result: Result returned by the tool (will be JSON encoded)
|
tool_result: Result returned by the tool (will be JSON encoded)
|
||||||
step_data: UI rendering data (will be JSON encoded)
|
step_data: UI rendering data (will be JSON encoded)
|
||||||
record_type: Type of trajectory record
|
record_type: Type of trajectory record
|
||||||
|
|
@ -170,7 +170,7 @@ class TrajectoryRepository:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Serialize JSON fields
|
# Serialize JSON fields
|
||||||
tool_parameters_json = json.dumps(tool_parameters)
|
tool_parameters_json = json.dumps(tool_parameters) if tool_parameters is not None else None
|
||||||
tool_result_json = json.dumps(tool_result) if tool_result is not None else None
|
tool_result_json = json.dumps(tool_result) if tool_result is not None else None
|
||||||
step_data_json = json.dumps(step_data) if step_data is not None else None
|
step_data_json = json.dumps(step_data) if step_data is not None else None
|
||||||
|
|
||||||
|
|
@ -185,7 +185,7 @@ class TrajectoryRepository:
|
||||||
# Create the trajectory record
|
# Create the trajectory record
|
||||||
trajectory = Trajectory.create(
|
trajectory = Trajectory.create(
|
||||||
human_input=human_input,
|
human_input=human_input,
|
||||||
tool_name=tool_name,
|
tool_name=tool_name or "", # Use empty string if tool_name is None
|
||||||
tool_parameters=tool_parameters_json,
|
tool_parameters=tool_parameters_json,
|
||||||
tool_result=tool_result_json,
|
tool_result=tool_result_json,
|
||||||
step_data=step_data_json,
|
step_data=step_data_json,
|
||||||
|
|
@ -197,7 +197,10 @@ class TrajectoryRepository:
|
||||||
error_type=error_type,
|
error_type=error_type,
|
||||||
error_details=error_details
|
error_details=error_details
|
||||||
)
|
)
|
||||||
logger.debug(f"Created trajectory record ID {trajectory.id} for tool: {tool_name}")
|
if tool_name:
|
||||||
|
logger.debug(f"Created trajectory record ID {trajectory.id} for tool: {tool_name}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Created trajectory record ID {trajectory.id} of type: {record_type}")
|
||||||
return trajectory
|
return trajectory
|
||||||
except peewee.DatabaseError as e:
|
except peewee.DatabaseError as e:
|
||||||
logger.error(f"Failed to create trajectory record: {str(e)}")
|
logger.error(f"Failed to create trajectory record: {str(e)}")
|
||||||
|
|
|
||||||
|
|
@ -154,6 +154,24 @@ class FallbackHandler:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback for tool: {self.current_failing_tool_name}"
|
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback for tool: {self.current_failing_tool_name}"
|
||||||
)
|
)
|
||||||
|
# Import repository classes directly to avoid circular imports
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||||
|
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||||
|
from ra_aid.database.connection import get_db
|
||||||
|
|
||||||
|
# Create repositories directly
|
||||||
|
trajectory_repo = TrajectoryRepository(get_db())
|
||||||
|
human_input_repo = HumanInputRepository(get_db())
|
||||||
|
human_input_id = human_input_repo.get_most_recent_id()
|
||||||
|
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"message": f"**Tool fallback activated**: Attempting fallback for tool {self.current_failing_tool_name}.",
|
||||||
|
"display_title": "Fallback Notification",
|
||||||
|
},
|
||||||
|
record_type="info",
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
cpm(
|
cpm(
|
||||||
f"**Tool fallback activated**: Attempting fallback for tool {self.current_failing_tool_name}.",
|
f"**Tool fallback activated**: Attempting fallback for tool {self.current_failing_tool_name}.",
|
||||||
title="Fallback Notification",
|
title="Fallback Notification",
|
||||||
|
|
@ -163,6 +181,24 @@ class FallbackHandler:
|
||||||
if result_list:
|
if result_list:
|
||||||
return result_list
|
return result_list
|
||||||
|
|
||||||
|
# Import repository classes directly to avoid circular imports
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||||
|
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||||
|
from ra_aid.database.connection import get_db
|
||||||
|
|
||||||
|
# Create repositories directly
|
||||||
|
trajectory_repo = TrajectoryRepository(get_db())
|
||||||
|
human_input_repo = HumanInputRepository(get_db())
|
||||||
|
human_input_id = human_input_repo.get_most_recent_id()
|
||||||
|
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"message": "All fallback models have failed.",
|
||||||
|
"display_title": "Fallback Failed",
|
||||||
|
},
|
||||||
|
record_type="error",
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
cpm("All fallback models have failed.", title="Fallback Failed")
|
cpm("All fallback models have failed.", title="Fallback Failed")
|
||||||
|
|
||||||
current_failing_tool_name = self.current_failing_tool_name
|
current_failing_tool_name = self.current_failing_tool_name
|
||||||
|
|
|
||||||
|
|
@ -234,6 +234,24 @@ def create_llm_client(
|
||||||
elif supports_temperature:
|
elif supports_temperature:
|
||||||
if temperature is None:
|
if temperature is None:
|
||||||
temperature = 0.7
|
temperature = 0.7
|
||||||
|
# Import repository classes directly to avoid circular imports
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||||
|
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||||
|
from ra_aid.database.connection import get_db
|
||||||
|
|
||||||
|
# Create repositories directly
|
||||||
|
trajectory_repo = TrajectoryRepository(get_db())
|
||||||
|
human_input_repo = HumanInputRepository(get_db())
|
||||||
|
human_input_id = human_input_repo.get_most_recent_id()
|
||||||
|
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"message": "This model supports temperature argument but none was given. Setting default temperature to 0.7.",
|
||||||
|
"display_title": "Information",
|
||||||
|
},
|
||||||
|
record_type="info",
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
cpm(
|
cpm(
|
||||||
"This model supports temperature argument but none was given. Setting default temperature to 0.7."
|
"This model supports temperature argument but none was given. Setting default temperature to 0.7."
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -51,11 +51,11 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
id = pw.AutoField()
|
id = pw.AutoField()
|
||||||
created_at = pw.DateTimeField()
|
created_at = pw.DateTimeField()
|
||||||
updated_at = pw.DateTimeField()
|
updated_at = pw.DateTimeField()
|
||||||
tool_name = pw.TextField()
|
tool_name = pw.TextField(null=True) # JSON-encoded parameters
|
||||||
tool_parameters = pw.TextField() # JSON-encoded parameters
|
tool_parameters = pw.TextField(null=True) # JSON-encoded parameters
|
||||||
tool_result = pw.TextField() # JSON-encoded result
|
tool_result = pw.TextField(null=True) # JSON-encoded result
|
||||||
step_data = pw.TextField() # JSON-encoded UI rendering data
|
step_data = pw.TextField(null=True) # JSON-encoded UI rendering data
|
||||||
record_type = pw.TextField() # Type of trajectory record
|
record_type = pw.TextField(null=True) # Type of trajectory record
|
||||||
cost = pw.FloatField(null=True) # Placeholder for cost tracking
|
cost = pw.FloatField(null=True) # Placeholder for cost tracking
|
||||||
tokens = pw.IntegerField(null=True) # Placeholder for token usage tracking
|
tokens = pw.IntegerField(null=True) # Placeholder for token usage tracking
|
||||||
is_error = pw.BooleanField(default=False) # Flag indicating if this record represents an error
|
is_error = pw.BooleanField(default=False) # Flag indicating if this record represents an error
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,79 @@
|
||||||
|
"""Peewee migrations -- 008_20250311_191232_add_session_model.py.
|
||||||
|
|
||||||
|
Some examples (model - class or model name)::
|
||||||
|
|
||||||
|
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||||
|
> Model = migrator.ModelClass # Return model in current state by name
|
||||||
|
|
||||||
|
> migrator.sql(sql) # Run custom SQL
|
||||||
|
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||||
|
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||||
|
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||||
|
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||||
|
> migrator.change_fields(model, **fields) # Change fields
|
||||||
|
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||||
|
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||||
|
> migrator.rename_table(model, new_table_name)
|
||||||
|
> migrator.add_index(model, *col_names, unique=False)
|
||||||
|
> migrator.add_not_null(model, *field_names)
|
||||||
|
> migrator.add_default(model, field_name, default)
|
||||||
|
> migrator.add_constraint(model, name, sql)
|
||||||
|
> migrator.drop_index(model, *col_names)
|
||||||
|
> migrator.drop_not_null(model, *field_names)
|
||||||
|
> migrator.drop_constraints(model, *constraints)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from contextlib import suppress
|
||||||
|
|
||||||
|
import peewee as pw
|
||||||
|
from peewee_migrate import Migrator
|
||||||
|
|
||||||
|
|
||||||
|
with suppress(ImportError):
|
||||||
|
import playhouse.postgres_ext as pw_pext
|
||||||
|
|
||||||
|
|
||||||
|
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
|
"""Create the session table for storing application session information."""
|
||||||
|
|
||||||
|
table_exists = False
|
||||||
|
# Check if the table already exists
|
||||||
|
try:
|
||||||
|
database.execute_sql("SELECT id FROM session LIMIT 1")
|
||||||
|
# If we reach here, the table exists
|
||||||
|
table_exists = True
|
||||||
|
except pw.OperationalError:
|
||||||
|
# Table doesn't exist, safe to create
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Create the Session model - this registers it in migrator.orm as 'Session'
|
||||||
|
@migrator.create_model
|
||||||
|
class Session(pw.Model):
|
||||||
|
id = pw.AutoField()
|
||||||
|
created_at = pw.DateTimeField()
|
||||||
|
updated_at = pw.DateTimeField()
|
||||||
|
start_time = pw.DateTimeField()
|
||||||
|
command_line = pw.TextField(null=True)
|
||||||
|
program_version = pw.TextField(null=True)
|
||||||
|
machine_info = pw.TextField(null=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
table_name = "session"
|
||||||
|
|
||||||
|
# FIX: Explicitly register the model under the lowercase table name key
|
||||||
|
# This ensures that later migrations can access it via either:
|
||||||
|
# - migrator.orm['Session'] (class name)
|
||||||
|
# - migrator.orm['session'] (table name)
|
||||||
|
if 'Session' in migrator.orm:
|
||||||
|
migrator.orm['session'] = migrator.orm['Session']
|
||||||
|
|
||||||
|
# Only return after model registration is complete
|
||||||
|
if table_exists:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
|
"""Remove the session table."""
|
||||||
|
|
||||||
|
migrator.remove_model('session')
|
||||||
|
|
@ -0,0 +1,67 @@
|
||||||
|
"""Peewee migrations -- 009_20250311_191517_add_session_fk_to_human_input.py.
|
||||||
|
|
||||||
|
Some examples (model - class or model name)::
|
||||||
|
|
||||||
|
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||||
|
> Model = migrator.ModelClass # Return model in current state by name
|
||||||
|
|
||||||
|
> migrator.sql(sql) # Run custom SQL
|
||||||
|
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||||
|
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||||
|
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||||
|
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||||
|
> migrator.change_fields(model, **fields) # Change fields
|
||||||
|
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||||
|
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||||
|
> migrator.rename_table(model, new_table_name)
|
||||||
|
> migrator.add_index(model, *col_names, unique=False)
|
||||||
|
> migrator.add_not_null(model, *field_names)
|
||||||
|
> migrator.add_default(model, field_name, default)
|
||||||
|
> migrator.add_constraint(model, name, sql)
|
||||||
|
> migrator.drop_index(model, *col_names)
|
||||||
|
> migrator.drop_not_null(model, *field_names)
|
||||||
|
> migrator.drop_constraints(model, *constraints)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from contextlib import suppress
|
||||||
|
|
||||||
|
import peewee as pw
|
||||||
|
from peewee_migrate import Migrator
|
||||||
|
|
||||||
|
|
||||||
|
with suppress(ImportError):
|
||||||
|
import playhouse.postgres_ext as pw_pext
|
||||||
|
|
||||||
|
|
||||||
|
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
|
"""Add session foreign key to HumanInput table."""
|
||||||
|
|
||||||
|
# Get the Session model from migrator.orm
|
||||||
|
Session = migrator.orm['session']
|
||||||
|
|
||||||
|
# Check if the column already exists
|
||||||
|
try:
|
||||||
|
database.execute_sql("SELECT session_id FROM human_input LIMIT 1")
|
||||||
|
# If we reach here, the column exists
|
||||||
|
return
|
||||||
|
except pw.OperationalError:
|
||||||
|
# Column doesn't exist, safe to add
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Add the session_id foreign key column
|
||||||
|
migrator.add_fields(
|
||||||
|
'human_input',
|
||||||
|
session=pw.ForeignKeyField(
|
||||||
|
Session,
|
||||||
|
null=True,
|
||||||
|
field='id',
|
||||||
|
on_delete='CASCADE'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
|
"""Remove session foreign key from HumanInput table."""
|
||||||
|
|
||||||
|
migrator.remove_fields('human_input', 'session')
|
||||||
|
|
@ -0,0 +1,67 @@
|
||||||
|
"""Peewee migrations -- 010_20250311_191617_add_session_fk_to_key_fact.py.
|
||||||
|
|
||||||
|
Some examples (model - class or model name)::
|
||||||
|
|
||||||
|
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||||
|
> Model = migrator.ModelClass # Return model in current state by name
|
||||||
|
|
||||||
|
> migrator.sql(sql) # Run custom SQL
|
||||||
|
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||||
|
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||||
|
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||||
|
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||||
|
> migrator.change_fields(model, **fields) # Change fields
|
||||||
|
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||||
|
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||||
|
> migrator.rename_table(model, new_table_name)
|
||||||
|
> migrator.add_index(model, *col_names, unique=False)
|
||||||
|
> migrator.add_not_null(model, *field_names)
|
||||||
|
> migrator.add_default(model, field_name, default)
|
||||||
|
> migrator.add_constraint(model, name, sql)
|
||||||
|
> migrator.drop_index(model, *col_names)
|
||||||
|
> migrator.drop_not_null(model, *field_names)
|
||||||
|
> migrator.drop_constraints(model, *constraints)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from contextlib import suppress
|
||||||
|
|
||||||
|
import peewee as pw
|
||||||
|
from peewee_migrate import Migrator
|
||||||
|
|
||||||
|
|
||||||
|
with suppress(ImportError):
|
||||||
|
import playhouse.postgres_ext as pw_pext
|
||||||
|
|
||||||
|
|
||||||
|
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
|
"""Add session foreign key to KeyFact table."""
|
||||||
|
|
||||||
|
# Get the Session model from migrator.orm
|
||||||
|
Session = migrator.orm['session']
|
||||||
|
|
||||||
|
# Check if the column already exists
|
||||||
|
try:
|
||||||
|
database.execute_sql("SELECT session_id FROM key_fact LIMIT 1")
|
||||||
|
# If we reach here, the column exists
|
||||||
|
return
|
||||||
|
except pw.OperationalError:
|
||||||
|
# Column doesn't exist, safe to add
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Add the session_id foreign key column
|
||||||
|
migrator.add_fields(
|
||||||
|
'key_fact',
|
||||||
|
session=pw.ForeignKeyField(
|
||||||
|
Session,
|
||||||
|
null=True,
|
||||||
|
field='id',
|
||||||
|
on_delete='CASCADE'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
|
"""Remove session foreign key from KeyFact table."""
|
||||||
|
|
||||||
|
migrator.remove_fields('key_fact', 'session')
|
||||||
|
|
@ -0,0 +1,67 @@
|
||||||
|
"""Peewee migrations -- 011_20250311_191732_add_session_fk_to_key_snippet.py.
|
||||||
|
|
||||||
|
Some examples (model - class or model name)::
|
||||||
|
|
||||||
|
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||||
|
> Model = migrator.ModelClass # Return model in current state by name
|
||||||
|
|
||||||
|
> migrator.sql(sql) # Run custom SQL
|
||||||
|
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||||
|
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||||
|
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||||
|
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||||
|
> migrator.change_fields(model, **fields) # Change fields
|
||||||
|
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||||
|
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||||
|
> migrator.rename_table(model, new_table_name)
|
||||||
|
> migrator.add_index(model, *col_names, unique=False)
|
||||||
|
> migrator.add_not_null(model, *field_names)
|
||||||
|
> migrator.add_default(model, field_name, default)
|
||||||
|
> migrator.add_constraint(model, name, sql)
|
||||||
|
> migrator.drop_index(model, *col_names)
|
||||||
|
> migrator.drop_not_null(model, *field_names)
|
||||||
|
> migrator.drop_constraints(model, *constraints)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from contextlib import suppress
|
||||||
|
|
||||||
|
import peewee as pw
|
||||||
|
from peewee_migrate import Migrator
|
||||||
|
|
||||||
|
|
||||||
|
with suppress(ImportError):
|
||||||
|
import playhouse.postgres_ext as pw_pext
|
||||||
|
|
||||||
|
|
||||||
|
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
|
"""Add session foreign key to KeySnippet table."""
|
||||||
|
|
||||||
|
# Get the Session model from migrator.orm
|
||||||
|
Session = migrator.orm['session']
|
||||||
|
|
||||||
|
# Check if the column already exists
|
||||||
|
try:
|
||||||
|
database.execute_sql("SELECT session_id FROM key_snippet LIMIT 1")
|
||||||
|
# If we reach here, the column exists
|
||||||
|
return
|
||||||
|
except pw.OperationalError:
|
||||||
|
# Column doesn't exist, safe to add
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Add the session_id foreign key column
|
||||||
|
migrator.add_fields(
|
||||||
|
'key_snippet',
|
||||||
|
session=pw.ForeignKeyField(
|
||||||
|
Session,
|
||||||
|
null=True,
|
||||||
|
field='id',
|
||||||
|
on_delete='CASCADE'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
|
"""Remove session foreign key from KeySnippet table."""
|
||||||
|
|
||||||
|
migrator.remove_fields('key_snippet', 'session')
|
||||||
|
|
@ -0,0 +1,67 @@
|
||||||
|
"""Peewee migrations -- 012_20250311_191832_add_session_fk_to_research_note.py.
|
||||||
|
|
||||||
|
Some examples (model - class or model name)::
|
||||||
|
|
||||||
|
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||||
|
> Model = migrator.ModelClass # Return model in current state by name
|
||||||
|
|
||||||
|
> migrator.sql(sql) # Run custom SQL
|
||||||
|
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||||
|
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||||
|
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||||
|
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||||
|
> migrator.change_fields(model, **fields) # Change fields
|
||||||
|
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||||
|
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||||
|
> migrator.rename_table(model, new_table_name)
|
||||||
|
> migrator.add_index(model, *col_names, unique=False)
|
||||||
|
> migrator.add_not_null(model, *field_names)
|
||||||
|
> migrator.add_default(model, field_name, default)
|
||||||
|
> migrator.add_constraint(model, name, sql)
|
||||||
|
> migrator.drop_index(model, *col_names)
|
||||||
|
> migrator.drop_not_null(model, *field_names)
|
||||||
|
> migrator.drop_constraints(model, *constraints)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from contextlib import suppress
|
||||||
|
|
||||||
|
import peewee as pw
|
||||||
|
from peewee_migrate import Migrator
|
||||||
|
|
||||||
|
|
||||||
|
with suppress(ImportError):
|
||||||
|
import playhouse.postgres_ext as pw_pext
|
||||||
|
|
||||||
|
|
||||||
|
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
|
"""Add session foreign key to ResearchNote table."""
|
||||||
|
|
||||||
|
# Get the Session model from migrator.orm
|
||||||
|
Session = migrator.orm['session']
|
||||||
|
|
||||||
|
# Check if the column already exists
|
||||||
|
try:
|
||||||
|
database.execute_sql("SELECT session_id FROM research_note LIMIT 1")
|
||||||
|
# If we reach here, the column exists
|
||||||
|
return
|
||||||
|
except pw.OperationalError:
|
||||||
|
# Column doesn't exist, safe to add
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Add the session_id foreign key column
|
||||||
|
migrator.add_fields(
|
||||||
|
'research_note',
|
||||||
|
session=pw.ForeignKeyField(
|
||||||
|
Session,
|
||||||
|
null=True,
|
||||||
|
field='id',
|
||||||
|
on_delete='CASCADE'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
|
"""Remove session foreign key from ResearchNote table."""
|
||||||
|
|
||||||
|
migrator.remove_fields('research_note', 'session')
|
||||||
|
|
@ -0,0 +1,67 @@
|
||||||
|
"""Peewee migrations -- 013_20250311_191701_add_session_fk_to_trajectory.py.
|
||||||
|
|
||||||
|
Some examples (model - class or model name)::
|
||||||
|
|
||||||
|
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||||
|
> Model = migrator.ModelClass # Return model in current state by name
|
||||||
|
|
||||||
|
> migrator.sql(sql) # Run custom SQL
|
||||||
|
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||||
|
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||||
|
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||||
|
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||||
|
> migrator.change_fields(model, **fields) # Change fields
|
||||||
|
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||||
|
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||||
|
> migrator.rename_table(model, new_table_name)
|
||||||
|
> migrator.add_index(model, *col_names, unique=False)
|
||||||
|
> migrator.add_not_null(model, *field_names)
|
||||||
|
> migrator.add_default(model, field_name, default)
|
||||||
|
> migrator.add_constraint(model, name, sql)
|
||||||
|
> migrator.drop_index(model, *col_names)
|
||||||
|
> migrator.drop_not_null(model, *field_names)
|
||||||
|
> migrator.drop_constraints(model, *constraints)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from contextlib import suppress
|
||||||
|
|
||||||
|
import peewee as pw
|
||||||
|
from peewee_migrate import Migrator
|
||||||
|
|
||||||
|
|
||||||
|
with suppress(ImportError):
|
||||||
|
import playhouse.postgres_ext as pw_pext
|
||||||
|
|
||||||
|
|
||||||
|
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
|
"""Add session foreign key to Trajectory table."""
|
||||||
|
|
||||||
|
# Get the Session model from migrator.orm
|
||||||
|
Session = migrator.orm['session']
|
||||||
|
|
||||||
|
# Check if the column already exists
|
||||||
|
try:
|
||||||
|
database.execute_sql("SELECT session_id FROM trajectory LIMIT 1")
|
||||||
|
# If we reach here, the column exists
|
||||||
|
return
|
||||||
|
except pw.OperationalError:
|
||||||
|
# Column doesn't exist, safe to add
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Add the session_id foreign key column
|
||||||
|
migrator.add_fields(
|
||||||
|
'trajectory',
|
||||||
|
session=pw.ForeignKeyField(
|
||||||
|
Session,
|
||||||
|
null=True,
|
||||||
|
field='id',
|
||||||
|
on_delete='CASCADE'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
|
"""Remove session foreign key from Trajectory table."""
|
||||||
|
|
||||||
|
migrator.remove_fields('trajectory', 'session')
|
||||||
|
|
@ -17,6 +17,8 @@ __all__ = [
|
||||||
|
|
||||||
from ra_aid.file_listing import FileListerError, get_file_listing
|
from ra_aid.file_listing import FileListerError, get_file_listing
|
||||||
from ra_aid.project_state import ProjectStateError, is_new_project
|
from ra_aid.project_state import ProjectStateError, is_new_project
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||||
|
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -130,6 +132,24 @@ def display_project_status(info: ProjectInfo) -> None:
|
||||||
{status} with **{file_count} file(s)**
|
{status} with **{file_count} file(s)**
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Record project status in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"project_status": "new" if info.is_new else "existing",
|
||||||
|
"file_count": file_count,
|
||||||
|
"total_files": info.total_files,
|
||||||
|
"display_title": "Project Status",
|
||||||
|
},
|
||||||
|
record_type="info",
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Silently continue if trajectory recording fails
|
||||||
|
pass
|
||||||
|
|
||||||
# Create and display panel
|
# Create and display panel
|
||||||
console = Console()
|
console = Console()
|
||||||
console.print(Panel(Markdown(status_text.strip()), title="📊 Project Status"))
|
console.print(Panel(Markdown(status_text.strip()), title="📊 Project Status"))
|
||||||
|
|
@ -194,4 +194,5 @@ THE AGENT IS VERY FORGETFUL AND YOUR WRITING MUST INCLUDE REMARKS ABOUT HOW IT S
|
||||||
REMEMBER WE ARE INSTRUCTING THE AGENT **HOW TO DO RESEARCH ABOUT WHAT ALREADY EXISTS** AT THIS POINT USING THE TOOLS AVAILABLE. YOU ARE NOT TO DO THE ACTUAL RESEARCH YOURSELF. IF AN IMPLEMENTATION IS REQUESTED, THE AGENT SHOULD BE INSTRUCTED TO CALL request_task_implementation BUT ONLY AFTER EMITTING RESEARCH NOTES, KEY FACTS, AND KEY SNIPPETS AS RELEVANT.
|
REMEMBER WE ARE INSTRUCTING THE AGENT **HOW TO DO RESEARCH ABOUT WHAT ALREADY EXISTS** AT THIS POINT USING THE TOOLS AVAILABLE. YOU ARE NOT TO DO THE ACTUAL RESEARCH YOURSELF. IF AN IMPLEMENTATION IS REQUESTED, THE AGENT SHOULD BE INSTRUCTED TO CALL request_task_implementation BUT ONLY AFTER EMITTING RESEARCH NOTES, KEY FACTS, AND KEY SNIPPETS AS RELEVANT.
|
||||||
IT IS IMPERATIVE THAT WE DO NOT START DIRECTLY IMPLEMENTING ANYTHING AT THIS POINT. WE ARE RESEARCHING, THEN CALLING request_implementation *AT MOST ONCE*.
|
IT IS IMPERATIVE THAT WE DO NOT START DIRECTLY IMPLEMENTING ANYTHING AT THIS POINT. WE ARE RESEARCHING, THEN CALLING request_implementation *AT MOST ONCE*.
|
||||||
IT IS IMPERATIVE THE AGENT EMITS KEY FACTS AND THOROUGH RESEARCH NOTES AT THIS POINT. THE RESEARCH NOTES CAN JUST BE THOUGHTS AT THIS POINT IF IT IS A NEW PROJECT.
|
IT IS IMPERATIVE THE AGENT EMITS KEY FACTS AND THOROUGH RESEARCH NOTES AT THIS POINT. THE RESEARCH NOTES CAN JUST BE THOUGHTS AT THIS POINT IF IT IS A NEW PROJECT.
|
||||||
|
THE AGENT MUST ALWAYS CALL emit_research_notes AT LEAST ONCE, ESPECIALLY IF IT CALLS ask_expert.
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -15,11 +15,12 @@ from ra_aid.agent_context import (
|
||||||
reset_completion_flags,
|
reset_completion_flags,
|
||||||
)
|
)
|
||||||
from ra_aid.config import DEFAULT_MODEL
|
from ra_aid.config import DEFAULT_MODEL
|
||||||
from ra_aid.console.formatting import print_error
|
from ra_aid.console.formatting import print_error, print_task_header
|
||||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||||
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||||
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
|
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
|
||||||
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
||||||
from ra_aid.exceptions import AgentInterrupt
|
from ra_aid.exceptions import AgentInterrupt
|
||||||
|
|
@ -27,8 +28,7 @@ from ra_aid.model_formatters import format_key_facts_dict
|
||||||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||||
from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict
|
from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict
|
||||||
|
|
||||||
from ..console import print_task_header
|
from ra_aid.llm import initialize_llm
|
||||||
from ..llm import initialize_llm
|
|
||||||
from .human import ask_human
|
from .human import ask_human
|
||||||
from .memory import get_related_files, get_work_log
|
from .memory import get_related_files, get_work_log
|
||||||
|
|
||||||
|
|
@ -63,7 +63,23 @@ def request_research(query: str) -> ResearchResult:
|
||||||
# Check recursion depth
|
# Check recursion depth
|
||||||
current_depth = get_depth()
|
current_depth = get_depth()
|
||||||
if current_depth >= RESEARCH_AGENT_RECURSION_LIMIT:
|
if current_depth >= RESEARCH_AGENT_RECURSION_LIMIT:
|
||||||
print_error("Maximum research recursion depth reached")
|
error_message = "Maximum research recursion depth reached"
|
||||||
|
|
||||||
|
# Record error in trajectory
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"error_message": error_message,
|
||||||
|
"display_title": "Error",
|
||||||
|
},
|
||||||
|
record_type="error",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
is_error=True,
|
||||||
|
error_message=error_message
|
||||||
|
)
|
||||||
|
|
||||||
|
print_error(error_message)
|
||||||
try:
|
try:
|
||||||
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
|
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
|
|
@ -110,7 +126,23 @@ def request_research(query: str) -> ResearchResult:
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_error(f"Error during research: {str(e)}")
|
error_message = f"Error during research: {str(e)}"
|
||||||
|
|
||||||
|
# Record error in trajectory
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"error_message": error_message,
|
||||||
|
"display_title": "Error",
|
||||||
|
},
|
||||||
|
record_type="error",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
is_error=True,
|
||||||
|
error_message=error_message
|
||||||
|
)
|
||||||
|
|
||||||
|
print_error(error_message)
|
||||||
success = False
|
success = False
|
||||||
reason = f"error: {str(e)}"
|
reason = f"error: {str(e)}"
|
||||||
finally:
|
finally:
|
||||||
|
|
@ -195,7 +227,23 @@ def request_web_research(query: str) -> ResearchResult:
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_error(f"Error during web research: {str(e)}")
|
error_message = f"Error during web research: {str(e)}"
|
||||||
|
|
||||||
|
# Record error in trajectory
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"error_message": error_message,
|
||||||
|
"display_title": "Error",
|
||||||
|
},
|
||||||
|
record_type="error",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
is_error=True,
|
||||||
|
error_message=error_message
|
||||||
|
)
|
||||||
|
|
||||||
|
print_error(error_message)
|
||||||
success = False
|
success = False
|
||||||
reason = f"error: {str(e)}"
|
reason = f"error: {str(e)}"
|
||||||
finally:
|
finally:
|
||||||
|
|
@ -347,6 +395,19 @@ def request_task_implementation(task_spec: str) -> str:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print_task_header(task_spec)
|
print_task_header(task_spec)
|
||||||
|
|
||||||
|
# Record task display in trajectory
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"task": task_spec,
|
||||||
|
"display_title": "Task",
|
||||||
|
},
|
||||||
|
record_type="task_display",
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
|
|
||||||
# Run implementation agent
|
# Run implementation agent
|
||||||
from ..agents.implementation_agent import run_task_implementation_agent
|
from ..agents.implementation_agent import run_task_implementation_agent
|
||||||
|
|
||||||
|
|
@ -372,7 +433,23 @@ def request_task_implementation(task_spec: str) -> str:
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_error(f"Error during task implementation: {str(e)}")
|
error_message = f"Error during task implementation: {str(e)}"
|
||||||
|
|
||||||
|
# Record error in trajectory
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"error_message": error_message,
|
||||||
|
"display_title": "Error",
|
||||||
|
},
|
||||||
|
record_type="error",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
is_error=True,
|
||||||
|
error_message=error_message
|
||||||
|
)
|
||||||
|
|
||||||
|
print_error(error_message)
|
||||||
success = False
|
success = False
|
||||||
reason = f"error: {str(e)}"
|
reason = f"error: {str(e)}"
|
||||||
|
|
||||||
|
|
@ -503,7 +580,23 @@ def request_implementation(task_spec: str) -> str:
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_error(f"Error during planning: {str(e)}")
|
error_message = f"Error during planning: {str(e)}"
|
||||||
|
|
||||||
|
# Record error in trajectory
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"error_message": error_message,
|
||||||
|
"display_title": "Error",
|
||||||
|
},
|
||||||
|
record_type="error",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
is_error=True,
|
||||||
|
error_message=error_message
|
||||||
|
)
|
||||||
|
|
||||||
|
print_error(error_message)
|
||||||
success = False
|
success = False
|
||||||
reason = f"error: {str(e)}"
|
reason = f"error: {str(e)}"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,9 @@ from rich.panel import Panel
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
from ..database.repositories.trajectory_repository import get_trajectory_repository
|
||||||
|
from ..database.repositories.human_input_repository import get_human_input_repository
|
||||||
|
|
||||||
from ..database.repositories.key_fact_repository import get_key_fact_repository
|
from ..database.repositories.key_fact_repository import get_key_fact_repository
|
||||||
from ..database.repositories.key_snippet_repository import get_key_snippet_repository
|
from ..database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||||
from ..database.repositories.related_files_repository import get_related_files_repository
|
from ..database.repositories.related_files_repository import get_related_files_repository
|
||||||
|
|
@ -72,6 +75,23 @@ def emit_expert_context(context: str) -> str:
|
||||||
"""
|
"""
|
||||||
expert_context["text"].append(context)
|
expert_context["text"].append(context)
|
||||||
|
|
||||||
|
# Record expert context in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
tool_name="emit_expert_context",
|
||||||
|
tool_parameters={"context_length": len(context)},
|
||||||
|
step_data={
|
||||||
|
"display_title": "Expert Context",
|
||||||
|
"context_length": len(context),
|
||||||
|
},
|
||||||
|
record_type="tool_execution",
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to record trajectory: {e}")
|
||||||
|
|
||||||
# Create and display status panel
|
# Create and display status panel
|
||||||
panel_content = f"Added expert context ({len(context)} characters)"
|
panel_content = f"Added expert context ({len(context)} characters)"
|
||||||
console.print(Panel(panel_content, title="Expert Context", border_style="blue"))
|
console.print(Panel(panel_content, title="Expert Context", border_style="blue"))
|
||||||
|
|
@ -184,6 +204,23 @@ def ask_expert(question: str) -> str:
|
||||||
# Build display query (just question)
|
# Build display query (just question)
|
||||||
display_query = "# Question\n" + question
|
display_query = "# Question\n" + question
|
||||||
|
|
||||||
|
# Record expert query in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
tool_name="ask_expert",
|
||||||
|
tool_parameters={"question": question},
|
||||||
|
step_data={
|
||||||
|
"display_title": "Expert Query",
|
||||||
|
"question": question,
|
||||||
|
},
|
||||||
|
record_type="tool_execution",
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to record trajectory: {e}")
|
||||||
|
|
||||||
# Show only question in panel
|
# Show only question in panel
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Markdown(display_query), title="🤔 Expert Query", border_style="yellow")
|
Panel(Markdown(display_query), title="🤔 Expert Query", border_style="yellow")
|
||||||
|
|
@ -263,6 +300,23 @@ def ask_expert(question: str) -> str:
|
||||||
logger.error(f"Exception during content processing: {str(e)}")
|
logger.error(f"Exception during content processing: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
# Record expert response in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
tool_name="ask_expert",
|
||||||
|
tool_parameters={"question": question},
|
||||||
|
step_data={
|
||||||
|
"display_title": "Expert Response",
|
||||||
|
"response_length": len(content),
|
||||||
|
},
|
||||||
|
record_type="tool_execution",
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to record trajectory: {e}")
|
||||||
|
|
||||||
# Format and display response
|
# Format and display response
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Markdown(content), title="Expert Response", border_style="blue")
|
Panel(Markdown(content), title="Expert Response", border_style="blue")
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@ from rich.panel import Panel
|
||||||
from ra_aid.console import console
|
from ra_aid.console import console
|
||||||
from ra_aid.console.formatting import print_error
|
from ra_aid.console.formatting import print_error
|
||||||
from ra_aid.tools.memory import emit_related_files
|
from ra_aid.tools.memory import emit_related_files
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||||
|
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||||
|
|
||||||
|
|
||||||
def truncate_display_str(s: str, max_length: int = 30) -> str:
|
def truncate_display_str(s: str, max_length: int = 30) -> str:
|
||||||
|
|
@ -54,6 +56,32 @@ def file_str_replace(filepath: str, old_str: str, new_str: str, *, replace_all:
|
||||||
path = Path(filepath)
|
path = Path(filepath)
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
msg = f"File not found: {filepath}"
|
msg = f"File not found: {filepath}"
|
||||||
|
|
||||||
|
# Record error in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"error_message": msg,
|
||||||
|
"display_title": "Error",
|
||||||
|
},
|
||||||
|
record_type="error",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
is_error=True,
|
||||||
|
error_message=msg,
|
||||||
|
tool_name="file_str_replace",
|
||||||
|
tool_parameters={
|
||||||
|
"filepath": filepath,
|
||||||
|
"old_str": old_str,
|
||||||
|
"new_str": new_str,
|
||||||
|
"replace_all": replace_all
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Silently handle trajectory recording failures (e.g., in test environments)
|
||||||
|
pass
|
||||||
|
|
||||||
print_error(msg)
|
print_error(msg)
|
||||||
return {"success": False, "message": msg}
|
return {"success": False, "message": msg}
|
||||||
|
|
||||||
|
|
@ -62,10 +90,62 @@ def file_str_replace(filepath: str, old_str: str, new_str: str, *, replace_all:
|
||||||
|
|
||||||
if count == 0:
|
if count == 0:
|
||||||
msg = f"String not found: {truncate_display_str(old_str)}"
|
msg = f"String not found: {truncate_display_str(old_str)}"
|
||||||
|
|
||||||
|
# Record error in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"error_message": msg,
|
||||||
|
"display_title": "Error",
|
||||||
|
},
|
||||||
|
record_type="error",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
is_error=True,
|
||||||
|
error_message=msg,
|
||||||
|
tool_name="file_str_replace",
|
||||||
|
tool_parameters={
|
||||||
|
"filepath": filepath,
|
||||||
|
"old_str": old_str,
|
||||||
|
"new_str": new_str,
|
||||||
|
"replace_all": replace_all
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Silently handle trajectory recording failures (e.g., in test environments)
|
||||||
|
pass
|
||||||
|
|
||||||
print_error(msg)
|
print_error(msg)
|
||||||
return {"success": False, "message": msg}
|
return {"success": False, "message": msg}
|
||||||
elif count > 1 and not replace_all:
|
elif count > 1 and not replace_all:
|
||||||
msg = f"String appears {count} times - must be unique (use replace_all=True to replace all occurrences)"
|
msg = f"String appears {count} times - must be unique (use replace_all=True to replace all occurrences)"
|
||||||
|
|
||||||
|
# Record error in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"error_message": msg,
|
||||||
|
"display_title": "Error",
|
||||||
|
},
|
||||||
|
record_type="error",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
is_error=True,
|
||||||
|
error_message=msg,
|
||||||
|
tool_name="file_str_replace",
|
||||||
|
tool_parameters={
|
||||||
|
"filepath": filepath,
|
||||||
|
"old_str": old_str,
|
||||||
|
"new_str": new_str,
|
||||||
|
"replace_all": replace_all
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Silently handle trajectory recording failures (e.g., in test environments)
|
||||||
|
pass
|
||||||
|
|
||||||
print_error(msg)
|
print_error(msg)
|
||||||
return {"success": False, "message": msg}
|
return {"success": False, "message": msg}
|
||||||
|
|
||||||
|
|
@ -93,7 +173,34 @@ def file_str_replace(filepath: str, old_str: str, new_str: str, *, replace_all:
|
||||||
emit_related_files.invoke({"files": [filepath]})
|
emit_related_files.invoke({"files": [filepath]})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Don't let related files error affect main function success
|
# Don't let related files error affect main function success
|
||||||
print_error(f"Note: Could not add to related files: {str(e)}")
|
error_msg = f"Note: Could not add to related files: {str(e)}"
|
||||||
|
|
||||||
|
# Record error in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"error_message": error_msg,
|
||||||
|
"display_title": "Error",
|
||||||
|
},
|
||||||
|
record_type="error",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
is_error=True,
|
||||||
|
error_message=error_msg,
|
||||||
|
tool_name="file_str_replace",
|
||||||
|
tool_parameters={
|
||||||
|
"filepath": filepath,
|
||||||
|
"old_str": old_str,
|
||||||
|
"new_str": new_str,
|
||||||
|
"replace_all": replace_all
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Silently handle trajectory recording failures (e.g., in test environments)
|
||||||
|
pass
|
||||||
|
|
||||||
|
print_error(error_msg)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
|
|
@ -102,5 +209,31 @@ def file_str_replace(filepath: str, old_str: str, new_str: str, *, replace_all:
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"Error: {str(e)}"
|
msg = f"Error: {str(e)}"
|
||||||
|
|
||||||
|
# Record error in trajectory
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
step_data={
|
||||||
|
"error_message": msg,
|
||||||
|
"display_title": "Error",
|
||||||
|
},
|
||||||
|
record_type="error",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
is_error=True,
|
||||||
|
error_message=msg,
|
||||||
|
tool_name="file_str_replace",
|
||||||
|
tool_parameters={
|
||||||
|
"filepath": filepath,
|
||||||
|
"old_str": old_str,
|
||||||
|
"new_str": new_str,
|
||||||
|
"replace_all": replace_all
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Silently handle trajectory recording failures (e.g., in test environments)
|
||||||
|
pass
|
||||||
|
|
||||||
print_error(msg)
|
print_error(msg)
|
||||||
return {"success": False, "message": msg}
|
return {"success": False, "message": msg}
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import fnmatch
|
import fnmatch
|
||||||
from typing import List, Tuple
|
import logging
|
||||||
|
from typing import List, Tuple, Dict, Optional, Any
|
||||||
|
|
||||||
from fuzzywuzzy import process
|
from fuzzywuzzy import process
|
||||||
from git import Repo, exc
|
from git import Repo, exc
|
||||||
|
|
@ -12,6 +13,49 @@ from ra_aid.file_listing import get_all_project_files, FileListerError
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
|
def record_trajectory(
|
||||||
|
tool_name: str,
|
||||||
|
tool_parameters: Dict,
|
||||||
|
step_data: Dict,
|
||||||
|
record_type: str = "tool_execution",
|
||||||
|
is_error: bool = False,
|
||||||
|
error_message: Optional[str] = None,
|
||||||
|
error_type: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Helper function to record trajectory information, handling the case when repositories are not available.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool
|
||||||
|
tool_parameters: Parameters passed to the tool
|
||||||
|
step_data: UI rendering data
|
||||||
|
record_type: Type of trajectory record
|
||||||
|
is_error: Flag indicating if this record represents an error
|
||||||
|
error_message: The error message
|
||||||
|
error_type: The type/class of the error
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||||
|
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||||
|
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_parameters=tool_parameters,
|
||||||
|
step_data=step_data,
|
||||||
|
record_type=record_type,
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
is_error=is_error,
|
||||||
|
error_message=error_message,
|
||||||
|
error_type=error_type
|
||||||
|
)
|
||||||
|
except (ImportError, RuntimeError):
|
||||||
|
# If either the repository modules can't be imported or no repository is available,
|
||||||
|
# just log and continue without recording trajectory
|
||||||
|
logging.debug("Skipping trajectory recording: repositories not available")
|
||||||
|
|
||||||
DEFAULT_EXCLUDE_PATTERNS = [
|
DEFAULT_EXCLUDE_PATTERNS = [
|
||||||
"*.pyc",
|
"*.pyc",
|
||||||
"__pycache__/*",
|
"__pycache__/*",
|
||||||
|
|
@ -57,7 +101,32 @@ def fuzzy_find_project_files(
|
||||||
"""
|
"""
|
||||||
# Validate threshold
|
# Validate threshold
|
||||||
if not 0 <= threshold <= 100:
|
if not 0 <= threshold <= 100:
|
||||||
raise ValueError("Threshold must be between 0 and 100")
|
error_msg = "Threshold must be between 0 and 100"
|
||||||
|
|
||||||
|
# Record error in trajectory
|
||||||
|
record_trajectory(
|
||||||
|
tool_name="fuzzy_find_project_files",
|
||||||
|
tool_parameters={
|
||||||
|
"search_term": search_term,
|
||||||
|
"repo_path": repo_path,
|
||||||
|
"threshold": threshold,
|
||||||
|
"max_results": max_results,
|
||||||
|
"include_paths": include_paths,
|
||||||
|
"exclude_patterns": exclude_patterns,
|
||||||
|
"include_hidden": include_hidden
|
||||||
|
},
|
||||||
|
step_data={
|
||||||
|
"search_term": search_term,
|
||||||
|
"display_title": "Invalid Threshold Value",
|
||||||
|
"error_message": error_msg
|
||||||
|
},
|
||||||
|
record_type="tool_execution",
|
||||||
|
is_error=True,
|
||||||
|
error_message=error_msg,
|
||||||
|
error_type="ValueError"
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
# Handle empty search term as special case
|
# Handle empty search term as special case
|
||||||
if not search_term:
|
if not search_term:
|
||||||
|
|
@ -126,6 +195,27 @@ def fuzzy_find_project_files(
|
||||||
else:
|
else:
|
||||||
info_sections.append("## Results\n*No matches found*")
|
info_sections.append("## Results\n*No matches found*")
|
||||||
|
|
||||||
|
# Record fuzzy find in trajectory
|
||||||
|
record_trajectory(
|
||||||
|
tool_name="fuzzy_find_project_files",
|
||||||
|
tool_parameters={
|
||||||
|
"search_term": search_term,
|
||||||
|
"repo_path": repo_path,
|
||||||
|
"threshold": threshold,
|
||||||
|
"max_results": max_results,
|
||||||
|
"include_paths": include_paths,
|
||||||
|
"exclude_patterns": exclude_patterns,
|
||||||
|
"include_hidden": include_hidden
|
||||||
|
},
|
||||||
|
step_data={
|
||||||
|
"search_term": search_term,
|
||||||
|
"display_title": "Fuzzy Find Results",
|
||||||
|
"total_files": len(all_files),
|
||||||
|
"matches_found": len(filtered_matches)
|
||||||
|
},
|
||||||
|
record_type="tool_execution"
|
||||||
|
)
|
||||||
|
|
||||||
# Display the panel
|
# Display the panel
|
||||||
console.print(
|
console.print(
|
||||||
Panel(
|
Panel(
|
||||||
|
|
@ -138,5 +228,30 @@ def fuzzy_find_project_files(
|
||||||
return filtered_matches
|
return filtered_matches
|
||||||
|
|
||||||
except FileListerError as e:
|
except FileListerError as e:
|
||||||
console.print(f"[bold red]Error listing files: {e}[/bold red]")
|
error_msg = f"Error listing files: {e}"
|
||||||
|
|
||||||
|
# Record error in trajectory
|
||||||
|
record_trajectory(
|
||||||
|
tool_name="fuzzy_find_project_files",
|
||||||
|
tool_parameters={
|
||||||
|
"search_term": search_term,
|
||||||
|
"repo_path": repo_path,
|
||||||
|
"threshold": threshold,
|
||||||
|
"max_results": max_results,
|
||||||
|
"include_paths": include_paths,
|
||||||
|
"exclude_patterns": exclude_patterns,
|
||||||
|
"include_hidden": include_hidden
|
||||||
|
},
|
||||||
|
step_data={
|
||||||
|
"search_term": search_term,
|
||||||
|
"display_title": "Fuzzy Find Error",
|
||||||
|
"error_message": error_msg
|
||||||
|
},
|
||||||
|
record_type="tool_execution",
|
||||||
|
is_error=True,
|
||||||
|
error_message=error_msg,
|
||||||
|
error_type=type(e).__name__
|
||||||
|
)
|
||||||
|
|
||||||
|
console.print(f"[bold red]{error_msg}[/bold red]")
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ from ra_aid.database.repositories.key_fact_repository import get_key_fact_reposi
|
||||||
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||||
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||||
from ra_aid.database.repositories.work_log_repository import get_work_log_repository
|
from ra_aid.database.repositories.work_log_repository import get_work_log_repository
|
||||||
from ra_aid.model_formatters import key_snippets_formatter
|
from ra_aid.model_formatters import key_snippets_formatter
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
|
|
@ -69,6 +70,22 @@ def emit_research_notes(notes: str) -> str:
|
||||||
from ra_aid.model_formatters.research_notes_formatter import format_research_note
|
from ra_aid.model_formatters.research_notes_formatter import format_research_note
|
||||||
formatted_note = format_research_note(note_id, notes)
|
formatted_note = format_research_note(note_id, notes)
|
||||||
|
|
||||||
|
# Record to trajectory before displaying panel
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
trajectory_repo.create(
|
||||||
|
tool_name="emit_research_notes",
|
||||||
|
tool_parameters={"notes": notes},
|
||||||
|
step_data={
|
||||||
|
"note_id": note_id,
|
||||||
|
"display_title": "Research Notes",
|
||||||
|
},
|
||||||
|
record_type="memory_operation",
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.warning(f"Failed to record trajectory: {str(e)}")
|
||||||
|
|
||||||
# Display formatted note
|
# Display formatted note
|
||||||
console.print(Panel(Markdown(formatted_note), title="🔍 Research Notes"))
|
console.print(Panel(Markdown(formatted_note), title="🔍 Research Notes"))
|
||||||
|
|
||||||
|
|
@ -123,6 +140,23 @@ def emit_key_facts(facts: List[str]) -> str:
|
||||||
console.print(f"Error storing fact: {str(e)}", style="red")
|
console.print(f"Error storing fact: {str(e)}", style="red")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Record to trajectory before displaying panel
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
trajectory_repo.create(
|
||||||
|
tool_name="emit_key_facts",
|
||||||
|
tool_parameters={"facts": [fact]},
|
||||||
|
step_data={
|
||||||
|
"fact_id": fact_id,
|
||||||
|
"fact": fact,
|
||||||
|
"display_title": f"Key Fact #{fact_id}",
|
||||||
|
},
|
||||||
|
record_type="memory_operation",
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.warning(f"Failed to record trajectory: {str(e)}")
|
||||||
|
|
||||||
# Display panel with ID
|
# Display panel with ID
|
||||||
console.print(
|
console.print(
|
||||||
Panel(
|
Panel(
|
||||||
|
|
@ -214,6 +248,32 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str:
|
||||||
if snippet_info["description"]:
|
if snippet_info["description"]:
|
||||||
display_text.extend(["", "**Description**:", snippet_info["description"]])
|
display_text.extend(["", "**Description**:", snippet_info["description"]])
|
||||||
|
|
||||||
|
# Record to trajectory before displaying panel
|
||||||
|
try:
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
trajectory_repo.create(
|
||||||
|
tool_name="emit_key_snippet",
|
||||||
|
tool_parameters={
|
||||||
|
"snippet_info": {
|
||||||
|
"filepath": snippet_info["filepath"],
|
||||||
|
"line_number": snippet_info["line_number"],
|
||||||
|
"description": snippet_info["description"],
|
||||||
|
# Omit the full snippet content to avoid duplicating large text in the database
|
||||||
|
"snippet_length": len(snippet_info["snippet"])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
step_data={
|
||||||
|
"snippet_id": snippet_id,
|
||||||
|
"filepath": snippet_info["filepath"],
|
||||||
|
"line_number": snippet_info["line_number"],
|
||||||
|
"display_title": f"Key Snippet #{snippet_id}",
|
||||||
|
},
|
||||||
|
record_type="memory_operation",
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.warning(f"Failed to record trajectory: {str(e)}")
|
||||||
|
|
||||||
# Display panel
|
# Display panel
|
||||||
console.print(
|
console.print(
|
||||||
Panel(
|
Panel(
|
||||||
|
|
@ -248,6 +308,25 @@ def one_shot_completed(message: str) -> str:
|
||||||
message: Completion message to display
|
message: Completion message to display
|
||||||
"""
|
"""
|
||||||
mark_task_completed(message)
|
mark_task_completed(message)
|
||||||
|
|
||||||
|
# Record to trajectory before displaying panel
|
||||||
|
human_input_id = None
|
||||||
|
try:
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
trajectory_repo.create(
|
||||||
|
tool_name="one_shot_completed",
|
||||||
|
tool_parameters={"message": message},
|
||||||
|
step_data={
|
||||||
|
"completion_message": message,
|
||||||
|
"display_title": "Task Completed",
|
||||||
|
},
|
||||||
|
record_type="task_completion",
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.warning(f"Failed to record trajectory: {str(e)}")
|
||||||
|
|
||||||
console.print(Panel(Markdown(message), title="✅ Task Completed"))
|
console.print(Panel(Markdown(message), title="✅ Task Completed"))
|
||||||
log_work_event(f"Task completed:\n\n{message}")
|
log_work_event(f"Task completed:\n\n{message}")
|
||||||
return "Completion noted."
|
return "Completion noted."
|
||||||
|
|
@ -261,6 +340,25 @@ def task_completed(message: str) -> str:
|
||||||
message: Message explaining how/why the task is complete
|
message: Message explaining how/why the task is complete
|
||||||
"""
|
"""
|
||||||
mark_task_completed(message)
|
mark_task_completed(message)
|
||||||
|
|
||||||
|
# Record to trajectory before displaying panel
|
||||||
|
human_input_id = None
|
||||||
|
try:
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
trajectory_repo.create(
|
||||||
|
tool_name="task_completed",
|
||||||
|
tool_parameters={"message": message},
|
||||||
|
step_data={
|
||||||
|
"completion_message": message,
|
||||||
|
"display_title": "Task Completed",
|
||||||
|
},
|
||||||
|
record_type="task_completion",
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.warning(f"Failed to record trajectory: {str(e)}")
|
||||||
|
|
||||||
console.print(Panel(Markdown(message), title="✅ Task Completed"))
|
console.print(Panel(Markdown(message), title="✅ Task Completed"))
|
||||||
log_work_event(f"Task completed:\n\n{message}")
|
log_work_event(f"Task completed:\n\n{message}")
|
||||||
return "Completion noted."
|
return "Completion noted."
|
||||||
|
|
@ -275,6 +373,25 @@ def plan_implementation_completed(message: str) -> str:
|
||||||
"""
|
"""
|
||||||
mark_should_exit(propagation_depth=1)
|
mark_should_exit(propagation_depth=1)
|
||||||
mark_plan_completed(message)
|
mark_plan_completed(message)
|
||||||
|
|
||||||
|
# Record to trajectory before displaying panel
|
||||||
|
human_input_id = None
|
||||||
|
try:
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
trajectory_repo.create(
|
||||||
|
tool_name="plan_implementation_completed",
|
||||||
|
tool_parameters={"message": message},
|
||||||
|
step_data={
|
||||||
|
"completion_message": message,
|
||||||
|
"display_title": "Plan Executed",
|
||||||
|
},
|
||||||
|
record_type="plan_completion",
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.warning(f"Failed to record trajectory: {str(e)}")
|
||||||
|
|
||||||
console.print(Panel(Markdown(message), title="✅ Plan Executed"))
|
console.print(Panel(Markdown(message), title="✅ Plan Executed"))
|
||||||
log_work_event(f"Completed implementation:\n\n{message}")
|
log_work_event(f"Completed implementation:\n\n{message}")
|
||||||
return "Plan completion noted."
|
return "Plan completion noted."
|
||||||
|
|
@ -361,10 +478,29 @@ def emit_related_files(files: List[str]) -> str:
|
||||||
|
|
||||||
results.append(f"File ID #{file_id}: {file}")
|
results.append(f"File ID #{file_id}: {file}")
|
||||||
|
|
||||||
# Rich output - single consolidated panel for added files
|
# Record to trajectory before displaying panel for added files
|
||||||
if added_files:
|
if added_files:
|
||||||
files_added_md = "\n".join(f"- `{file}`" for id, file in added_files)
|
files_added_md = "\n".join(f"- `{file}`" for id, file in added_files)
|
||||||
md_content = f"**Files Noted:**\n{files_added_md}"
|
md_content = f"**Files Noted:**\n{files_added_md}"
|
||||||
|
|
||||||
|
human_input_id = None
|
||||||
|
try:
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
trajectory_repo.create(
|
||||||
|
tool_name="emit_related_files",
|
||||||
|
tool_parameters={"files": files},
|
||||||
|
step_data={
|
||||||
|
"added_files": [file for _, file in added_files],
|
||||||
|
"added_file_ids": [file_id for file_id, _ in added_files],
|
||||||
|
"display_title": "Related Files Noted",
|
||||||
|
},
|
||||||
|
record_type="memory_operation",
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.warning(f"Failed to record trajectory: {str(e)}")
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
Panel(
|
Panel(
|
||||||
Markdown(md_content),
|
Markdown(md_content),
|
||||||
|
|
@ -373,10 +509,28 @@ def emit_related_files(files: List[str]) -> str:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Display skipped binary files
|
# Record to trajectory before displaying panel for binary files
|
||||||
if binary_files:
|
if binary_files:
|
||||||
binary_files_md = "\n".join(f"- `{file}`" for file in binary_files)
|
binary_files_md = "\n".join(f"- `{file}`" for file in binary_files)
|
||||||
md_content = f"**Binary Files Skipped:**\n{binary_files_md}"
|
md_content = f"**Binary Files Skipped:**\n{binary_files_md}"
|
||||||
|
|
||||||
|
human_input_id = None
|
||||||
|
try:
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
trajectory_repo.create(
|
||||||
|
tool_name="emit_related_files",
|
||||||
|
tool_parameters={"files": files},
|
||||||
|
step_data={
|
||||||
|
"binary_files": binary_files,
|
||||||
|
"display_title": "Binary Files Not Added",
|
||||||
|
},
|
||||||
|
record_type="memory_operation",
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.warning(f"Failed to record trajectory: {str(e)}")
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
Panel(
|
Panel(
|
||||||
Markdown(md_content),
|
Markdown(md_content),
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
import os.path
|
import os.path
|
||||||
import time
|
import time
|
||||||
from typing import Dict
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
@ -16,6 +16,49 @@ console = Console()
|
||||||
CHUNK_SIZE = 8192
|
CHUNK_SIZE = 8192
|
||||||
|
|
||||||
|
|
||||||
|
def record_trajectory(
|
||||||
|
tool_name: str,
|
||||||
|
tool_parameters: Dict,
|
||||||
|
step_data: Dict,
|
||||||
|
record_type: str = "tool_execution",
|
||||||
|
is_error: bool = False,
|
||||||
|
error_message: Optional[str] = None,
|
||||||
|
error_type: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Helper function to record trajectory information, handling the case when repositories are not available.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool
|
||||||
|
tool_parameters: Parameters passed to the tool
|
||||||
|
step_data: UI rendering data
|
||||||
|
record_type: Type of trajectory record
|
||||||
|
is_error: Flag indicating if this record represents an error
|
||||||
|
error_message: The error message
|
||||||
|
error_type: The type/class of the error
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||||
|
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||||
|
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_parameters=tool_parameters,
|
||||||
|
step_data=step_data,
|
||||||
|
record_type=record_type,
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
is_error=is_error,
|
||||||
|
error_message=error_message,
|
||||||
|
error_type=error_type
|
||||||
|
)
|
||||||
|
except (ImportError, RuntimeError):
|
||||||
|
# If either the repository modules can't be imported or no repository is available,
|
||||||
|
# just log and continue without recording trajectory
|
||||||
|
logging.debug("Skipping trajectory recording: repositories not available")
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def read_file_tool(filepath: str, encoding: str = "utf-8") -> Dict[str, str]:
|
def read_file_tool(filepath: str, encoding: str = "utf-8") -> Dict[str, str]:
|
||||||
"""Read and return the contents of a text file.
|
"""Read and return the contents of a text file.
|
||||||
|
|
@ -29,10 +72,43 @@ def read_file_tool(filepath: str, encoding: str = "utf-8") -> Dict[str, str]:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
if not os.path.exists(filepath):
|
if not os.path.exists(filepath):
|
||||||
|
# Record error in trajectory
|
||||||
|
record_trajectory(
|
||||||
|
tool_name="read_file_tool",
|
||||||
|
tool_parameters={
|
||||||
|
"filepath": filepath,
|
||||||
|
"encoding": encoding
|
||||||
|
},
|
||||||
|
step_data={
|
||||||
|
"filepath": filepath,
|
||||||
|
"display_title": "File Not Found",
|
||||||
|
"error_message": f"File not found: {filepath}"
|
||||||
|
},
|
||||||
|
is_error=True,
|
||||||
|
error_message=f"File not found: {filepath}",
|
||||||
|
error_type="FileNotFoundError"
|
||||||
|
)
|
||||||
raise FileNotFoundError(f"File not found: {filepath}")
|
raise FileNotFoundError(f"File not found: {filepath}")
|
||||||
|
|
||||||
# Check if the file is binary
|
# Check if the file is binary
|
||||||
if is_binary_file(filepath):
|
if is_binary_file(filepath):
|
||||||
|
# Record binary file error in trajectory
|
||||||
|
record_trajectory(
|
||||||
|
tool_name="read_file_tool",
|
||||||
|
tool_parameters={
|
||||||
|
"filepath": filepath,
|
||||||
|
"encoding": encoding
|
||||||
|
},
|
||||||
|
step_data={
|
||||||
|
"filepath": filepath,
|
||||||
|
"display_title": "Binary File Detected",
|
||||||
|
"error_message": f"Cannot read binary file: {filepath}"
|
||||||
|
},
|
||||||
|
is_error=True,
|
||||||
|
error_message="Cannot read binary file",
|
||||||
|
error_type="BinaryFileError"
|
||||||
|
)
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
Panel(
|
Panel(
|
||||||
f"Cannot read binary file: {filepath}",
|
f"Cannot read binary file: {filepath}",
|
||||||
|
|
@ -67,6 +143,22 @@ def read_file_tool(filepath: str, encoding: str = "utf-8") -> Dict[str, str]:
|
||||||
logging.debug(f"File read complete: {total_bytes} bytes in {elapsed:.2f}s")
|
logging.debug(f"File read complete: {total_bytes} bytes in {elapsed:.2f}s")
|
||||||
logging.debug(f"Pre-truncation stats: {total_bytes} bytes, {line_count} lines")
|
logging.debug(f"Pre-truncation stats: {total_bytes} bytes, {line_count} lines")
|
||||||
|
|
||||||
|
# Record successful file read in trajectory
|
||||||
|
record_trajectory(
|
||||||
|
tool_name="read_file_tool",
|
||||||
|
tool_parameters={
|
||||||
|
"filepath": filepath,
|
||||||
|
"encoding": encoding
|
||||||
|
},
|
||||||
|
step_data={
|
||||||
|
"filepath": filepath,
|
||||||
|
"display_title": "File Read",
|
||||||
|
"line_count": line_count,
|
||||||
|
"total_bytes": total_bytes,
|
||||||
|
"elapsed_time": elapsed
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
Panel(
|
Panel(
|
||||||
f"Read {line_count} lines ({total_bytes} bytes) from {filepath} in {elapsed:.2f}s",
|
f"Read {line_count} lines ({total_bytes} bytes) from {filepath} in {elapsed:.2f}s",
|
||||||
|
|
@ -80,6 +172,25 @@ def read_file_tool(filepath: str, encoding: str = "utf-8") -> Dict[str, str]:
|
||||||
|
|
||||||
return {"content": truncated}
|
return {"content": truncated}
|
||||||
|
|
||||||
except Exception:
|
except Exception as e:
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
|
# Record exception in trajectory (if it's not already a handled FileNotFoundError)
|
||||||
|
if not isinstance(e, FileNotFoundError):
|
||||||
|
record_trajectory(
|
||||||
|
tool_name="read_file_tool",
|
||||||
|
tool_parameters={
|
||||||
|
"filepath": filepath,
|
||||||
|
"encoding": encoding
|
||||||
|
},
|
||||||
|
step_data={
|
||||||
|
"filepath": filepath,
|
||||||
|
"display_title": "File Read Error",
|
||||||
|
"error_message": str(e)
|
||||||
|
},
|
||||||
|
is_error=True,
|
||||||
|
error_message=str(e),
|
||||||
|
error_type=type(e).__name__
|
||||||
|
)
|
||||||
|
|
||||||
raise
|
raise
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,9 @@ from langchain_core.tools import tool
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||||
|
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -10,6 +13,24 @@ def existing_project_detected() -> dict:
|
||||||
"""
|
"""
|
||||||
When to call: Once you have confirmed that the current working directory contains project files.
|
When to call: Once you have confirmed that the current working directory contains project files.
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
# Record detection in trajectory
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
tool_name="existing_project_detected",
|
||||||
|
tool_parameters={},
|
||||||
|
step_data={
|
||||||
|
"detection_type": "existing_project",
|
||||||
|
"display_title": "Existing Project Detected",
|
||||||
|
},
|
||||||
|
record_type="tool_execution",
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Continue even if trajectory recording fails
|
||||||
|
console.print(f"Warning: Could not record trajectory: {str(e)}")
|
||||||
|
|
||||||
console.print(Panel("📁 Existing Project Detected", style="bright_blue", padding=0))
|
console.print(Panel("📁 Existing Project Detected", style="bright_blue", padding=0))
|
||||||
return {
|
return {
|
||||||
"hint": (
|
"hint": (
|
||||||
|
|
@ -30,6 +51,24 @@ def monorepo_detected() -> dict:
|
||||||
"""
|
"""
|
||||||
When to call: After identifying that multiple packages or modules exist within a single repository.
|
When to call: After identifying that multiple packages or modules exist within a single repository.
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
# Record detection in trajectory
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
tool_name="monorepo_detected",
|
||||||
|
tool_parameters={},
|
||||||
|
step_data={
|
||||||
|
"detection_type": "monorepo",
|
||||||
|
"display_title": "Monorepo Detected",
|
||||||
|
},
|
||||||
|
record_type="tool_execution",
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Continue even if trajectory recording fails
|
||||||
|
console.print(f"Warning: Could not record trajectory: {str(e)}")
|
||||||
|
|
||||||
console.print(Panel("📦 Monorepo Detected", style="bright_blue", padding=0))
|
console.print(Panel("📦 Monorepo Detected", style="bright_blue", padding=0))
|
||||||
return {
|
return {
|
||||||
"hint": (
|
"hint": (
|
||||||
|
|
@ -53,6 +92,24 @@ def ui_detected() -> dict:
|
||||||
"""
|
"""
|
||||||
When to call: After detecting that the project contains a user interface layer or front-end component.
|
When to call: After detecting that the project contains a user interface layer or front-end component.
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
# Record detection in trajectory
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
tool_name="ui_detected",
|
||||||
|
tool_parameters={},
|
||||||
|
step_data={
|
||||||
|
"detection_type": "ui",
|
||||||
|
"display_title": "UI Detected",
|
||||||
|
},
|
||||||
|
record_type="tool_execution",
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Continue even if trajectory recording fails
|
||||||
|
console.print(f"Warning: Could not record trajectory: {str(e)}")
|
||||||
|
|
||||||
console.print(Panel("🎯 UI Detected", style="bright_blue", padding=0))
|
console.print(Panel("🎯 UI Detected", style="bright_blue", padding=0))
|
||||||
return {
|
return {
|
||||||
"hint": (
|
"hint": (
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,8 @@ from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
|
||||||
|
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||||
from ra_aid.proc.interactive import run_interactive_command
|
from ra_aid.proc.interactive import run_interactive_command
|
||||||
from ra_aid.text.processing import truncate_output
|
from ra_aid.text.processing import truncate_output
|
||||||
|
|
||||||
|
|
@ -158,6 +160,30 @@ def ripgrep_search(
|
||||||
info_sections.append("\n".join(params))
|
info_sections.append("\n".join(params))
|
||||||
|
|
||||||
# Execute command
|
# Execute command
|
||||||
|
# Record ripgrep search in trajectory
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
tool_name="ripgrep_search",
|
||||||
|
tool_parameters={
|
||||||
|
"pattern": pattern,
|
||||||
|
"before_context_lines": before_context_lines,
|
||||||
|
"after_context_lines": after_context_lines,
|
||||||
|
"file_type": file_type,
|
||||||
|
"case_sensitive": case_sensitive,
|
||||||
|
"include_hidden": include_hidden,
|
||||||
|
"follow_links": follow_links,
|
||||||
|
"exclude_dirs": exclude_dirs,
|
||||||
|
"fixed_string": fixed_string
|
||||||
|
},
|
||||||
|
step_data={
|
||||||
|
"search_pattern": pattern,
|
||||||
|
"display_title": "Ripgrep Search",
|
||||||
|
},
|
||||||
|
record_type="tool_execution",
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
Panel(
|
Panel(
|
||||||
Markdown(f"Searching for: **{pattern}**"),
|
Markdown(f"Searching for: **{pattern}**"),
|
||||||
|
|
@ -179,5 +205,34 @@ def ripgrep_search(
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = str(e)
|
error_msg = str(e)
|
||||||
|
|
||||||
|
# Record error in trajectory
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
tool_name="ripgrep_search",
|
||||||
|
tool_parameters={
|
||||||
|
"pattern": pattern,
|
||||||
|
"before_context_lines": before_context_lines,
|
||||||
|
"after_context_lines": after_context_lines,
|
||||||
|
"file_type": file_type,
|
||||||
|
"case_sensitive": case_sensitive,
|
||||||
|
"include_hidden": include_hidden,
|
||||||
|
"follow_links": follow_links,
|
||||||
|
"exclude_dirs": exclude_dirs,
|
||||||
|
"fixed_string": fixed_string
|
||||||
|
},
|
||||||
|
step_data={
|
||||||
|
"search_pattern": pattern,
|
||||||
|
"display_title": "Ripgrep Search Error",
|
||||||
|
"error_message": error_msg
|
||||||
|
},
|
||||||
|
record_type="tool_execution",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
is_error=True,
|
||||||
|
error_message=error_msg,
|
||||||
|
error_type=type(e).__name__
|
||||||
|
)
|
||||||
|
|
||||||
console.print(Panel(error_msg, title="❌ Error", border_style="red"))
|
console.print(Panel(error_msg, title="❌ Error", border_style="red"))
|
||||||
return {"output": error_msg, "return_code": 1, "success": False}
|
return {"output": error_msg, "return_code": 1, "success": False}
|
||||||
|
|
@ -10,6 +10,8 @@ from ra_aid.proc.interactive import run_interactive_command
|
||||||
from ra_aid.text.processing import truncate_output
|
from ra_aid.text.processing import truncate_output
|
||||||
from ra_aid.tools.memory import log_work_event
|
from ra_aid.tools.memory import log_work_event
|
||||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||||
|
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
|
|
@ -54,6 +56,20 @@ def run_shell_command(
|
||||||
console.print(" " + get_cowboy_message())
|
console.print(" " + get_cowboy_message())
|
||||||
console.print("")
|
console.print("")
|
||||||
|
|
||||||
|
# Record tool execution in trajectory
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
tool_name="run_shell_command",
|
||||||
|
tool_parameters={"command": command, "timeout": timeout},
|
||||||
|
step_data={
|
||||||
|
"command": command,
|
||||||
|
"display_title": "Shell Command",
|
||||||
|
},
|
||||||
|
record_type="tool_execution",
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
|
|
||||||
# Show just the command in a simple panel
|
# Show just the command in a simple panel
|
||||||
console.print(Panel(command, title="🐚 Shell", border_style="bright_yellow"))
|
console.print(Panel(command, title="🐚 Shell", border_style="bright_yellow"))
|
||||||
|
|
||||||
|
|
@ -96,5 +112,23 @@ def run_shell_command(
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print()
|
print()
|
||||||
|
# Record error in trajectory
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
tool_name="run_shell_command",
|
||||||
|
tool_parameters={"command": command, "timeout": timeout},
|
||||||
|
step_data={
|
||||||
|
"command": command,
|
||||||
|
"error": str(e),
|
||||||
|
"display_title": "Shell Error",
|
||||||
|
},
|
||||||
|
record_type="tool_execution",
|
||||||
|
is_error=True,
|
||||||
|
error_message=str(e),
|
||||||
|
error_type=type(e).__name__,
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
|
|
||||||
console.print(Panel(str(e), title="❌ Error", border_style="red"))
|
console.print(Panel(str(e), title="❌ Error", border_style="red"))
|
||||||
return {"output": str(e), "return_code": 1, "success": False}
|
return {"output": str(e), "return_code": 1, "success": False}
|
||||||
|
|
@ -7,6 +7,9 @@ from rich.markdown import Markdown
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from tavily import TavilyClient
|
from tavily import TavilyClient
|
||||||
|
|
||||||
|
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||||
|
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -21,9 +24,44 @@ def web_search_tavily(query: str) -> Dict:
|
||||||
Returns:
|
Returns:
|
||||||
Dict containing search results from Tavily
|
Dict containing search results from Tavily
|
||||||
"""
|
"""
|
||||||
client = TavilyClient(api_key=os.environ["TAVILY_API_KEY"])
|
# Record trajectory before displaying panel
|
||||||
|
trajectory_repo = get_trajectory_repository()
|
||||||
|
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||||
|
trajectory_repo.create(
|
||||||
|
tool_name="web_search_tavily",
|
||||||
|
tool_parameters={"query": query},
|
||||||
|
step_data={
|
||||||
|
"query": query,
|
||||||
|
"display_title": "Web Search",
|
||||||
|
},
|
||||||
|
record_type="tool_execution",
|
||||||
|
human_input_id=human_input_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Display search query panel
|
||||||
console.print(
|
console.print(
|
||||||
Panel(Markdown(query), title="🔍 Searching Tavily", border_style="bright_blue")
|
Panel(Markdown(query), title="🔍 Searching Tavily", border_style="bright_blue")
|
||||||
)
|
)
|
||||||
search_result = client.search(query=query)
|
|
||||||
return search_result
|
try:
|
||||||
|
client = TavilyClient(api_key=os.environ["TAVILY_API_KEY"])
|
||||||
|
search_result = client.search(query=query)
|
||||||
|
return search_result
|
||||||
|
except Exception as e:
|
||||||
|
# Record error in trajectory
|
||||||
|
trajectory_repo.create(
|
||||||
|
tool_name="web_search_tavily",
|
||||||
|
tool_parameters={"query": query},
|
||||||
|
step_data={
|
||||||
|
"query": query,
|
||||||
|
"display_title": "Web Search Error",
|
||||||
|
"error": str(e)
|
||||||
|
},
|
||||||
|
record_type="tool_execution",
|
||||||
|
human_input_id=human_input_id,
|
||||||
|
is_error=True,
|
||||||
|
error_message=str(e),
|
||||||
|
error_type=type(e).__name__
|
||||||
|
)
|
||||||
|
# Re-raise the exception to maintain original behavior
|
||||||
|
raise
|
||||||
|
|
@ -7,7 +7,7 @@ ensuring consistent test environments and proper isolation.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
@ -26,6 +26,39 @@ def mock_config_repository():
|
||||||
yield repo
|
yield repo
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def mock_trajectory_repository():
|
||||||
|
"""Mock the TrajectoryRepository to avoid database operations during tests."""
|
||||||
|
with patch('ra_aid.database.repositories.trajectory_repository.TrajectoryRepository') as mock:
|
||||||
|
# Setup a mock repository
|
||||||
|
mock_repo = MagicMock()
|
||||||
|
mock_repo.create.return_value = MagicMock(id=1)
|
||||||
|
mock.return_value = mock_repo
|
||||||
|
yield mock_repo
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def mock_human_input_repository():
|
||||||
|
"""Mock the HumanInputRepository to avoid database operations during tests."""
|
||||||
|
with patch('ra_aid.database.repositories.human_input_repository.HumanInputRepository') as mock:
|
||||||
|
# Setup a mock repository
|
||||||
|
mock_repo = MagicMock()
|
||||||
|
mock_repo.get_most_recent_id.return_value = 1
|
||||||
|
mock_repo.create.return_value = MagicMock(id=1)
|
||||||
|
mock.return_value = mock_repo
|
||||||
|
yield mock_repo
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def mock_repository_access(mock_trajectory_repository, mock_human_input_repository):
|
||||||
|
"""Mock all repository accessor functions."""
|
||||||
|
with patch('ra_aid.database.repositories.trajectory_repository.get_trajectory_repository',
|
||||||
|
return_value=mock_trajectory_repository):
|
||||||
|
with patch('ra_aid.database.repositories.human_input_repository.get_human_input_repository',
|
||||||
|
return_value=mock_human_input_repository):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def isolated_db_environment(tmp_path, monkeypatch):
|
def isolated_db_environment(tmp_path, monkeypatch):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,11 @@ from ra_aid.anthropic_token_limiter import (
|
||||||
state_modifier,
|
state_modifier,
|
||||||
)
|
)
|
||||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
||||||
from ra_aid.database.repositories.config_repository import ConfigRepositoryManager, get_config_repository, config_repo_var
|
from ra_aid.database.repositories.config_repository import (
|
||||||
|
ConfigRepositoryManager,
|
||||||
|
get_config_repository,
|
||||||
|
config_repo_var,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -34,7 +38,9 @@ def mock_model():
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_config_repository():
|
def mock_config_repository():
|
||||||
"""Mock the ConfigRepository to avoid database operations during tests"""
|
"""Mock the ConfigRepository to avoid database operations during tests"""
|
||||||
with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var:
|
with patch(
|
||||||
|
"ra_aid.database.repositories.config_repository.config_repo_var"
|
||||||
|
) as mock_repo_var:
|
||||||
# Setup a mock repository
|
# Setup a mock repository
|
||||||
mock_repo = MagicMock()
|
mock_repo = MagicMock()
|
||||||
|
|
||||||
|
|
@ -44,6 +50,7 @@ def mock_config_repository():
|
||||||
# Setup get method to return config values
|
# Setup get method to return config values
|
||||||
def get_config(key, default=None):
|
def get_config(key, default=None):
|
||||||
return config.get(key, default)
|
return config.get(key, default)
|
||||||
|
|
||||||
mock_repo.get.side_effect = get_config
|
mock_repo.get.side_effect = get_config
|
||||||
|
|
||||||
# Setup get_all method to return all config values
|
# Setup get_all method to return all config values
|
||||||
|
|
@ -52,11 +59,13 @@ def mock_config_repository():
|
||||||
# Setup set method to update config values
|
# Setup set method to update config values
|
||||||
def set_config(key, value):
|
def set_config(key, value):
|
||||||
config[key] = value
|
config[key] = value
|
||||||
|
|
||||||
mock_repo.set.side_effect = set_config
|
mock_repo.set.side_effect = set_config
|
||||||
|
|
||||||
# Setup update method to update multiple config values
|
# Setup update method to update multiple config values
|
||||||
def update_config(update_dict):
|
def update_config(update_dict):
|
||||||
config.update(update_dict)
|
config.update(update_dict)
|
||||||
|
|
||||||
mock_repo.update.side_effect = update_config
|
mock_repo.update.side_effect = update_config
|
||||||
|
|
||||||
# Make the mock context var return our mock repo
|
# Make the mock context var return our mock repo
|
||||||
|
|
@ -65,15 +74,55 @@ def mock_config_repository():
|
||||||
yield mock_repo
|
yield mock_repo
|
||||||
|
|
||||||
|
|
||||||
# These tests have been moved to test_anthropic_token_limiter.py
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_trajectory_repository():
|
||||||
|
"""Mock the TrajectoryRepository to avoid database operations during tests"""
|
||||||
|
with patch(
|
||||||
|
"ra_aid.database.repositories.trajectory_repository.trajectory_repo_var"
|
||||||
|
) as mock_repo_var:
|
||||||
|
# Setup a mock repository
|
||||||
|
mock_repo = MagicMock()
|
||||||
|
|
||||||
|
# Setup create method to return a mock trajectory
|
||||||
|
def mock_create(**kwargs):
|
||||||
|
mock_trajectory = MagicMock()
|
||||||
|
mock_trajectory.id = 1
|
||||||
|
return mock_trajectory
|
||||||
|
|
||||||
|
mock_repo.create.side_effect = mock_create
|
||||||
|
|
||||||
|
# Make the mock context var return our mock repo
|
||||||
|
mock_repo_var.get.return_value = mock_repo
|
||||||
|
|
||||||
|
yield mock_repo
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_human_input_repository():
|
||||||
|
"""Mock the HumanInputRepository to avoid database operations during tests"""
|
||||||
|
with patch(
|
||||||
|
"ra_aid.database.repositories.human_input_repository.human_input_repo_var"
|
||||||
|
) as mock_repo_var:
|
||||||
|
# Setup a mock repository
|
||||||
|
mock_repo = MagicMock()
|
||||||
|
|
||||||
|
# Setup get_most_recent_id method to return a dummy ID
|
||||||
|
mock_repo.get_most_recent_id.return_value = 1
|
||||||
|
|
||||||
|
# Make the mock context var return our mock repo
|
||||||
|
mock_repo_var.get.return_value = mock_repo
|
||||||
|
|
||||||
|
yield mock_repo
|
||||||
|
|
||||||
|
|
||||||
def test_create_agent_anthropic(mock_model, mock_config_repository):
|
def test_create_agent_anthropic(mock_model, mock_config_repository):
|
||||||
"""Test create_agent with Anthropic Claude model."""
|
"""Test create_agent with Anthropic Claude model."""
|
||||||
mock_config_repository.update({"provider": "anthropic", "model": "claude-2"})
|
mock_config_repository.update({"provider": "anthropic", "model": "claude-2"})
|
||||||
|
|
||||||
with patch("ra_aid.agent_utils.create_react_agent") as mock_react, \
|
with (
|
||||||
patch("ra_aid.anthropic_token_limiter.state_modifier") as mock_state_modifier:
|
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
||||||
|
patch("ra_aid.anthropic_token_limiter.state_modifier") as mock_state_modifier,
|
||||||
|
):
|
||||||
mock_react.return_value = "react_agent"
|
mock_react.return_value = "react_agent"
|
||||||
agent = create_agent(mock_model, [])
|
agent = create_agent(mock_model, [])
|
||||||
|
|
||||||
|
|
@ -81,7 +130,7 @@ def test_create_agent_anthropic(mock_model, mock_config_repository):
|
||||||
mock_react.assert_called_once_with(
|
mock_react.assert_called_once_with(
|
||||||
mock_model,
|
mock_model,
|
||||||
[],
|
[],
|
||||||
interrupt_after=['tools'],
|
interrupt_after=["tools"],
|
||||||
version="v2",
|
version="v2",
|
||||||
state_modifier=mock_react.call_args[1]["state_modifier"],
|
state_modifier=mock_react.call_args[1]["state_modifier"],
|
||||||
name="React",
|
name="React",
|
||||||
|
|
@ -173,13 +222,17 @@ def test_create_agent_with_checkpointer(mock_model, mock_config_repository):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_config_repository):
|
def test_create_agent_anthropic_token_limiting_enabled(
|
||||||
|
mock_model, mock_config_repository
|
||||||
|
):
|
||||||
"""Test create_agent sets up token limiting for Claude models when enabled."""
|
"""Test create_agent sets up token limiting for Claude models when enabled."""
|
||||||
mock_config_repository.update({
|
mock_config_repository.update(
|
||||||
"provider": "anthropic",
|
{
|
||||||
"model": "claude-2",
|
"provider": "anthropic",
|
||||||
"limit_tokens": True,
|
"model": "claude-2",
|
||||||
})
|
"limit_tokens": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
||||||
|
|
@ -196,13 +249,17 @@ def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_config_r
|
||||||
assert callable(args[1]["state_modifier"])
|
assert callable(args[1]["state_modifier"])
|
||||||
|
|
||||||
|
|
||||||
def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_repository):
|
def test_create_agent_anthropic_token_limiting_disabled(
|
||||||
|
mock_model, mock_config_repository
|
||||||
|
):
|
||||||
"""Test create_agent doesn't set up token limiting for Claude models when disabled."""
|
"""Test create_agent doesn't set up token limiting for Claude models when disabled."""
|
||||||
mock_config_repository.update({
|
mock_config_repository.update(
|
||||||
"provider": "anthropic",
|
{
|
||||||
"model": "claude-2",
|
"provider": "anthropic",
|
||||||
"limit_tokens": False,
|
"model": "claude-2",
|
||||||
})
|
"limit_tokens": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
||||||
|
|
@ -214,7 +271,9 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_
|
||||||
agent = create_agent(mock_model, [])
|
agent = create_agent(mock_model, [])
|
||||||
|
|
||||||
assert agent == "react_agent"
|
assert agent == "react_agent"
|
||||||
mock_react.assert_called_once_with(mock_model, [], interrupt_after=['tools'], version="v2", name="React")
|
mock_react.assert_called_once_with(
|
||||||
|
mock_model, [], interrupt_after=["tools"], version="v2", name="React"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# These tests have been moved to test_anthropic_token_limiter.py
|
# These tests have been moved to test_anthropic_token_limiter.py
|
||||||
|
|
@ -482,7 +541,9 @@ def test_run_agent_with_retry_checks_crash_status(monkeypatch, mock_config_repos
|
||||||
assert "Agent has crashed: Test crash message" in result
|
assert "Agent has crashed: Test crash message" in result
|
||||||
|
|
||||||
|
|
||||||
def test_run_agent_with_retry_handles_badrequest_error(monkeypatch, mock_config_repository):
|
def test_run_agent_with_retry_handles_badrequest_error(
|
||||||
|
monkeypatch, mock_config_repository
|
||||||
|
):
|
||||||
"""Test that run_agent_with_retry properly handles BadRequestError as unretryable."""
|
"""Test that run_agent_with_retry properly handles BadRequestError as unretryable."""
|
||||||
from ra_aid.agent_context import agent_context, is_crashed
|
from ra_aid.agent_context import agent_context, is_crashed
|
||||||
from ra_aid.agent_utils import run_agent_with_retry
|
from ra_aid.agent_utils import run_agent_with_retry
|
||||||
|
|
@ -540,7 +601,9 @@ def test_run_agent_with_retry_handles_badrequest_error(monkeypatch, mock_config_
|
||||||
assert is_crashed()
|
assert is_crashed()
|
||||||
|
|
||||||
|
|
||||||
def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch, mock_config_repository):
|
def test_run_agent_with_retry_handles_api_badrequest_error(
|
||||||
|
monkeypatch, mock_config_repository
|
||||||
|
):
|
||||||
"""Test that run_agent_with_retry properly handles API BadRequestError as unretryable."""
|
"""Test that run_agent_with_retry properly handles API BadRequestError as unretryable."""
|
||||||
# Import APIError from anthropic module and patch it on the agent_utils module
|
# Import APIError from anthropic module and patch it on the agent_utils module
|
||||||
|
|
||||||
|
|
@ -613,5 +676,7 @@ def test_handle_api_error_resource_exhausted():
|
||||||
from ra_aid.agent_utils import _handle_api_error
|
from ra_aid.agent_utils import _handle_api_error
|
||||||
|
|
||||||
# ResourceExhausted exception should be handled without raising
|
# ResourceExhausted exception should be handled without raising
|
||||||
resource_exhausted_error = ResourceExhausted("429 Resource has been exhausted (e.g. check quota).")
|
resource_exhausted_error = ResourceExhausted(
|
||||||
|
"429 Resource has been exhausted (e.g. check quota)."
|
||||||
|
)
|
||||||
_handle_api_error(resource_exhausted_error, 0, 5, 1)
|
_handle_api_error(resource_exhausted_error, 0, 5, 1)
|
||||||
|
|
|
||||||
|
|
@ -113,6 +113,40 @@ def mock_work_log_repository():
|
||||||
|
|
||||||
yield mock_repo
|
yield mock_repo
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_trajectory_repository():
|
||||||
|
"""Mock the TrajectoryRepository to avoid database operations during tests"""
|
||||||
|
with patch('ra_aid.database.repositories.trajectory_repository.trajectory_repo_var') as mock_repo_var:
|
||||||
|
# Setup a mock repository
|
||||||
|
mock_repo = MagicMock()
|
||||||
|
|
||||||
|
# Setup create method to return a mock trajectory
|
||||||
|
def mock_create(**kwargs):
|
||||||
|
mock_trajectory = MagicMock()
|
||||||
|
mock_trajectory.id = 1
|
||||||
|
return mock_trajectory
|
||||||
|
mock_repo.create.side_effect = mock_create
|
||||||
|
|
||||||
|
# Make the mock context var return our mock repo
|
||||||
|
mock_repo_var.get.return_value = mock_repo
|
||||||
|
|
||||||
|
yield mock_repo
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_human_input_repository():
|
||||||
|
"""Mock the HumanInputRepository to avoid database operations during tests"""
|
||||||
|
with patch('ra_aid.database.repositories.human_input_repository.human_input_repo_var') as mock_repo_var:
|
||||||
|
# Setup a mock repository
|
||||||
|
mock_repo = MagicMock()
|
||||||
|
|
||||||
|
# Setup get_most_recent_id method to return a dummy ID
|
||||||
|
mock_repo.get_most_recent_id.return_value = 1
|
||||||
|
|
||||||
|
# Make the mock context var return our mock repo
|
||||||
|
mock_repo_var.get.return_value = mock_repo
|
||||||
|
|
||||||
|
yield mock_repo
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_functions():
|
def mock_functions():
|
||||||
"""Mock functions used in agent.py"""
|
"""Mock functions used in agent.py"""
|
||||||
|
|
@ -126,7 +160,9 @@ def mock_functions():
|
||||||
patch('ra_aid.tools.agent.get_related_files') as mock_get_files, \
|
patch('ra_aid.tools.agent.get_related_files') as mock_get_files, \
|
||||||
patch('ra_aid.tools.agent.get_work_log') as mock_get_work_log, \
|
patch('ra_aid.tools.agent.get_work_log') as mock_get_work_log, \
|
||||||
patch('ra_aid.tools.agent.reset_completion_flags') as mock_reset, \
|
patch('ra_aid.tools.agent.reset_completion_flags') as mock_reset, \
|
||||||
patch('ra_aid.tools.agent.get_completion_message') as mock_get_completion:
|
patch('ra_aid.tools.agent.get_completion_message') as mock_get_completion, \
|
||||||
|
patch('ra_aid.tools.agent.get_trajectory_repository') as mock_get_trajectory_repo, \
|
||||||
|
patch('ra_aid.tools.agent.get_human_input_repository') as mock_get_human_input_repo:
|
||||||
|
|
||||||
# Setup mock return values
|
# Setup mock return values
|
||||||
mock_fact_repo.get_facts_dict.return_value = {1: "Test fact 1", 2: "Test fact 2"}
|
mock_fact_repo.get_facts_dict.return_value = {1: "Test fact 1", 2: "Test fact 2"}
|
||||||
|
|
@ -138,6 +174,15 @@ def mock_functions():
|
||||||
mock_get_work_log.return_value = "Test work log"
|
mock_get_work_log.return_value = "Test work log"
|
||||||
mock_get_completion.return_value = "Task completed"
|
mock_get_completion.return_value = "Task completed"
|
||||||
|
|
||||||
|
# Setup mock for trajectory repository
|
||||||
|
mock_trajectory_repo = MagicMock()
|
||||||
|
mock_get_trajectory_repo.return_value = mock_trajectory_repo
|
||||||
|
|
||||||
|
# Setup mock for human input repository
|
||||||
|
mock_human_input_repo = MagicMock()
|
||||||
|
mock_human_input_repo.get_most_recent_id.return_value = 1
|
||||||
|
mock_get_human_input_repo.return_value = mock_human_input_repo
|
||||||
|
|
||||||
# Return all mocks as a dictionary
|
# Return all mocks as a dictionary
|
||||||
yield {
|
yield {
|
||||||
'get_key_fact_repository': mock_get_fact_repo,
|
'get_key_fact_repository': mock_get_fact_repo,
|
||||||
|
|
@ -148,7 +193,9 @@ def mock_functions():
|
||||||
'get_related_files': mock_get_files,
|
'get_related_files': mock_get_files,
|
||||||
'get_work_log': mock_get_work_log,
|
'get_work_log': mock_get_work_log,
|
||||||
'reset_completion_flags': mock_reset,
|
'reset_completion_flags': mock_reset,
|
||||||
'get_completion_message': mock_get_completion
|
'get_completion_message': mock_get_completion,
|
||||||
|
'get_trajectory_repository': mock_get_trajectory_repo,
|
||||||
|
'get_human_input_repository': mock_get_human_input_repo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,40 @@ def mock_config_repository():
|
||||||
|
|
||||||
yield mock_repo
|
yield mock_repo
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_trajectory_repository():
|
||||||
|
"""Mock the TrajectoryRepository to avoid database operations during tests"""
|
||||||
|
with patch('ra_aid.database.repositories.trajectory_repository.trajectory_repo_var') as mock_repo_var:
|
||||||
|
# Setup a mock repository
|
||||||
|
mock_repo = MagicMock()
|
||||||
|
|
||||||
|
# Setup create method to return a mock trajectory
|
||||||
|
def mock_create(**kwargs):
|
||||||
|
mock_trajectory = MagicMock()
|
||||||
|
mock_trajectory.id = 1
|
||||||
|
return mock_trajectory
|
||||||
|
mock_repo.create.side_effect = mock_create
|
||||||
|
|
||||||
|
# Make the mock context var return our mock repo
|
||||||
|
mock_repo_var.get.return_value = mock_repo
|
||||||
|
|
||||||
|
yield mock_repo
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_human_input_repository():
|
||||||
|
"""Mock the HumanInputRepository to avoid database operations during tests"""
|
||||||
|
with patch('ra_aid.database.repositories.human_input_repository.human_input_repo_var') as mock_repo_var:
|
||||||
|
# Setup a mock repository
|
||||||
|
mock_repo = MagicMock()
|
||||||
|
|
||||||
|
# Setup get_most_recent_id method to return a dummy ID
|
||||||
|
mock_repo.get_most_recent_id.return_value = 1
|
||||||
|
|
||||||
|
# Make the mock context var return our mock repo
|
||||||
|
mock_repo_var.get.return_value = mock_repo
|
||||||
|
|
||||||
|
yield mock_repo
|
||||||
|
|
||||||
|
|
||||||
def test_shell_command_cowboy_mode(mock_console, mock_prompt, mock_run_interactive, mock_config_repository):
|
def test_shell_command_cowboy_mode(mock_console, mock_prompt, mock_run_interactive, mock_config_repository):
|
||||||
"""Test shell command execution in cowboy mode (no approval)"""
|
"""Test shell command execution in cowboy mode (no approval)"""
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
import { defineConfig } from '@vscode/test-cli';
|
||||||
|
|
||||||
|
export default defineConfig({
|
||||||
|
files: 'out/test/**/*.test.js',
|
||||||
|
});
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
{
|
||||||
|
// See http://go.microsoft.com/fwlink/?LinkId=827846
|
||||||
|
// for the documentation about the extensions.json format
|
||||||
|
"recommendations": ["dbaeumer.vscode-eslint", "connor4312.esbuild-problem-matchers", "ms-vscode.extension-test-runner"]
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
// A launch configuration that compiles the extension and then opens it inside a new window
|
||||||
|
// Use IntelliSense to learn about possible attributes.
|
||||||
|
// Hover to view descriptions of existing attributes.
|
||||||
|
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
|
{
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
"name": "Run Extension",
|
||||||
|
"type": "extensionHost",
|
||||||
|
"request": "launch",
|
||||||
|
"args": [
|
||||||
|
"--extensionDevelopmentPath=${workspaceFolder}"
|
||||||
|
],
|
||||||
|
"outFiles": [
|
||||||
|
"${workspaceFolder}/dist/**/*.js"
|
||||||
|
],
|
||||||
|
"preLaunchTask": "${defaultBuildTask}"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,13 @@
|
||||||
|
// Place your settings in this file to overwrite default and user settings.
|
||||||
|
{
|
||||||
|
"files.exclude": {
|
||||||
|
"out": false, // set this to true to hide the "out" folder with the compiled JS files
|
||||||
|
"dist": false // set this to true to hide the "dist" folder with the compiled JS files
|
||||||
|
},
|
||||||
|
"search.exclude": {
|
||||||
|
"out": true, // set this to false to include "out" folder in search results
|
||||||
|
"dist": true // set this to false to include "dist" folder in search results
|
||||||
|
},
|
||||||
|
// Turn off tsc task auto detection since we have the necessary tasks as npm scripts
|
||||||
|
"typescript.tsc.autoDetect": "off"
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,64 @@
|
||||||
|
// See https://go.microsoft.com/fwlink/?LinkId=733558
|
||||||
|
// for the documentation about the tasks.json format
|
||||||
|
{
|
||||||
|
"version": "2.0.0",
|
||||||
|
"tasks": [
|
||||||
|
{
|
||||||
|
"label": "watch",
|
||||||
|
"dependsOn": [
|
||||||
|
"npm: watch:tsc",
|
||||||
|
"npm: watch:esbuild"
|
||||||
|
],
|
||||||
|
"presentation": {
|
||||||
|
"reveal": "never"
|
||||||
|
},
|
||||||
|
"group": {
|
||||||
|
"kind": "build",
|
||||||
|
"isDefault": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "npm",
|
||||||
|
"script": "watch:esbuild",
|
||||||
|
"group": "build",
|
||||||
|
"problemMatcher": "$esbuild-watch",
|
||||||
|
"isBackground": true,
|
||||||
|
"label": "npm: watch:esbuild",
|
||||||
|
"presentation": {
|
||||||
|
"group": "watch",
|
||||||
|
"reveal": "never"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "npm",
|
||||||
|
"script": "watch:tsc",
|
||||||
|
"group": "build",
|
||||||
|
"problemMatcher": "$tsc-watch",
|
||||||
|
"isBackground": true,
|
||||||
|
"label": "npm: watch:tsc",
|
||||||
|
"presentation": {
|
||||||
|
"group": "watch",
|
||||||
|
"reveal": "never"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "npm",
|
||||||
|
"script": "watch-tests",
|
||||||
|
"problemMatcher": "$tsc-watch",
|
||||||
|
"isBackground": true,
|
||||||
|
"presentation": {
|
||||||
|
"reveal": "never",
|
||||||
|
"group": "watchers"
|
||||||
|
},
|
||||||
|
"group": "build"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"label": "tasks: watch-tests",
|
||||||
|
"dependsOn": [
|
||||||
|
"npm: watch",
|
||||||
|
"npm: watch-tests"
|
||||||
|
],
|
||||||
|
"problemMatcher": []
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
.vscode/**
|
||||||
|
.vscode-test/**
|
||||||
|
out/**
|
||||||
|
node_modules/**
|
||||||
|
src/**
|
||||||
|
.gitignore
|
||||||
|
.yarnrc
|
||||||
|
esbuild.js
|
||||||
|
vsc-extension-quickstart.md
|
||||||
|
**/tsconfig.json
|
||||||
|
**/eslint.config.mjs
|
||||||
|
**/*.map
|
||||||
|
**/*.ts
|
||||||
|
**/.vscode-test.*
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
# Change Log
|
||||||
|
|
||||||
|
All notable changes to the "ra-aid" extension will be documented in this file.
|
||||||
|
|
||||||
|
Check [Keep a Changelog](http://keepachangelog.com/) for recommendations on how to structure this file.
|
||||||
|
|
||||||
|
## [Unreleased]
|
||||||
|
|
||||||
|
- Initial release
|
||||||
|
|
@ -0,0 +1,71 @@
|
||||||
|
# ra-aid README
|
||||||
|
|
||||||
|
This is the README for your extension "ra-aid". After writing up a brief description, we recommend including the following sections.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
Describe specific features of your extension including screenshots of your extension in action. Image paths are relative to this README file.
|
||||||
|
|
||||||
|
For example if there is an image subfolder under your extension project workspace:
|
||||||
|
|
||||||
|
\!\[feature X\]\(images/feature-x.png\)
|
||||||
|
|
||||||
|
> Tip: Many popular extensions utilize animations. This is an excellent way to show off your extension! We recommend short, focused animations that are easy to follow.
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
If you have any requirements or dependencies, add a section describing those and how to install and configure them.
|
||||||
|
|
||||||
|
## Extension Settings
|
||||||
|
|
||||||
|
Include if your extension adds any VS Code settings through the `contributes.configuration` extension point.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
This extension contributes the following settings:
|
||||||
|
|
||||||
|
* `myExtension.enable`: Enable/disable this extension.
|
||||||
|
* `myExtension.thing`: Set to `blah` to do something.
|
||||||
|
|
||||||
|
## Known Issues
|
||||||
|
|
||||||
|
Calling out known issues can help limit users opening duplicate issues against your extension.
|
||||||
|
|
||||||
|
## Release Notes
|
||||||
|
|
||||||
|
Users appreciate release notes as you update your extension.
|
||||||
|
|
||||||
|
### 1.0.0
|
||||||
|
|
||||||
|
Initial release of ...
|
||||||
|
|
||||||
|
### 1.0.1
|
||||||
|
|
||||||
|
Fixed issue #.
|
||||||
|
|
||||||
|
### 1.1.0
|
||||||
|
|
||||||
|
Added features X, Y, and Z.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Following extension guidelines
|
||||||
|
|
||||||
|
Ensure that you've read through the extensions guidelines and follow the best practices for creating your extension.
|
||||||
|
|
||||||
|
* [Extension Guidelines](https://code.visualstudio.com/api/references/extension-guidelines)
|
||||||
|
|
||||||
|
## Working with Markdown
|
||||||
|
|
||||||
|
You can author your README using Visual Studio Code. Here are some useful editor keyboard shortcuts:
|
||||||
|
|
||||||
|
* Split the editor (`Cmd+\` on macOS or `Ctrl+\` on Windows and Linux).
|
||||||
|
* Toggle preview (`Shift+Cmd+V` on macOS or `Shift+Ctrl+V` on Windows and Linux).
|
||||||
|
* Press `Ctrl+Space` (Windows, Linux, macOS) to see a list of Markdown snippets.
|
||||||
|
|
||||||
|
## For more information
|
||||||
|
|
||||||
|
* [Visual Studio Code's Markdown Support](http://code.visualstudio.com/docs/languages/markdown)
|
||||||
|
* [Markdown Syntax Reference](https://help.github.com/articles/markdown-basics/)
|
||||||
|
|
||||||
|
**Enjoy!**
|
||||||
Binary file not shown.
|
After Width: | Height: | Size: 6.5 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 6.6 KiB |
|
|
@ -0,0 +1,56 @@
|
||||||
|
const esbuild = require("esbuild");
|
||||||
|
|
||||||
|
const production = process.argv.includes('--production');
|
||||||
|
const watch = process.argv.includes('--watch');
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @type {import('esbuild').Plugin}
|
||||||
|
*/
|
||||||
|
const esbuildProblemMatcherPlugin = {
|
||||||
|
name: 'esbuild-problem-matcher',
|
||||||
|
|
||||||
|
setup(build) {
|
||||||
|
build.onStart(() => {
|
||||||
|
console.log('[watch] build started');
|
||||||
|
});
|
||||||
|
build.onEnd((result) => {
|
||||||
|
result.errors.forEach(({ text, location }) => {
|
||||||
|
console.error(`✘ [ERROR] ${text}`);
|
||||||
|
console.error(` ${location.file}:${location.line}:${location.column}:`);
|
||||||
|
});
|
||||||
|
console.log('[watch] build finished');
|
||||||
|
});
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
async function main() {
|
||||||
|
const ctx = await esbuild.context({
|
||||||
|
entryPoints: [
|
||||||
|
'src/extension.ts'
|
||||||
|
],
|
||||||
|
bundle: true,
|
||||||
|
format: 'cjs',
|
||||||
|
minify: production,
|
||||||
|
sourcemap: !production,
|
||||||
|
sourcesContent: false,
|
||||||
|
platform: 'node',
|
||||||
|
outfile: 'dist/extension.js',
|
||||||
|
external: ['vscode'],
|
||||||
|
logLevel: 'silent',
|
||||||
|
plugins: [
|
||||||
|
/* add to the end of plugins array */
|
||||||
|
esbuildProblemMatcherPlugin,
|
||||||
|
],
|
||||||
|
});
|
||||||
|
if (watch) {
|
||||||
|
await ctx.watch();
|
||||||
|
} else {
|
||||||
|
await ctx.rebuild();
|
||||||
|
await ctx.dispose();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
main().catch(e => {
|
||||||
|
console.error(e);
|
||||||
|
process.exit(1);
|
||||||
|
});
|
||||||
|
|
@ -0,0 +1,28 @@
|
||||||
|
import typescriptEslint from "@typescript-eslint/eslint-plugin";
|
||||||
|
import tsParser from "@typescript-eslint/parser";
|
||||||
|
|
||||||
|
export default [{
|
||||||
|
files: ["**/*.ts"],
|
||||||
|
}, {
|
||||||
|
plugins: {
|
||||||
|
"@typescript-eslint": typescriptEslint,
|
||||||
|
},
|
||||||
|
|
||||||
|
languageOptions: {
|
||||||
|
parser: tsParser,
|
||||||
|
ecmaVersion: 2022,
|
||||||
|
sourceType: "module",
|
||||||
|
},
|
||||||
|
|
||||||
|
rules: {
|
||||||
|
"@typescript-eslint/naming-convention": ["warn", {
|
||||||
|
selector: "import",
|
||||||
|
format: ["camelCase", "PascalCase"],
|
||||||
|
}],
|
||||||
|
|
||||||
|
curly: "warn",
|
||||||
|
eqeqeq: "warn",
|
||||||
|
"no-throw-literal": "warn",
|
||||||
|
semi: "warn",
|
||||||
|
},
|
||||||
|
}];
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,67 @@
|
||||||
|
{
|
||||||
|
"name": "ra-aid",
|
||||||
|
"displayName": "RA.Aid",
|
||||||
|
"description": "Develop software autonomously.",
|
||||||
|
"version": "0.0.1",
|
||||||
|
"engines": {
|
||||||
|
"vscode": "^1.98.0"
|
||||||
|
},
|
||||||
|
"categories": [
|
||||||
|
"Other"
|
||||||
|
],
|
||||||
|
"activationEvents": [],
|
||||||
|
"main": "./dist/extension.js",
|
||||||
|
"contributes": {
|
||||||
|
"viewsContainers": {
|
||||||
|
"activitybar": [
|
||||||
|
{
|
||||||
|
"id": "ra-aid-view",
|
||||||
|
"title": "RA.Aid",
|
||||||
|
"icon": "assets/RA-white-transp.png"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"views": {
|
||||||
|
"ra-aid-view": [
|
||||||
|
{
|
||||||
|
"type": "webview",
|
||||||
|
"id": "ra-aid.view",
|
||||||
|
"name": "RA.Aid"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"commands": [
|
||||||
|
{
|
||||||
|
"command": "ra-aid.helloWorld",
|
||||||
|
"title": "Hello World"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"scripts": {
|
||||||
|
"vscode:prepublish": "npm run package",
|
||||||
|
"compile": "npm run check-types && npm run lint && node esbuild.js",
|
||||||
|
"watch": "npm-run-all -p watch:*",
|
||||||
|
"watch:esbuild": "node esbuild.js --watch",
|
||||||
|
"watch:tsc": "tsc --noEmit --watch --project tsconfig.json",
|
||||||
|
"package": "npm run check-types && npm run lint && node esbuild.js --production",
|
||||||
|
"compile-tests": "tsc -p . --outDir out",
|
||||||
|
"watch-tests": "tsc -p . -w --outDir out",
|
||||||
|
"pretest": "npm run compile-tests && npm run compile && npm run lint",
|
||||||
|
"check-types": "tsc --noEmit",
|
||||||
|
"lint": "eslint src",
|
||||||
|
"test": "vscode-test"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"@types/vscode": "^1.98.0",
|
||||||
|
"@types/mocha": "^10.0.10",
|
||||||
|
"@types/node": "20.x",
|
||||||
|
"@typescript-eslint/eslint-plugin": "^8.25.0",
|
||||||
|
"@typescript-eslint/parser": "^8.25.0",
|
||||||
|
"eslint": "^9.21.0",
|
||||||
|
"esbuild": "^0.25.0",
|
||||||
|
"npm-run-all": "^4.1.5",
|
||||||
|
"typescript": "^5.7.3",
|
||||||
|
"@vscode/test-cli": "^0.0.10",
|
||||||
|
"@vscode/test-electron": "^2.4.1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,133 @@
|
||||||
|
// The module 'vscode' contains the VS Code extensibility API
|
||||||
|
import * as vscode from 'vscode';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* WebviewViewProvider implementation for the RA.Aid panel
|
||||||
|
*/
|
||||||
|
class RAWebviewViewProvider implements vscode.WebviewViewProvider {
|
||||||
|
constructor(private readonly _extensionUri: vscode.Uri) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called when a view is first created to initialize the webview
|
||||||
|
*/
|
||||||
|
public resolveWebviewView(
|
||||||
|
webviewView: vscode.WebviewView,
|
||||||
|
context: vscode.WebviewViewResolveContext,
|
||||||
|
_token: vscode.CancellationToken
|
||||||
|
) {
|
||||||
|
// Set options for the webview
|
||||||
|
webviewView.webview.options = {
|
||||||
|
// Enable JavaScript in the webview
|
||||||
|
enableScripts: true,
|
||||||
|
// Restrict the webview to only load resources from the extension's directory
|
||||||
|
localResourceRoots: [this._extensionUri]
|
||||||
|
};
|
||||||
|
|
||||||
|
// Set the HTML content of the webview
|
||||||
|
webviewView.webview.html = this._getHtmlForWebview(webviewView.webview);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates HTML content for the webview with proper security policies
|
||||||
|
*/
|
||||||
|
private _getHtmlForWebview(webview: vscode.Webview): string {
|
||||||
|
// Create a URI to the extension's assets directory
|
||||||
|
const logoUri = webview.asWebviewUri(vscode.Uri.joinPath(this._extensionUri, 'assets', 'RA.png'));
|
||||||
|
|
||||||
|
// Create a URI to the script file
|
||||||
|
// const scriptUri = webview.asWebviewUri(vscode.Uri.joinPath(this._extensionUri, 'dist', 'webview.js'));
|
||||||
|
|
||||||
|
// Use a nonce to whitelist scripts
|
||||||
|
const nonce = getNonce();
|
||||||
|
|
||||||
|
return `<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta http-equiv="Content-Security-Policy" content="default-src 'none'; img-src ${webview.cspSource} https:; style-src ${webview.cspSource} 'unsafe-inline'; script-src 'nonce-${nonce}';">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>RA.Aid</title>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
padding: 0;
|
||||||
|
color: var(--vscode-foreground);
|
||||||
|
font-size: var(--vscode-font-size);
|
||||||
|
font-weight: var(--vscode-font-weight);
|
||||||
|
font-family: var(--vscode-font-family);
|
||||||
|
background-color: var(--vscode-editor-background);
|
||||||
|
}
|
||||||
|
.container {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
padding: 20px;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
.logo {
|
||||||
|
width: 100px;
|
||||||
|
height: 100px;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
}
|
||||||
|
h1 {
|
||||||
|
color: var(--vscode-editor-foreground);
|
||||||
|
font-size: 1.3em;
|
||||||
|
margin-bottom: 15px;
|
||||||
|
}
|
||||||
|
p {
|
||||||
|
color: var(--vscode-foreground);
|
||||||
|
margin-bottom: 10px;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<img src="${logoUri}" alt="RA.Aid Logo" class="logo">
|
||||||
|
<h1>RA.Aid</h1>
|
||||||
|
<p>Your research and development assistant.</p>
|
||||||
|
<p>More features coming soon!</p>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generates a random nonce for CSP
|
||||||
|
*/
|
||||||
|
function getNonce() {
|
||||||
|
let text = '';
|
||||||
|
const possible = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789';
|
||||||
|
for (let i = 0; i < 32; i++) {
|
||||||
|
text += possible.charAt(Math.floor(Math.random() * possible.length));
|
||||||
|
}
|
||||||
|
return text;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This method is called when your extension is activated
|
||||||
|
export function activate(context: vscode.ExtensionContext) {
|
||||||
|
// Use the console to output diagnostic information (console.log) and errors (console.error)
|
||||||
|
console.log('Congratulations, your extension "ra-aid" is now active!');
|
||||||
|
|
||||||
|
// Register the WebviewViewProvider
|
||||||
|
const provider = new RAWebviewViewProvider(context.extensionUri);
|
||||||
|
const viewRegistration = vscode.window.registerWebviewViewProvider(
|
||||||
|
'ra-aid.view', // Must match the view id in package.json
|
||||||
|
provider
|
||||||
|
);
|
||||||
|
context.subscriptions.push(viewRegistration);
|
||||||
|
|
||||||
|
// The command has been defined in the package.json file
|
||||||
|
// Now provide the implementation of the command with registerCommand
|
||||||
|
// The commandId parameter must match the command field in package.json
|
||||||
|
const disposable = vscode.commands.registerCommand('ra-aid.helloWorld', () => {
|
||||||
|
// The code you place here will be executed every time your command is executed
|
||||||
|
// Display a message box to the user
|
||||||
|
vscode.window.showInformationMessage('Hello World from RA.Aid!');
|
||||||
|
});
|
||||||
|
|
||||||
|
context.subscriptions.push(disposable);
|
||||||
|
}
|
||||||
|
|
||||||
|
// This method is called when your extension is deactivated
|
||||||
|
export function deactivate() {}
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
import * as assert from 'assert';
|
||||||
|
|
||||||
|
// You can import and use all API from the 'vscode' module
|
||||||
|
// as well as import your extension to test it
|
||||||
|
import * as vscode from 'vscode';
|
||||||
|
// import * as myExtension from '../../extension';
|
||||||
|
|
||||||
|
suite('Extension Test Suite', () => {
|
||||||
|
vscode.window.showInformationMessage('Start all tests.');
|
||||||
|
|
||||||
|
test('Sample test', () => {
|
||||||
|
assert.strictEqual(-1, [1, 2, 3].indexOf(5));
|
||||||
|
assert.strictEqual(-1, [1, 2, 3].indexOf(0));
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
{
|
||||||
|
"compilerOptions": {
|
||||||
|
"module": "Node16",
|
||||||
|
"target": "ES2022",
|
||||||
|
"lib": [
|
||||||
|
"ES2022"
|
||||||
|
],
|
||||||
|
"sourceMap": true,
|
||||||
|
"rootDir": "src",
|
||||||
|
"strict": true, /* enable all strict type-checking options */
|
||||||
|
/* Additional Checks */
|
||||||
|
// "noImplicitReturns": true, /* Report error when not all code paths in function return a value. */
|
||||||
|
// "noFallthroughCasesInSwitch": true, /* Report errors for fallthrough cases in switch statement. */
|
||||||
|
// "noUnusedParameters": true, /* Report errors on unused parameters. */
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
# Welcome to your VS Code Extension
|
||||||
|
|
||||||
|
## What's in the folder
|
||||||
|
|
||||||
|
* This folder contains all of the files necessary for your extension.
|
||||||
|
* `package.json` - this is the manifest file in which you declare your extension and command.
|
||||||
|
* The sample plugin registers a command and defines its title and command name. With this information VS Code can show the command in the command palette. It doesn’t yet need to load the plugin.
|
||||||
|
* `src/extension.ts` - this is the main file where you will provide the implementation of your command.
|
||||||
|
* The file exports one function, `activate`, which is called the very first time your extension is activated (in this case by executing the command). Inside the `activate` function we call `registerCommand`.
|
||||||
|
* We pass the function containing the implementation of the command as the second parameter to `registerCommand`.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
* install the recommended extensions (amodio.tsl-problem-matcher, ms-vscode.extension-test-runner, and dbaeumer.vscode-eslint)
|
||||||
|
|
||||||
|
|
||||||
|
## Get up and running straight away
|
||||||
|
|
||||||
|
* Press `F5` to open a new window with your extension loaded.
|
||||||
|
* Run your command from the command palette by pressing (`Ctrl+Shift+P` or `Cmd+Shift+P` on Mac) and typing `Hello World`.
|
||||||
|
* Set breakpoints in your code inside `src/extension.ts` to debug your extension.
|
||||||
|
* Find output from your extension in the debug console.
|
||||||
|
|
||||||
|
## Make changes
|
||||||
|
|
||||||
|
* You can relaunch the extension from the debug toolbar after changing code in `src/extension.ts`.
|
||||||
|
* You can also reload (`Ctrl+R` or `Cmd+R` on Mac) the VS Code window with your extension to load your changes.
|
||||||
|
|
||||||
|
|
||||||
|
## Explore the API
|
||||||
|
|
||||||
|
* You can open the full set of our API when you open the file `node_modules/@types/vscode/index.d.ts`.
|
||||||
|
|
||||||
|
## Run tests
|
||||||
|
|
||||||
|
* Install the [Extension Test Runner](https://marketplace.visualstudio.com/items?itemName=ms-vscode.extension-test-runner)
|
||||||
|
* Run the "watch" task via the **Tasks: Run Task** command. Make sure this is running, or tests might not be discovered.
|
||||||
|
* Open the Testing view from the activity bar and click the Run Test" button, or use the hotkey `Ctrl/Cmd + ; A`
|
||||||
|
* See the output of the test result in the Test Results view.
|
||||||
|
* Make changes to `src/test/extension.test.ts` or create new test files inside the `test` folder.
|
||||||
|
* The provided test runner will only consider files matching the name pattern `**.test.ts`.
|
||||||
|
* You can create folders inside the `test` folder to structure your tests any way you want.
|
||||||
|
|
||||||
|
## Go further
|
||||||
|
|
||||||
|
* Reduce the extension size and improve the startup time by [bundling your extension](https://code.visualstudio.com/api/working-with-extensions/bundling-extension).
|
||||||
|
* [Publish your extension](https://code.visualstudio.com/api/working-with-extensions/publishing-extension) on the VS Code extension marketplace.
|
||||||
|
* Automate builds by setting up [Continuous Integration](https://code.visualstudio.com/api/working-with-extensions/continuous-integration).
|
||||||
Loading…
Reference in New Issue