feat: add session and trajectory models to track application state and events

- Introduce a new `Session` model to store information about each program run, including command line arguments and environment details.
- Implement a `Trajectory` model to log significant events and errors during execution, enhancing debugging and monitoring capabilities.
- Update various repository classes to support session and trajectory management, allowing for better tracking of user interactions and system behavior.
- Modify existing functions to record relevant events in the trajectory, ensuring comprehensive logging of application activities.
- Enhance error handling by logging errors to the trajectory, providing insights into failures and system performance.

feat(vsc): add initial setup for VS Code extension "ra-aid" with essential files and configurations
chore(vsc): create tasks.json for managing build and watch tasks in VS Code
chore(vsc): add .vscodeignore to exclude unnecessary files from the extension package
docs(vsc): create CHANGELOG.md to document changes and updates for the extension
docs(vsc): add README.md with instructions and information about the extension
feat(vsc): include esbuild.js for building and bundling the extension
chore(vsc): add eslint.config.mjs for TypeScript linting configuration
chore(vsc): create package.json with dependencies and scripts for the extension
feat(vsc): implement extension logic in src/extension.ts with webview support
test(vsc): add initial test suite in extension.test.ts for extension functionality
chore(vsc): create tsconfig.json for TypeScript compiler options
docs(vsc): add vsc-extension-quickstart.md for guidance on extension development
This commit is contained in:
Ariel Frischer 2025-03-12 11:47:21 -07:00
commit 77a256317a
58 changed files with 9077 additions and 110 deletions

2
.gitignore vendored
View File

@ -14,3 +14,5 @@ __pycache__/
.envrc
appmap.log
*.swp
/vsc/node_modules
/vsc/dist

View File

@ -64,6 +64,9 @@ from ra_aid.database.repositories.trajectory_repository import (
TrajectoryRepositoryManager,
get_trajectory_repository,
)
from ra_aid.database.repositories.session_repository import (
SessionRepositoryManager, get_session_repository
)
from ra_aid.database.repositories.related_files_repository import (
RelatedFilesRepositoryManager,
)
@ -298,6 +301,11 @@ Examples:
action="store_true",
help="Display model thinking content extracted from think tags when supported by the model",
)
parser.add_argument(
"--show-cost",
action="store_true",
help="Display cost information as the agent works",
)
parser.add_argument(
"--reasoning-assistance",
action="store_true",
@ -538,18 +546,18 @@ def main():
env_discovery.discover()
env_data = env_discovery.format_markdown()
with (
KeyFactRepositoryManager(db) as key_fact_repo,
KeySnippetRepositoryManager(db) as key_snippet_repo,
HumanInputRepositoryManager(db) as human_input_repo,
ResearchNoteRepositoryManager(db) as research_note_repo,
RelatedFilesRepositoryManager() as related_files_repo,
TrajectoryRepositoryManager(db) as trajectory_repo,
WorkLogRepositoryManager() as work_log_repo,
ConfigRepositoryManager(config) as config_repo,
EnvInvManager(env_data) as env_inv,
):
with SessionRepositoryManager(db) as session_repo, \
KeyFactRepositoryManager(db) as key_fact_repo, \
KeySnippetRepositoryManager(db) as key_snippet_repo, \
HumanInputRepositoryManager(db) as human_input_repo, \
ResearchNoteRepositoryManager(db) as research_note_repo, \
RelatedFilesRepositoryManager() as related_files_repo, \
TrajectoryRepositoryManager(db) as trajectory_repo, \
WorkLogRepositoryManager() as work_log_repo, \
ConfigRepositoryManager(config) as config_repo, \
EnvInvManager(env_data) as env_inv:
# This initializes all repositories and makes them available via their respective get methods
logger.debug("Initialized SessionRepository")
logger.debug("Initialized KeyFactRepository")
logger.debug("Initialized KeySnippetRepository")
logger.debug("Initialized HumanInputRepository")
@ -560,6 +568,10 @@ def main():
logger.debug("Initialized ConfigRepository")
logger.debug("Initialized Environment Inventory")
# Create a new session for this program run
logger.debug("Initializing new session")
session_repo.create_session()
# Check dependencies before proceeding
check_dependencies()
@ -611,6 +623,7 @@ def main():
)
config_repo.set("web_research_enabled", web_research_enabled)
config_repo.set("show_thoughts", args.show_thoughts)
config_repo.set("show_cost", args.show_cost)
config_repo.set("force_reasoning_assistance", args.reasoning_assistance)
config_repo.set(
"disable_reasoning_assistance", args.no_reasoning_assistance
@ -636,11 +649,41 @@ def main():
)
if args.research_only:
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
error_message = "Chat mode cannot be used with --research-only"
trajectory_repo.create(
step_data={
"display_title": "Error",
"error_message": error_message,
},
record_type="error",
human_input_id=human_input_id,
is_error=True,
error_message=error_message,
)
except Exception as traj_error:
# Swallow exception to avoid recursion
logger.debug(f"Error recording trajectory: {traj_error}")
pass
print_error("Chat mode cannot be used with --research-only")
sys.exit(1)
print_stage_header("Chat Mode")
# Record stage transition in trajectory
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"stage": "chat_mode",
"display_title": "Chat Mode",
},
record_type="stage_transition",
human_input_id=human_input_id
)
# Get project info
try:
project_info = get_project_info(".", file_limit=2000)
@ -690,12 +733,9 @@ def main():
config_repo.set("expert_model", args.expert_model)
config_repo.set("temperature", args.temperature)
config_repo.set("show_thoughts", args.show_thoughts)
config_repo.set(
"force_reasoning_assistance", args.reasoning_assistance
)
config_repo.set(
"disable_reasoning_assistance", args.no_reasoning_assistance
)
config_repo.set("show_cost", args.show_cost)
config_repo.set("force_reasoning_assistance", args.reasoning_assistance)
config_repo.set("disable_reasoning_assistance", args.no_reasoning_assistance)
# Set modification tools based on use_aider flag
set_modification_tools(args.use_aider)
@ -737,6 +777,24 @@ def main():
# Validate message is provided
if not args.message:
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
error_message = "--message is required"
trajectory_repo.create(
step_data={
"display_title": "Error",
"error_message": error_message,
},
record_type="error",
human_input_id=human_input_id,
is_error=True,
error_message=error_message,
)
except Exception as traj_error:
# Swallow exception to avoid recursion
logger.debug(f"Error recording trajectory: {traj_error}")
pass
print_error("--message is required")
sys.exit(1)
@ -806,6 +864,18 @@ def main():
# Run research stage
print_stage_header("Research Stage")
# Record stage transition in trajectory
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"stage": "research_stage",
"display_title": "Research Stage",
},
record_type="stage_transition",
human_input_id=human_input_id
)
# Initialize research model with potential overrides
research_provider = args.research_provider or args.provider
research_model_name = args.research_model or args.model

View File

@ -462,6 +462,38 @@ class CiaynAgent:
error_msg = f"Error: {str(e)} \n Could not execute code: {code}"
tool_name = self.extract_tool_name(code)
logger.info(f"Tool execution failed for `{tool_name}`: {str(e)}")
# Record error in trajectory
try:
# Import here to avoid circular imports
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
from ra_aid.database.connection import get_db
# Create repositories directly
trajectory_repo = TrajectoryRepository(get_db())
human_input_repo = HumanInputRepository(get_db())
human_input_id = human_input_repo.get_most_recent_id()
trajectory_repo.create(
step_data={
"error_message": f"Tool execution failed for `{tool_name}`:\nError: {str(e)}",
"display_title": "Tool Error",
"code": code,
"tool_name": tool_name
},
record_type="tool_execution",
human_input_id=human_input_id,
is_error=True,
error_message=str(e),
error_type="ToolExecutionError",
tool_name=tool_name,
tool_parameters={"code": code}
)
except Exception as trajectory_error:
# Just log and continue if there's an error in trajectory recording
logger.error(f"Error recording trajectory for tool error display: {trajectory_error}")
print_warning(f"Tool execution failed for `{tool_name}`:\nError: {str(e)}\n\nCode:\n\n````\n{code}\n````", title="Tool Error")
raise ToolExecutionError(
error_msg, base_message=msg, tool_name=tool_name
@ -495,6 +527,36 @@ class CiaynAgent:
if not fallback_response:
self.chat_history.append(err_msg)
logger.info(f"Tool fallback was attempted but did not succeed. Original error: {str(e)}")
# Record error in trajectory
try:
# Import here to avoid circular imports
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
from ra_aid.database.connection import get_db
# Create repositories directly
trajectory_repo = TrajectoryRepository(get_db())
human_input_repo = HumanInputRepository(get_db())
human_input_id = human_input_repo.get_most_recent_id()
trajectory_repo.create(
step_data={
"error_message": f"Tool fallback was attempted but did not succeed. Original error: {str(e)}",
"display_title": "Fallback Failed",
"tool_name": e.tool_name if hasattr(e, "tool_name") else "unknown_tool"
},
record_type="error",
human_input_id=human_input_id,
is_error=True,
error_message=str(e),
error_type="FallbackFailedError",
tool_name=e.tool_name if hasattr(e, "tool_name") else "unknown_tool"
)
except Exception as trajectory_error:
# Just log and continue if there's an error in trajectory recording
logger.error(f"Error recording trajectory for fallback failed warning: {trajectory_error}")
print_warning(f"Tool fallback was attempted but did not succeed. Original error: {str(e)}", title="Fallback Failed")
return ""
@ -595,6 +657,35 @@ class CiaynAgent:
matches = re.findall(pattern, response, re.DOTALL)
if len(matches) == 0:
logger.info("Failed to extract a valid tool call from the model's response.")
# Record error in trajectory
try:
# Import here to avoid circular imports
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
from ra_aid.database.connection import get_db
# Create repositories directly
trajectory_repo = TrajectoryRepository(get_db())
human_input_repo = HumanInputRepository(get_db())
human_input_id = human_input_repo.get_most_recent_id()
trajectory_repo.create(
step_data={
"error_message": "Failed to extract a valid tool call from the model's response.",
"display_title": "Extraction Failed",
"code": code
},
record_type="error",
human_input_id=human_input_id,
is_error=True,
error_message="Failed to extract a valid tool call from the model's response.",
error_type="ExtractionError"
)
except Exception as trajectory_error:
# Just log and continue if there's an error in trajectory recording
logger.error(f"Error recording trajectory for extraction error display: {trajectory_error}")
print_warning("Failed to extract a valid tool call from the model's response.", title="Extraction Failed")
raise ToolExecutionError("Failed to extract tool call")
ma = matches[0][0].strip()
@ -647,6 +738,36 @@ class CiaynAgent:
warning_message = f"The model returned an empty response (attempt {empty_response_count} of {max_empty_responses}). Requesting the model to make a valid tool call."
logger.info(warning_message)
# Record warning in trajectory
try:
# Import here to avoid circular imports
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
from ra_aid.database.connection import get_db_connection
# Create repositories directly
trajectory_repo = TrajectoryRepository(get_db_connection())
human_input_repo = HumanInputRepository(get_db_connection())
human_input_id = human_input_repo.get_most_recent_id()
trajectory_repo.create(
step_data={
"warning_message": warning_message,
"display_title": "Empty Response",
"attempt": empty_response_count,
"max_attempts": max_empty_responses
},
record_type="error",
human_input_id=human_input_id,
is_error=True,
error_message=warning_message,
error_type="EmptyResponseWarning"
)
except Exception as trajectory_error:
# Just log and continue if there's an error in trajectory recording
logger.error(f"Error recording trajectory for empty response warning: {trajectory_error}")
print_warning(warning_message, title="Empty Response")
if empty_response_count >= max_empty_responses:
@ -658,6 +779,36 @@ class CiaynAgent:
error_message = "The agent has crashed after multiple failed attempts to generate a valid tool call."
logger.error(error_message)
# Record error in trajectory
try:
# Import here to avoid circular imports
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
from ra_aid.database.connection import get_db_connection
# Create repositories directly
trajectory_repo = TrajectoryRepository(get_db_connection())
human_input_repo = HumanInputRepository(get_db_connection())
human_input_id = human_input_repo.get_most_recent_id()
trajectory_repo.create(
step_data={
"error_message": error_message,
"display_title": "Agent Crashed",
"crash_reason": crash_message,
"attempts": empty_response_count
},
record_type="error",
human_input_id=human_input_id,
is_error=True,
error_message=error_message,
error_type="AgentCrashError"
)
except Exception as trajectory_error:
# Just log and continue if there's an error in trajectory recording
logger.error(f"Error recording trajectory for agent crash: {trajectory_error}")
print_error(error_message)
yield self._create_error_chunk(crash_message)

View File

@ -46,6 +46,10 @@ from ra_aid.fallback_handler import FallbackHandler
from ra_aid.logging_config import get_logger
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
from ra_aid.database.repositories.human_input_repository import (
get_human_input_repository,
)
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.anthropic_token_limiter import sonnet_35_state_modifier, state_modifier, get_model_token_limit
@ -284,9 +288,23 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e))
delay = base_delay * (2**attempt)
print_error(
f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
error_message = f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
# Record error in trajectory
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"error_message": error_message,
"display_title": "Error",
},
record_type="error",
human_input_id=human_input_id,
is_error=True,
error_message=error_message
)
print_error(error_message)
start = time.monotonic()
while time.monotonic() - start < delay:
check_interrupt()

View File

@ -22,6 +22,7 @@ from ra_aid import agent_utils
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
from ra_aid.llm import initialize_llm
from ra_aid.prompts.key_facts_gc_prompts import KEY_FACTS_GC_PROMPT
from ra_aid.tools.memory import log_work_event
@ -82,6 +83,22 @@ def delete_key_facts(fact_ids: List[int]) -> str:
if deleted_facts:
deleted_msg = "Successfully deleted facts:\n" + "\n".join([f"- #{fact_id}: {content}" for fact_id, content in deleted_facts])
result_parts.append(deleted_msg)
# Record GC operation in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"deleted_facts": deleted_facts,
"display_title": "Facts Deleted",
},
record_type="gc_operation",
human_input_id=human_input_id,
tool_name="key_facts_gc_agent"
)
except Exception:
pass # Continue if trajectory recording fails
console.print(
Panel(Markdown(deleted_msg), title="Facts Deleted", border_style="green")
)
@ -89,6 +106,22 @@ def delete_key_facts(fact_ids: List[int]) -> str:
if protected_facts:
protected_msg = "Protected facts (associated with current request):\n" + "\n".join([f"- #{fact_id}: {content}" for fact_id, content in protected_facts])
result_parts.append(protected_msg)
# Record GC operation in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"protected_facts": protected_facts,
"display_title": "Facts Protected",
},
record_type="gc_operation",
human_input_id=human_input_id,
tool_name="key_facts_gc_agent"
)
except Exception:
pass # Continue if trajectory recording fails
console.print(
Panel(Markdown(protected_msg), title="Facts Protected", border_style="blue")
)
@ -120,10 +153,44 @@ def run_key_facts_gc_agent() -> None:
fact_count = len(facts)
except RuntimeError as e:
logger.error(f"Failed to access key fact repository: {str(e)}")
# Record GC error in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"error": str(e),
"display_title": "GC Error",
},
record_type="gc_operation",
human_input_id=human_input_id,
tool_name="key_facts_gc_agent",
is_error=True,
error_message=str(e),
error_type="Repository Error"
)
except Exception:
pass # Continue if trajectory recording fails
console.print(Panel(f"Error: {str(e)}", title="🗑 GC Error", border_style="red"))
return # Exit the function if we can't access the repository
# Display status panel with fact count included
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"fact_count": fact_count,
"display_title": "Garbage Collection",
},
record_type="gc_operation",
human_input_id=human_input_id,
tool_name="key_facts_gc_agent"
)
except Exception:
pass # Continue if trajectory recording fails
console.print(Panel(f"Gathering my thoughts...\nCurrent number of key facts: {fact_count}", title="🗑 Garbage Collection"))
# Only run the agent if we actually have facts to clean
@ -185,6 +252,24 @@ def run_key_facts_gc_agent() -> None:
# Show info panel with updated count and protected facts count
protected_count = len(protected_facts)
if protected_count > 0:
# Record GC completion in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"original_count": fact_count,
"updated_count": updated_count,
"protected_count": protected_count,
"display_title": "GC Complete",
},
record_type="gc_operation",
human_input_id=human_input_id,
tool_name="key_facts_gc_agent"
)
except Exception:
pass # Continue if trajectory recording fails
console.print(
Panel(
f"Cleaned key facts: {fact_count}{updated_count}\nProtected facts (associated with current request): {protected_count}",
@ -192,6 +277,24 @@ def run_key_facts_gc_agent() -> None:
)
)
else:
# Record GC completion in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"original_count": fact_count,
"updated_count": updated_count,
"protected_count": 0,
"display_title": "GC Complete",
},
record_type="gc_operation",
human_input_id=human_input_id,
tool_name="key_facts_gc_agent"
)
except Exception:
pass # Continue if trajectory recording fails
console.print(
Panel(
f"Cleaned key facts: {fact_count}{updated_count}",
@ -199,6 +302,40 @@ def run_key_facts_gc_agent() -> None:
)
)
else:
# Record GC info in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"protected_count": len(protected_facts),
"message": "All facts are protected",
"display_title": "GC Info",
},
record_type="gc_operation",
human_input_id=human_input_id,
tool_name="key_facts_gc_agent"
)
except Exception:
pass # Continue if trajectory recording fails
console.print(Panel(f"All {len(protected_facts)} facts are associated with the current request and protected from deletion.", title="🗑 GC Info"))
else:
# Record GC info in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"fact_count": 0,
"message": "No key facts to clean",
"display_title": "GC Info",
},
record_type="gc_operation",
human_input_id=human_input_id,
tool_name="key_facts_gc_agent"
)
except Exception:
pass # Continue if trajectory recording fails
console.print(Panel("No key facts to clean.", title="🗑 GC Info"))

View File

@ -18,6 +18,7 @@ from ra_aid import agent_utils
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
from ra_aid.llm import initialize_llm
from ra_aid.prompts.key_snippets_gc_prompts import KEY_SNIPPETS_GC_PROMPT
from ra_aid.tools.memory import log_work_event
@ -65,6 +66,23 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
success = get_key_snippet_repository().delete(snippet_id)
if success:
success_msg = f"Successfully deleted snippet #{snippet_id} from {filepath}"
# Record GC operation in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"deleted_snippet_id": snippet_id,
"filepath": filepath,
"display_title": "Snippet Deleted",
},
record_type="gc_operation",
human_input_id=human_input_id,
tool_name="key_snippets_gc_agent"
)
except Exception:
pass # Continue if trajectory recording fails
console.print(
Panel(
Markdown(success_msg), title="Snippet Deleted", border_style="green"
@ -86,6 +104,22 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
if protected_snippets:
protected_msg = "Protected snippets (associated with current request):\n" + "\n".join([f"- #{snippet_id}: {filepath}" for snippet_id, filepath in protected_snippets])
result_parts.append(protected_msg)
# Record GC operation in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"protected_snippets": protected_snippets,
"display_title": "Snippets Protected",
},
record_type="gc_operation",
human_input_id=human_input_id,
tool_name="key_snippets_gc_agent"
)
except Exception:
pass # Continue if trajectory recording fails
console.print(
Panel(Markdown(protected_msg), title="Snippets Protected", border_style="blue")
)
@ -116,6 +150,21 @@ def run_key_snippets_gc_agent() -> None:
snippet_count = len(snippets)
# Display status panel with snippet count included
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"snippet_count": snippet_count,
"display_title": "Garbage Collection",
},
record_type="gc_operation",
human_input_id=human_input_id,
tool_name="key_snippets_gc_agent"
)
except Exception:
pass # Continue if trajectory recording fails
console.print(Panel(f"Gathering my thoughts...\nCurrent number of key snippets: {snippet_count}", title="🗑 Garbage Collection"))
# Only run the agent if we actually have snippets to clean
@ -185,6 +234,24 @@ def run_key_snippets_gc_agent() -> None:
# Show info panel with updated count and protected snippets count
protected_count = len(protected_snippets)
if protected_count > 0:
# Record GC completion in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"original_count": snippet_count,
"updated_count": updated_count,
"protected_count": protected_count,
"display_title": "GC Complete",
},
record_type="gc_operation",
human_input_id=human_input_id,
tool_name="key_snippets_gc_agent"
)
except Exception:
pass # Continue if trajectory recording fails
console.print(
Panel(
f"Cleaned key snippets: {snippet_count}{updated_count}\nProtected snippets (associated with current request): {protected_count}",
@ -192,6 +259,24 @@ def run_key_snippets_gc_agent() -> None:
)
)
else:
# Record GC completion in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"original_count": snippet_count,
"updated_count": updated_count,
"protected_count": 0,
"display_title": "GC Complete",
},
record_type="gc_operation",
human_input_id=human_input_id,
tool_name="key_snippets_gc_agent"
)
except Exception:
pass # Continue if trajectory recording fails
console.print(
Panel(
f"Cleaned key snippets: {snippet_count}{updated_count}",
@ -199,6 +284,40 @@ def run_key_snippets_gc_agent() -> None:
)
)
else:
# Record GC info in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"protected_count": len(protected_snippets),
"message": "All snippets are protected",
"display_title": "GC Info",
},
record_type="gc_operation",
human_input_id=human_input_id,
tool_name="key_snippets_gc_agent"
)
except Exception:
pass # Continue if trajectory recording fails
console.print(Panel(f"All {len(protected_snippets)} snippets are associated with the current request and protected from deletion.", title="🗑 GC Info"))
else:
# Record GC info in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"snippet_count": 0,
"message": "No key snippets to clean",
"display_title": "GC Info",
},
record_type="gc_operation",
human_input_id=human_input_id,
tool_name="key_snippets_gc_agent"
)
except Exception:
pass # Continue if trajectory recording fails
console.print(Panel("No key snippets to clean.", title="🗑 GC Info"))

View File

@ -24,6 +24,8 @@ from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.database.repositories.work_log_repository import get_work_log_repository
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
from ra_aid.env_inv_context import get_env_inv
from ra_aid.exceptions import AgentInterrupt
from ra_aid.llm import initialize_expert_llm
@ -156,6 +158,18 @@ def run_planning_agent(
# Display the planning stage header before any reasoning assistance
print_stage_header("Planning Stage")
# Record stage transition in trajectory
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"stage": "planning_stage",
"display_title": "Planning Stage",
},
record_type="stage_transition",
human_input_id=human_input_id
)
# Initialize expert guidance section
expert_guidance = ""

View File

@ -22,6 +22,7 @@ from ra_aid.agent_utils import create_agent, run_agent_with_retry
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
from ra_aid.llm import initialize_llm
from ra_aid.model_formatters.research_notes_formatter import format_research_note
from ra_aid.tools.memory import log_work_event
@ -84,6 +85,22 @@ def delete_research_notes(note_ids: List[int]) -> str:
if deleted_notes:
deleted_msg = "Successfully deleted research notes:\n" + "\n".join([f"- #{note_id}: {content[:100]}..." if len(content) > 100 else f"- #{note_id}: {content}" for note_id, content in deleted_notes])
result_parts.append(deleted_msg)
# Record GC operation in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"deleted_notes": deleted_notes,
"display_title": "Research Notes Deleted",
},
record_type="gc_operation",
human_input_id=human_input_id,
tool_name="research_notes_gc_agent"
)
except Exception:
pass # Continue if trajectory recording fails
console.print(
Panel(Markdown(deleted_msg), title="Research Notes Deleted", border_style="green")
)
@ -91,6 +108,22 @@ def delete_research_notes(note_ids: List[int]) -> str:
if protected_notes:
protected_msg = "Protected research notes (associated with current request):\n" + "\n".join([f"- #{note_id}: {content[:100]}..." if len(content) > 100 else f"- #{note_id}: {content}" for note_id, content in protected_notes])
result_parts.append(protected_msg)
# Record GC operation in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"protected_notes": protected_notes,
"display_title": "Research Notes Protected",
},
record_type="gc_operation",
human_input_id=human_input_id,
tool_name="research_notes_gc_agent"
)
except Exception:
pass # Continue if trajectory recording fails
console.print(
Panel(Markdown(protected_msg), title="Research Notes Protected", border_style="blue")
)
@ -125,10 +158,44 @@ def run_research_notes_gc_agent(threshold: int = 30) -> None:
note_count = len(notes)
except RuntimeError as e:
logger.error(f"Failed to access research note repository: {str(e)}")
# Record GC error in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"error": str(e),
"display_title": "GC Error",
},
record_type="gc_operation",
human_input_id=human_input_id,
tool_name="research_notes_gc_agent",
is_error=True,
error_message=str(e),
error_type="Repository Error"
)
except Exception:
pass # Continue if trajectory recording fails
console.print(Panel(f"Error: {str(e)}", title="🗑 GC Error", border_style="red"))
return # Exit the function if we can't access the repository
# Display status panel with note count included
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"note_count": note_count,
"display_title": "Garbage Collection",
},
record_type="gc_operation",
human_input_id=human_input_id,
tool_name="research_notes_gc_agent"
)
except Exception:
pass # Continue if trajectory recording fails
console.print(Panel(f"Gathering my thoughts...\nCurrent number of research notes: {note_count}", title="🗑 Garbage Collection"))
# Only run the agent if we actually have notes to clean and we're over the threshold
@ -235,6 +302,24 @@ Remember: Your goal is to maintain a concise, high-value collection of research
# Show info panel with updated count and protected notes count
protected_count = len(protected_notes)
if protected_count > 0:
# Record GC completion in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"original_count": note_count,
"updated_count": updated_count,
"protected_count": protected_count,
"display_title": "GC Complete",
},
record_type="gc_operation",
human_input_id=human_input_id,
tool_name="research_notes_gc_agent"
)
except Exception:
pass # Continue if trajectory recording fails
console.print(
Panel(
f"Cleaned research notes: {note_count}{updated_count}\nProtected notes (associated with current request): {protected_count}",
@ -242,6 +327,24 @@ Remember: Your goal is to maintain a concise, high-value collection of research
)
)
else:
# Record GC completion in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"original_count": note_count,
"updated_count": updated_count,
"protected_count": 0,
"display_title": "GC Complete",
},
record_type="gc_operation",
human_input_id=human_input_id,
tool_name="research_notes_gc_agent"
)
except Exception:
pass # Continue if trajectory recording fails
console.print(
Panel(
f"Cleaned research notes: {note_count}{updated_count}",
@ -249,6 +352,41 @@ Remember: Your goal is to maintain a concise, high-value collection of research
)
)
else:
# Record GC info in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"protected_count": len(protected_notes),
"message": "All research notes are protected",
"display_title": "GC Info",
},
record_type="gc_operation",
human_input_id=human_input_id,
tool_name="research_notes_gc_agent"
)
except Exception:
pass # Continue if trajectory recording fails
console.print(Panel(f"All {len(protected_notes)} research notes are associated with the current request and protected from deletion.", title="🗑 GC Info"))
else:
# Record GC info in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"note_count": note_count,
"threshold": threshold,
"message": "Below threshold - no cleanup needed",
"display_title": "GC Info",
},
record_type="gc_operation",
human_input_id=human_input_id,
tool_name="research_notes_gc_agent"
)
except Exception:
pass # Continue if trajectory recording fails
console.print(Panel(f"Research notes count ({note_count}) is below threshold ({threshold}). No cleanup needed.", title="🗑 GC Info"))

View File

@ -7,6 +7,7 @@ FALLBACK_TOOL_MODEL_LIMIT = 5
RETRY_FALLBACK_COUNT = 3
DEFAULT_TEST_CMD_TIMEOUT = 60 * 5 # 5 minutes in seconds
DEFAULT_MODEL="claude-3-7-sonnet-20250219"
DEFAULT_SHOW_COST = False
VALID_PROVIDERS = [

View File

@ -1,6 +1,10 @@
from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from typing import Optional
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
console = Console()

View File

@ -6,14 +6,18 @@ from rich.panel import Panel
from ra_aid.exceptions import ToolExecutionError
from ra_aid.callbacks.anthropic_callback_handler import AnthropicCallbackHandler
from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.config import DEFAULT_SHOW_COST
# Import shared console instance
from .formatting import console
def get_cost_subtitle(cost_cb: Optional[AnthropicCallbackHandler]) -> Optional[str]:
"""Generate a subtitle with cost information if a callback is provided."""
if cost_cb:
"""Generate a subtitle with cost information if a callback is provided and show_cost is enabled."""
# Only show cost information if both cost_cb is provided AND show_cost is True
show_cost = get_config_repository().get("show_cost", DEFAULT_SHOW_COST)
if cost_cb and show_cost:
return f"Cost: ${cost_cb.total_cost:.6f} | Tokens: {cost_cb.total_tokens}"
return None

View File

@ -42,8 +42,8 @@ def initialize_database():
# to avoid circular imports
# Note: This import needs to be here, not at the top level
try:
from ra_aid.database.models import KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory
db.create_tables([KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory], safe=True)
from ra_aid.database.models import KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory, Session
db.create_tables([KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory, Session], safe=True)
logger.debug("Ensured database tables exist")
except Exception as e:
logger.error(f"Error creating tables: {str(e)}")
@ -99,6 +99,25 @@ class BaseModel(peewee.Model):
raise
class Session(BaseModel):
"""
Model representing a session stored in the database.
Sessions track information about each program run, providing a way to group
related records like human inputs, trajectories, and key facts.
Each session record captures details about when the program was started,
what command line arguments were used, and environment information.
"""
start_time = peewee.DateTimeField(default=datetime.datetime.now)
command_line = peewee.TextField(null=True)
program_version = peewee.TextField(null=True)
machine_info = peewee.TextField(null=True) # JSON-encoded machine information
class Meta:
table_name = "session"
class HumanInput(BaseModel):
"""
Model representing human input stored in the database.
@ -109,6 +128,7 @@ class HumanInput(BaseModel):
"""
content = peewee.TextField()
source = peewee.TextField() # 'cli', 'chat', or 'hil'
session = peewee.ForeignKeyField(Session, backref='human_inputs', null=True)
# created_at and updated_at are inherited from BaseModel
class Meta:
@ -124,6 +144,7 @@ class KeyFact(BaseModel):
"""
content = peewee.TextField()
human_input = peewee.ForeignKeyField(HumanInput, backref='key_facts', null=True)
session = peewee.ForeignKeyField(Session, backref='key_facts', null=True)
# created_at and updated_at are inherited from BaseModel
class Meta:
@ -143,6 +164,7 @@ class KeySnippet(BaseModel):
snippet = peewee.TextField()
description = peewee.TextField(null=True)
human_input = peewee.ForeignKeyField(HumanInput, backref='key_snippets', null=True)
session = peewee.ForeignKeyField(Session, backref='key_snippets', null=True)
# created_at and updated_at are inherited from BaseModel
class Meta:
@ -159,6 +181,7 @@ class ResearchNote(BaseModel):
"""
content = peewee.TextField()
human_input = peewee.ForeignKeyField(HumanInput, backref='research_notes', null=True)
session = peewee.ForeignKeyField(Session, backref='research_notes', null=True)
# created_at and updated_at are inherited from BaseModel
class Meta:
@ -182,17 +205,18 @@ class Trajectory(BaseModel):
- Error information (when a tool execution fails)
"""
human_input = peewee.ForeignKeyField(HumanInput, backref='trajectories', null=True)
tool_name = peewee.TextField()
tool_parameters = peewee.TextField() # JSON-encoded parameters
tool_result = peewee.TextField() # JSON-encoded result
step_data = peewee.TextField() # JSON-encoded UI rendering data
record_type = peewee.TextField() # Type of trajectory record
tool_name = peewee.TextField(null=True)
tool_parameters = peewee.TextField(null=True) # JSON-encoded parameters
tool_result = peewee.TextField(null=True) # JSON-encoded result
step_data = peewee.TextField(null=True) # JSON-encoded UI rendering data
record_type = peewee.TextField(null=True) # Type of trajectory record
cost = peewee.FloatField(null=True) # Placeholder for cost tracking
tokens = peewee.IntegerField(null=True) # Placeholder for token usage tracking
is_error = peewee.BooleanField(default=False) # Flag indicating if this record represents an error
error_message = peewee.TextField(null=True) # The error message
error_type = peewee.TextField(null=True) # The type/class of the error
error_details = peewee.TextField(null=True) # Additional error details like stack traces or context
session = peewee.ForeignKeyField(Session, backref='trajectories', null=True)
# created_at and updated_at are inherited from BaseModel
class Meta:

View File

@ -32,6 +32,7 @@ class ConfigRepository:
FALLBACK_TOOL_MODEL_LIMIT,
RETRY_FALLBACK_COUNT,
DEFAULT_TEST_CMD_TIMEOUT,
DEFAULT_SHOW_COST,
VALID_PROVIDERS,
)
@ -42,6 +43,7 @@ class ConfigRepository:
"fallback_tool_model_limit": FALLBACK_TOOL_MODEL_LIMIT,
"retry_fallback_count": RETRY_FALLBACK_COUNT,
"test_cmd_timeout": DEFAULT_TEST_CMD_TIMEOUT,
"show_cost": DEFAULT_SHOW_COST,
"valid_providers": VALID_PROVIDERS,
}

View File

@ -0,0 +1,243 @@
"""
Session repository implementation for database access.
This module provides a repository implementation for the Session model,
following the repository pattern for data access abstraction. It handles
operations for storing and retrieving application session information.
"""
from typing import Dict, List, Optional, Any
import contextvars
import datetime
import json
import logging
import sys
import peewee
from ra_aid.database.models import Session
from ra_aid.__version__ import __version__
from ra_aid.logging_config import get_logger
logger = get_logger(__name__)
# Create contextvar to hold the SessionRepository instance
session_repo_var = contextvars.ContextVar("session_repo", default=None)
class SessionRepositoryManager:
"""
Context manager for SessionRepository.
This class provides a context manager interface for SessionRepository,
using the contextvars approach for thread safety.
Example:
with DatabaseManager() as db:
with SessionRepositoryManager(db) as repo:
# Use the repository
session = repo.create_session()
current_session = repo.get_current_session()
"""
def __init__(self, db):
"""
Initialize the SessionRepositoryManager.
Args:
db: Database connection to use (required)
"""
self.db = db
def __enter__(self) -> 'SessionRepository':
"""
Initialize the SessionRepository and return it.
Returns:
SessionRepository: The initialized repository
"""
repo = SessionRepository(self.db)
session_repo_var.set(repo)
return repo
def __exit__(
self,
exc_type: Optional[type],
exc_val: Optional[Exception],
exc_tb: Optional[object],
) -> None:
"""
Reset the repository when exiting the context.
Args:
exc_type: The exception type if an exception was raised
exc_val: The exception value if an exception was raised
exc_tb: The traceback if an exception was raised
"""
# Reset the contextvar to None
session_repo_var.set(None)
# Don't suppress exceptions
return False
def get_session_repository() -> 'SessionRepository':
"""
Get the current SessionRepository instance.
Returns:
SessionRepository: The current repository instance
Raises:
RuntimeError: If no repository has been initialized with SessionRepositoryManager
"""
repo = session_repo_var.get()
if repo is None:
raise RuntimeError(
"No SessionRepository available. "
"Make sure to initialize one with SessionRepositoryManager first."
)
return repo
class SessionRepository:
"""
Repository for handling Session records in the database.
This class provides methods for creating, retrieving, and managing Session records.
It abstracts away the database operations and provides a clean interface for working
with Session entities.
"""
def __init__(self, db):
"""
Initialize the SessionRepository.
Args:
db: Database connection to use (required)
"""
if db is None:
raise ValueError("Database connection is required for SessionRepository")
self.db = db
self.current_session = None
def create_session(self, metadata: Optional[Dict[str, Any]] = None) -> Session:
"""
Create a new session record in the database.
Args:
metadata: Optional dictionary of additional metadata to store with the session
Returns:
Session: The newly created session instance
Raises:
peewee.DatabaseError: If there's an error creating the record
"""
try:
# Get command line arguments
command_line = " ".join(sys.argv)
# Get program version
program_version = __version__
# JSON encode metadata if provided
machine_info = json.dumps(metadata) if metadata is not None else None
session = Session.create(
start_time=datetime.datetime.now(),
command_line=command_line,
program_version=program_version,
machine_info=machine_info
)
# Store the current session
self.current_session = session
logger.debug(f"Created new session with ID {session.id}")
return session
except peewee.DatabaseError as e:
logger.error(f"Failed to create session record: {str(e)}")
raise
def get_current_session(self) -> Optional[Session]:
"""
Get the current active session.
If no session has been created in this repository instance,
retrieves the most recent session from the database.
Returns:
Optional[Session]: The current session or None if no sessions exist
"""
if self.current_session is not None:
return self.current_session
try:
# Find the most recent session
session = Session.select().order_by(Session.created_at.desc()).first()
if session:
self.current_session = session
return session
except peewee.DatabaseError as e:
logger.error(f"Failed to get current session: {str(e)}")
return None
def get_current_session_id(self) -> Optional[int]:
"""
Get the ID of the current active session.
Returns:
Optional[int]: The ID of the current session or None if no session exists
"""
session = self.get_current_session()
return session.id if session else None
def get(self, session_id: int) -> Optional[Session]:
"""
Get a session by its ID.
Args:
session_id: The ID of the session to retrieve
Returns:
Optional[Session]: The session with the given ID or None if not found
"""
try:
return Session.get_or_none(Session.id == session_id)
except peewee.DatabaseError as e:
logger.error(f"Database error getting session {session_id}: {str(e)}")
return None
def get_all(self) -> List[Session]:
"""
Get all sessions from the database.
Returns:
List[Session]: List of all sessions
"""
try:
return list(Session.select().order_by(Session.created_at.desc()))
except peewee.DatabaseError as e:
logger.error(f"Failed to get all sessions: {str(e)}")
return []
def get_recent(self, limit: int = 10) -> List[Session]:
"""
Get the most recent sessions from the database.
Args:
limit: Maximum number of sessions to return (default: 10)
Returns:
List[Session]: List of the most recent sessions
"""
try:
return list(
Session.select()
.order_by(Session.created_at.desc())
.limit(limit)
)
except peewee.DatabaseError as e:
logger.error(f"Failed to get recent sessions: {str(e)}")
return []

View File

@ -132,8 +132,8 @@ class TrajectoryRepository:
def create(
self,
tool_name: str,
tool_parameters: Dict[str, Any],
tool_name: Optional[str] = None,
tool_parameters: Optional[Dict[str, Any]] = None,
tool_result: Optional[Dict[str, Any]] = None,
step_data: Optional[Dict[str, Any]] = None,
record_type: str = "tool_execution",
@ -149,8 +149,8 @@ class TrajectoryRepository:
Create a new trajectory record in the database.
Args:
tool_name: Name of the tool that was executed
tool_parameters: Parameters passed to the tool (will be JSON encoded)
tool_name: Optional name of the tool that was executed
tool_parameters: Optional parameters passed to the tool (will be JSON encoded)
tool_result: Result returned by the tool (will be JSON encoded)
step_data: UI rendering data (will be JSON encoded)
record_type: Type of trajectory record
@ -170,7 +170,7 @@ class TrajectoryRepository:
"""
try:
# Serialize JSON fields
tool_parameters_json = json.dumps(tool_parameters)
tool_parameters_json = json.dumps(tool_parameters) if tool_parameters is not None else None
tool_result_json = json.dumps(tool_result) if tool_result is not None else None
step_data_json = json.dumps(step_data) if step_data is not None else None
@ -185,7 +185,7 @@ class TrajectoryRepository:
# Create the trajectory record
trajectory = Trajectory.create(
human_input=human_input,
tool_name=tool_name,
tool_name=tool_name or "", # Use empty string if tool_name is None
tool_parameters=tool_parameters_json,
tool_result=tool_result_json,
step_data=step_data_json,
@ -197,7 +197,10 @@ class TrajectoryRepository:
error_type=error_type,
error_details=error_details
)
logger.debug(f"Created trajectory record ID {trajectory.id} for tool: {tool_name}")
if tool_name:
logger.debug(f"Created trajectory record ID {trajectory.id} for tool: {tool_name}")
else:
logger.debug(f"Created trajectory record ID {trajectory.id} of type: {record_type}")
return trajectory
except peewee.DatabaseError as e:
logger.error(f"Failed to create trajectory record: {str(e)}")

View File

@ -154,6 +154,24 @@ class FallbackHandler:
logger.debug(
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback for tool: {self.current_failing_tool_name}"
)
# Import repository classes directly to avoid circular imports
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
from ra_aid.database.connection import get_db
# Create repositories directly
trajectory_repo = TrajectoryRepository(get_db())
human_input_repo = HumanInputRepository(get_db())
human_input_id = human_input_repo.get_most_recent_id()
trajectory_repo.create(
step_data={
"message": f"**Tool fallback activated**: Attempting fallback for tool {self.current_failing_tool_name}.",
"display_title": "Fallback Notification",
},
record_type="info",
human_input_id=human_input_id
)
cpm(
f"**Tool fallback activated**: Attempting fallback for tool {self.current_failing_tool_name}.",
title="Fallback Notification",
@ -163,6 +181,24 @@ class FallbackHandler:
if result_list:
return result_list
# Import repository classes directly to avoid circular imports
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
from ra_aid.database.connection import get_db
# Create repositories directly
trajectory_repo = TrajectoryRepository(get_db())
human_input_repo = HumanInputRepository(get_db())
human_input_id = human_input_repo.get_most_recent_id()
trajectory_repo.create(
step_data={
"message": "All fallback models have failed.",
"display_title": "Fallback Failed",
},
record_type="error",
human_input_id=human_input_id
)
cpm("All fallback models have failed.", title="Fallback Failed")
current_failing_tool_name = self.current_failing_tool_name

View File

@ -234,6 +234,24 @@ def create_llm_client(
elif supports_temperature:
if temperature is None:
temperature = 0.7
# Import repository classes directly to avoid circular imports
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
from ra_aid.database.connection import get_db
# Create repositories directly
trajectory_repo = TrajectoryRepository(get_db())
human_input_repo = HumanInputRepository(get_db())
human_input_id = human_input_repo.get_most_recent_id()
trajectory_repo.create(
step_data={
"message": "This model supports temperature argument but none was given. Setting default temperature to 0.7.",
"display_title": "Information",
},
record_type="info",
human_input_id=human_input_id
)
cpm(
"This model supports temperature argument but none was given. Setting default temperature to 0.7."
)

View File

@ -51,11 +51,11 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
id = pw.AutoField()
created_at = pw.DateTimeField()
updated_at = pw.DateTimeField()
tool_name = pw.TextField()
tool_parameters = pw.TextField() # JSON-encoded parameters
tool_result = pw.TextField() # JSON-encoded result
step_data = pw.TextField() # JSON-encoded UI rendering data
record_type = pw.TextField() # Type of trajectory record
tool_name = pw.TextField(null=True) # JSON-encoded parameters
tool_parameters = pw.TextField(null=True) # JSON-encoded parameters
tool_result = pw.TextField(null=True) # JSON-encoded result
step_data = pw.TextField(null=True) # JSON-encoded UI rendering data
record_type = pw.TextField(null=True) # Type of trajectory record
cost = pw.FloatField(null=True) # Placeholder for cost tracking
tokens = pw.IntegerField(null=True) # Placeholder for token usage tracking
is_error = pw.BooleanField(default=False) # Flag indicating if this record represents an error

View File

@ -0,0 +1,79 @@
"""Peewee migrations -- 008_20250311_191232_add_session_model.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Create the session table for storing application session information."""
table_exists = False
# Check if the table already exists
try:
database.execute_sql("SELECT id FROM session LIMIT 1")
# If we reach here, the table exists
table_exists = True
except pw.OperationalError:
# Table doesn't exist, safe to create
pass
# Create the Session model - this registers it in migrator.orm as 'Session'
@migrator.create_model
class Session(pw.Model):
id = pw.AutoField()
created_at = pw.DateTimeField()
updated_at = pw.DateTimeField()
start_time = pw.DateTimeField()
command_line = pw.TextField(null=True)
program_version = pw.TextField(null=True)
machine_info = pw.TextField(null=True)
class Meta:
table_name = "session"
# FIX: Explicitly register the model under the lowercase table name key
# This ensures that later migrations can access it via either:
# - migrator.orm['Session'] (class name)
# - migrator.orm['session'] (table name)
if 'Session' in migrator.orm:
migrator.orm['session'] = migrator.orm['Session']
# Only return after model registration is complete
if table_exists:
return
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Remove the session table."""
migrator.remove_model('session')

View File

@ -0,0 +1,67 @@
"""Peewee migrations -- 009_20250311_191517_add_session_fk_to_human_input.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Add session foreign key to HumanInput table."""
# Get the Session model from migrator.orm
Session = migrator.orm['session']
# Check if the column already exists
try:
database.execute_sql("SELECT session_id FROM human_input LIMIT 1")
# If we reach here, the column exists
return
except pw.OperationalError:
# Column doesn't exist, safe to add
pass
# Add the session_id foreign key column
migrator.add_fields(
'human_input',
session=pw.ForeignKeyField(
Session,
null=True,
field='id',
on_delete='CASCADE'
)
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Remove session foreign key from HumanInput table."""
migrator.remove_fields('human_input', 'session')

View File

@ -0,0 +1,67 @@
"""Peewee migrations -- 010_20250311_191617_add_session_fk_to_key_fact.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Add session foreign key to KeyFact table."""
# Get the Session model from migrator.orm
Session = migrator.orm['session']
# Check if the column already exists
try:
database.execute_sql("SELECT session_id FROM key_fact LIMIT 1")
# If we reach here, the column exists
return
except pw.OperationalError:
# Column doesn't exist, safe to add
pass
# Add the session_id foreign key column
migrator.add_fields(
'key_fact',
session=pw.ForeignKeyField(
Session,
null=True,
field='id',
on_delete='CASCADE'
)
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Remove session foreign key from KeyFact table."""
migrator.remove_fields('key_fact', 'session')

View File

@ -0,0 +1,67 @@
"""Peewee migrations -- 011_20250311_191732_add_session_fk_to_key_snippet.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Add session foreign key to KeySnippet table."""
# Get the Session model from migrator.orm
Session = migrator.orm['session']
# Check if the column already exists
try:
database.execute_sql("SELECT session_id FROM key_snippet LIMIT 1")
# If we reach here, the column exists
return
except pw.OperationalError:
# Column doesn't exist, safe to add
pass
# Add the session_id foreign key column
migrator.add_fields(
'key_snippet',
session=pw.ForeignKeyField(
Session,
null=True,
field='id',
on_delete='CASCADE'
)
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Remove session foreign key from KeySnippet table."""
migrator.remove_fields('key_snippet', 'session')

View File

@ -0,0 +1,67 @@
"""Peewee migrations -- 012_20250311_191832_add_session_fk_to_research_note.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Add session foreign key to ResearchNote table."""
# Get the Session model from migrator.orm
Session = migrator.orm['session']
# Check if the column already exists
try:
database.execute_sql("SELECT session_id FROM research_note LIMIT 1")
# If we reach here, the column exists
return
except pw.OperationalError:
# Column doesn't exist, safe to add
pass
# Add the session_id foreign key column
migrator.add_fields(
'research_note',
session=pw.ForeignKeyField(
Session,
null=True,
field='id',
on_delete='CASCADE'
)
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Remove session foreign key from ResearchNote table."""
migrator.remove_fields('research_note', 'session')

View File

@ -0,0 +1,67 @@
"""Peewee migrations -- 013_20250311_191701_add_session_fk_to_trajectory.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Add session foreign key to Trajectory table."""
# Get the Session model from migrator.orm
Session = migrator.orm['session']
# Check if the column already exists
try:
database.execute_sql("SELECT session_id FROM trajectory LIMIT 1")
# If we reach here, the column exists
return
except pw.OperationalError:
# Column doesn't exist, safe to add
pass
# Add the session_id foreign key column
migrator.add_fields(
'trajectory',
session=pw.ForeignKeyField(
Session,
null=True,
field='id',
on_delete='CASCADE'
)
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Remove session foreign key from Trajectory table."""
migrator.remove_fields('trajectory', 'session')

View File

@ -17,6 +17,8 @@ __all__ = [
from ra_aid.file_listing import FileListerError, get_file_listing
from ra_aid.project_state import ProjectStateError, is_new_project
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
@dataclass
@ -130,6 +132,24 @@ def display_project_status(info: ProjectInfo) -> None:
{status} with **{file_count} file(s)**
"""
# Record project status in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"project_status": "new" if info.is_new else "existing",
"file_count": file_count,
"total_files": info.total_files,
"display_title": "Project Status",
},
record_type="info",
human_input_id=human_input_id
)
except Exception as e:
# Silently continue if trajectory recording fails
pass
# Create and display panel
console = Console()
console.print(Panel(Markdown(status_text.strip()), title="📊 Project Status"))

View File

@ -194,4 +194,5 @@ THE AGENT IS VERY FORGETFUL AND YOUR WRITING MUST INCLUDE REMARKS ABOUT HOW IT S
REMEMBER WE ARE INSTRUCTING THE AGENT **HOW TO DO RESEARCH ABOUT WHAT ALREADY EXISTS** AT THIS POINT USING THE TOOLS AVAILABLE. YOU ARE NOT TO DO THE ACTUAL RESEARCH YOURSELF. IF AN IMPLEMENTATION IS REQUESTED, THE AGENT SHOULD BE INSTRUCTED TO CALL request_task_implementation BUT ONLY AFTER EMITTING RESEARCH NOTES, KEY FACTS, AND KEY SNIPPETS AS RELEVANT.
IT IS IMPERATIVE THAT WE DO NOT START DIRECTLY IMPLEMENTING ANYTHING AT THIS POINT. WE ARE RESEARCHING, THEN CALLING request_implementation *AT MOST ONCE*.
IT IS IMPERATIVE THE AGENT EMITS KEY FACTS AND THOROUGH RESEARCH NOTES AT THIS POINT. THE RESEARCH NOTES CAN JUST BE THOUGHTS AT THIS POINT IF IT IS A NEW PROJECT.
THE AGENT MUST ALWAYS CALL emit_research_notes AT LEAST ONCE, ESPECIALLY IF IT CALLS ask_expert.
"""

View File

@ -15,11 +15,12 @@ from ra_aid.agent_context import (
reset_completion_flags,
)
from ra_aid.config import DEFAULT_MODEL
from ra_aid.console.formatting import print_error
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
from ra_aid.console.formatting import print_error, print_task_header
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
from ra_aid.exceptions import AgentInterrupt
@ -27,8 +28,7 @@ from ra_aid.model_formatters import format_key_facts_dict
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict
from ..console import print_task_header
from ..llm import initialize_llm
from ra_aid.llm import initialize_llm
from .human import ask_human
from .memory import get_related_files, get_work_log
@ -63,7 +63,23 @@ def request_research(query: str) -> ResearchResult:
# Check recursion depth
current_depth = get_depth()
if current_depth >= RESEARCH_AGENT_RECURSION_LIMIT:
print_error("Maximum research recursion depth reached")
error_message = "Maximum research recursion depth reached"
# Record error in trajectory
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"error_message": error_message,
"display_title": "Error",
},
record_type="error",
human_input_id=human_input_id,
is_error=True,
error_message=error_message
)
print_error(error_message)
try:
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
except RuntimeError as e:
@ -110,7 +126,23 @@ def request_research(query: str) -> ResearchResult:
except KeyboardInterrupt:
raise
except Exception as e:
print_error(f"Error during research: {str(e)}")
error_message = f"Error during research: {str(e)}"
# Record error in trajectory
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"error_message": error_message,
"display_title": "Error",
},
record_type="error",
human_input_id=human_input_id,
is_error=True,
error_message=error_message
)
print_error(error_message)
success = False
reason = f"error: {str(e)}"
finally:
@ -195,7 +227,23 @@ def request_web_research(query: str) -> ResearchResult:
except KeyboardInterrupt:
raise
except Exception as e:
print_error(f"Error during web research: {str(e)}")
error_message = f"Error during web research: {str(e)}"
# Record error in trajectory
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"error_message": error_message,
"display_title": "Error",
},
record_type="error",
human_input_id=human_input_id,
is_error=True,
error_message=error_message
)
print_error(error_message)
success = False
reason = f"error: {str(e)}"
finally:
@ -347,6 +395,19 @@ def request_task_implementation(task_spec: str) -> str:
try:
print_task_header(task_spec)
# Record task display in trajectory
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"task": task_spec,
"display_title": "Task",
},
record_type="task_display",
human_input_id=human_input_id
)
# Run implementation agent
from ..agents.implementation_agent import run_task_implementation_agent
@ -372,7 +433,23 @@ def request_task_implementation(task_spec: str) -> str:
except KeyboardInterrupt:
raise
except Exception as e:
print_error(f"Error during task implementation: {str(e)}")
error_message = f"Error during task implementation: {str(e)}"
# Record error in trajectory
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"error_message": error_message,
"display_title": "Error",
},
record_type="error",
human_input_id=human_input_id,
is_error=True,
error_message=error_message
)
print_error(error_message)
success = False
reason = f"error: {str(e)}"
@ -503,7 +580,23 @@ def request_implementation(task_spec: str) -> str:
except KeyboardInterrupt:
raise
except Exception as e:
print_error(f"Error during planning: {str(e)}")
error_message = f"Error during planning: {str(e)}"
# Record error in trajectory
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"error_message": error_message,
"display_title": "Error",
},
record_type="error",
human_input_id=human_input_id,
is_error=True,
error_message=error_message
)
print_error(error_message)
success = False
reason = f"error: {str(e)}"

View File

@ -9,6 +9,9 @@ from rich.panel import Panel
logger = logging.getLogger(__name__)
from ..database.repositories.trajectory_repository import get_trajectory_repository
from ..database.repositories.human_input_repository import get_human_input_repository
from ..database.repositories.key_fact_repository import get_key_fact_repository
from ..database.repositories.key_snippet_repository import get_key_snippet_repository
from ..database.repositories.related_files_repository import get_related_files_repository
@ -72,6 +75,23 @@ def emit_expert_context(context: str) -> str:
"""
expert_context["text"].append(context)
# Record expert context in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
tool_name="emit_expert_context",
tool_parameters={"context_length": len(context)},
step_data={
"display_title": "Expert Context",
"context_length": len(context),
},
record_type="tool_execution",
human_input_id=human_input_id
)
except Exception as e:
logger.error(f"Failed to record trajectory: {e}")
# Create and display status panel
panel_content = f"Added expert context ({len(context)} characters)"
console.print(Panel(panel_content, title="Expert Context", border_style="blue"))
@ -184,6 +204,23 @@ def ask_expert(question: str) -> str:
# Build display query (just question)
display_query = "# Question\n" + question
# Record expert query in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
tool_name="ask_expert",
tool_parameters={"question": question},
step_data={
"display_title": "Expert Query",
"question": question,
},
record_type="tool_execution",
human_input_id=human_input_id
)
except Exception as e:
logger.error(f"Failed to record trajectory: {e}")
# Show only question in panel
console.print(
Panel(Markdown(display_query), title="🤔 Expert Query", border_style="yellow")
@ -263,6 +300,23 @@ def ask_expert(question: str) -> str:
logger.error(f"Exception during content processing: {str(e)}")
raise
# Record expert response in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
tool_name="ask_expert",
tool_parameters={"question": question},
step_data={
"display_title": "Expert Response",
"response_length": len(content),
},
record_type="tool_execution",
human_input_id=human_input_id
)
except Exception as e:
logger.error(f"Failed to record trajectory: {e}")
# Format and display response
console.print(
Panel(Markdown(content), title="Expert Response", border_style="blue")

View File

@ -7,6 +7,8 @@ from rich.panel import Panel
from ra_aid.console import console
from ra_aid.console.formatting import print_error
from ra_aid.tools.memory import emit_related_files
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
def truncate_display_str(s: str, max_length: int = 30) -> str:
@ -54,6 +56,32 @@ def file_str_replace(filepath: str, old_str: str, new_str: str, *, replace_all:
path = Path(filepath)
if not path.exists():
msg = f"File not found: {filepath}"
# Record error in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"error_message": msg,
"display_title": "Error",
},
record_type="error",
human_input_id=human_input_id,
is_error=True,
error_message=msg,
tool_name="file_str_replace",
tool_parameters={
"filepath": filepath,
"old_str": old_str,
"new_str": new_str,
"replace_all": replace_all
}
)
except Exception:
# Silently handle trajectory recording failures (e.g., in test environments)
pass
print_error(msg)
return {"success": False, "message": msg}
@ -62,10 +90,62 @@ def file_str_replace(filepath: str, old_str: str, new_str: str, *, replace_all:
if count == 0:
msg = f"String not found: {truncate_display_str(old_str)}"
# Record error in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"error_message": msg,
"display_title": "Error",
},
record_type="error",
human_input_id=human_input_id,
is_error=True,
error_message=msg,
tool_name="file_str_replace",
tool_parameters={
"filepath": filepath,
"old_str": old_str,
"new_str": new_str,
"replace_all": replace_all
}
)
except Exception:
# Silently handle trajectory recording failures (e.g., in test environments)
pass
print_error(msg)
return {"success": False, "message": msg}
elif count > 1 and not replace_all:
msg = f"String appears {count} times - must be unique (use replace_all=True to replace all occurrences)"
# Record error in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"error_message": msg,
"display_title": "Error",
},
record_type="error",
human_input_id=human_input_id,
is_error=True,
error_message=msg,
tool_name="file_str_replace",
tool_parameters={
"filepath": filepath,
"old_str": old_str,
"new_str": new_str,
"replace_all": replace_all
}
)
except Exception:
# Silently handle trajectory recording failures (e.g., in test environments)
pass
print_error(msg)
return {"success": False, "message": msg}
@ -93,7 +173,34 @@ def file_str_replace(filepath: str, old_str: str, new_str: str, *, replace_all:
emit_related_files.invoke({"files": [filepath]})
except Exception as e:
# Don't let related files error affect main function success
print_error(f"Note: Could not add to related files: {str(e)}")
error_msg = f"Note: Could not add to related files: {str(e)}"
# Record error in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"error_message": error_msg,
"display_title": "Error",
},
record_type="error",
human_input_id=human_input_id,
is_error=True,
error_message=error_msg,
tool_name="file_str_replace",
tool_parameters={
"filepath": filepath,
"old_str": old_str,
"new_str": new_str,
"replace_all": replace_all
}
)
except Exception:
# Silently handle trajectory recording failures (e.g., in test environments)
pass
print_error(error_msg)
return {
"success": True,
@ -102,5 +209,31 @@ def file_str_replace(filepath: str, old_str: str, new_str: str, *, replace_all:
except Exception as e:
msg = f"Error: {str(e)}"
# Record error in trajectory
try:
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
step_data={
"error_message": msg,
"display_title": "Error",
},
record_type="error",
human_input_id=human_input_id,
is_error=True,
error_message=msg,
tool_name="file_str_replace",
tool_parameters={
"filepath": filepath,
"old_str": old_str,
"new_str": new_str,
"replace_all": replace_all
}
)
except Exception:
# Silently handle trajectory recording failures (e.g., in test environments)
pass
print_error(msg)
return {"success": False, "message": msg}

View File

@ -1,5 +1,6 @@
import fnmatch
from typing import List, Tuple
import logging
from typing import List, Tuple, Dict, Optional, Any
from fuzzywuzzy import process
from git import Repo, exc
@ -12,6 +13,49 @@ from ra_aid.file_listing import get_all_project_files, FileListerError
console = Console()
def record_trajectory(
tool_name: str,
tool_parameters: Dict,
step_data: Dict,
record_type: str = "tool_execution",
is_error: bool = False,
error_message: Optional[str] = None,
error_type: Optional[str] = None
) -> None:
"""
Helper function to record trajectory information, handling the case when repositories are not available.
Args:
tool_name: Name of the tool
tool_parameters: Parameters passed to the tool
step_data: UI rendering data
record_type: Type of trajectory record
is_error: Flag indicating if this record represents an error
error_message: The error message
error_type: The type/class of the error
"""
try:
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
tool_name=tool_name,
tool_parameters=tool_parameters,
step_data=step_data,
record_type=record_type,
human_input_id=human_input_id,
is_error=is_error,
error_message=error_message,
error_type=error_type
)
except (ImportError, RuntimeError):
# If either the repository modules can't be imported or no repository is available,
# just log and continue without recording trajectory
logging.debug("Skipping trajectory recording: repositories not available")
DEFAULT_EXCLUDE_PATTERNS = [
"*.pyc",
"__pycache__/*",
@ -57,7 +101,32 @@ def fuzzy_find_project_files(
"""
# Validate threshold
if not 0 <= threshold <= 100:
raise ValueError("Threshold must be between 0 and 100")
error_msg = "Threshold must be between 0 and 100"
# Record error in trajectory
record_trajectory(
tool_name="fuzzy_find_project_files",
tool_parameters={
"search_term": search_term,
"repo_path": repo_path,
"threshold": threshold,
"max_results": max_results,
"include_paths": include_paths,
"exclude_patterns": exclude_patterns,
"include_hidden": include_hidden
},
step_data={
"search_term": search_term,
"display_title": "Invalid Threshold Value",
"error_message": error_msg
},
record_type="tool_execution",
is_error=True,
error_message=error_msg,
error_type="ValueError"
)
raise ValueError(error_msg)
# Handle empty search term as special case
if not search_term:
@ -126,6 +195,27 @@ def fuzzy_find_project_files(
else:
info_sections.append("## Results\n*No matches found*")
# Record fuzzy find in trajectory
record_trajectory(
tool_name="fuzzy_find_project_files",
tool_parameters={
"search_term": search_term,
"repo_path": repo_path,
"threshold": threshold,
"max_results": max_results,
"include_paths": include_paths,
"exclude_patterns": exclude_patterns,
"include_hidden": include_hidden
},
step_data={
"search_term": search_term,
"display_title": "Fuzzy Find Results",
"total_files": len(all_files),
"matches_found": len(filtered_matches)
},
record_type="tool_execution"
)
# Display the panel
console.print(
Panel(
@ -138,5 +228,30 @@ def fuzzy_find_project_files(
return filtered_matches
except FileListerError as e:
console.print(f"[bold red]Error listing files: {e}[/bold red]")
error_msg = f"Error listing files: {e}"
# Record error in trajectory
record_trajectory(
tool_name="fuzzy_find_project_files",
tool_parameters={
"search_term": search_term,
"repo_path": repo_path,
"threshold": threshold,
"max_results": max_results,
"include_paths": include_paths,
"exclude_patterns": exclude_patterns,
"include_hidden": include_hidden
},
step_data={
"search_term": search_term,
"display_title": "Fuzzy Find Error",
"error_message": error_msg
},
record_type="tool_execution",
is_error=True,
error_message=error_msg,
error_type=type(e).__name__
)
console.print(f"[bold red]{error_msg}[/bold red]")
return []

View File

@ -17,6 +17,7 @@ from ra_aid.database.repositories.key_fact_repository import get_key_fact_reposi
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
from ra_aid.database.repositories.work_log_repository import get_work_log_repository
from ra_aid.model_formatters import key_snippets_formatter
from ra_aid.logging_config import get_logger
@ -69,6 +70,22 @@ def emit_research_notes(notes: str) -> str:
from ra_aid.model_formatters.research_notes_formatter import format_research_note
formatted_note = format_research_note(note_id, notes)
# Record to trajectory before displaying panel
try:
trajectory_repo = get_trajectory_repository()
trajectory_repo.create(
tool_name="emit_research_notes",
tool_parameters={"notes": notes},
step_data={
"note_id": note_id,
"display_title": "Research Notes",
},
record_type="memory_operation",
human_input_id=human_input_id
)
except RuntimeError as e:
logger.warning(f"Failed to record trajectory: {str(e)}")
# Display formatted note
console.print(Panel(Markdown(formatted_note), title="🔍 Research Notes"))
@ -123,6 +140,23 @@ def emit_key_facts(facts: List[str]) -> str:
console.print(f"Error storing fact: {str(e)}", style="red")
continue
# Record to trajectory before displaying panel
try:
trajectory_repo = get_trajectory_repository()
trajectory_repo.create(
tool_name="emit_key_facts",
tool_parameters={"facts": [fact]},
step_data={
"fact_id": fact_id,
"fact": fact,
"display_title": f"Key Fact #{fact_id}",
},
record_type="memory_operation",
human_input_id=human_input_id
)
except RuntimeError as e:
logger.warning(f"Failed to record trajectory: {str(e)}")
# Display panel with ID
console.print(
Panel(
@ -214,6 +248,32 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str:
if snippet_info["description"]:
display_text.extend(["", "**Description**:", snippet_info["description"]])
# Record to trajectory before displaying panel
try:
trajectory_repo = get_trajectory_repository()
trajectory_repo.create(
tool_name="emit_key_snippet",
tool_parameters={
"snippet_info": {
"filepath": snippet_info["filepath"],
"line_number": snippet_info["line_number"],
"description": snippet_info["description"],
# Omit the full snippet content to avoid duplicating large text in the database
"snippet_length": len(snippet_info["snippet"])
}
},
step_data={
"snippet_id": snippet_id,
"filepath": snippet_info["filepath"],
"line_number": snippet_info["line_number"],
"display_title": f"Key Snippet #{snippet_id}",
},
record_type="memory_operation",
human_input_id=human_input_id
)
except RuntimeError as e:
logger.warning(f"Failed to record trajectory: {str(e)}")
# Display panel
console.print(
Panel(
@ -248,6 +308,25 @@ def one_shot_completed(message: str) -> str:
message: Completion message to display
"""
mark_task_completed(message)
# Record to trajectory before displaying panel
human_input_id = None
try:
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo = get_trajectory_repository()
trajectory_repo.create(
tool_name="one_shot_completed",
tool_parameters={"message": message},
step_data={
"completion_message": message,
"display_title": "Task Completed",
},
record_type="task_completion",
human_input_id=human_input_id
)
except RuntimeError as e:
logger.warning(f"Failed to record trajectory: {str(e)}")
console.print(Panel(Markdown(message), title="✅ Task Completed"))
log_work_event(f"Task completed:\n\n{message}")
return "Completion noted."
@ -261,6 +340,25 @@ def task_completed(message: str) -> str:
message: Message explaining how/why the task is complete
"""
mark_task_completed(message)
# Record to trajectory before displaying panel
human_input_id = None
try:
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo = get_trajectory_repository()
trajectory_repo.create(
tool_name="task_completed",
tool_parameters={"message": message},
step_data={
"completion_message": message,
"display_title": "Task Completed",
},
record_type="task_completion",
human_input_id=human_input_id
)
except RuntimeError as e:
logger.warning(f"Failed to record trajectory: {str(e)}")
console.print(Panel(Markdown(message), title="✅ Task Completed"))
log_work_event(f"Task completed:\n\n{message}")
return "Completion noted."
@ -275,6 +373,25 @@ def plan_implementation_completed(message: str) -> str:
"""
mark_should_exit(propagation_depth=1)
mark_plan_completed(message)
# Record to trajectory before displaying panel
human_input_id = None
try:
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo = get_trajectory_repository()
trajectory_repo.create(
tool_name="plan_implementation_completed",
tool_parameters={"message": message},
step_data={
"completion_message": message,
"display_title": "Plan Executed",
},
record_type="plan_completion",
human_input_id=human_input_id
)
except RuntimeError as e:
logger.warning(f"Failed to record trajectory: {str(e)}")
console.print(Panel(Markdown(message), title="✅ Plan Executed"))
log_work_event(f"Completed implementation:\n\n{message}")
return "Plan completion noted."
@ -361,10 +478,29 @@ def emit_related_files(files: List[str]) -> str:
results.append(f"File ID #{file_id}: {file}")
# Rich output - single consolidated panel for added files
# Record to trajectory before displaying panel for added files
if added_files:
files_added_md = "\n".join(f"- `{file}`" for id, file in added_files)
md_content = f"**Files Noted:**\n{files_added_md}"
human_input_id = None
try:
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo = get_trajectory_repository()
trajectory_repo.create(
tool_name="emit_related_files",
tool_parameters={"files": files},
step_data={
"added_files": [file for _, file in added_files],
"added_file_ids": [file_id for file_id, _ in added_files],
"display_title": "Related Files Noted",
},
record_type="memory_operation",
human_input_id=human_input_id
)
except RuntimeError as e:
logger.warning(f"Failed to record trajectory: {str(e)}")
console.print(
Panel(
Markdown(md_content),
@ -373,10 +509,28 @@ def emit_related_files(files: List[str]) -> str:
)
)
# Display skipped binary files
# Record to trajectory before displaying panel for binary files
if binary_files:
binary_files_md = "\n".join(f"- `{file}`" for file in binary_files)
md_content = f"**Binary Files Skipped:**\n{binary_files_md}"
human_input_id = None
try:
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo = get_trajectory_repository()
trajectory_repo.create(
tool_name="emit_related_files",
tool_parameters={"files": files},
step_data={
"binary_files": binary_files,
"display_title": "Binary Files Not Added",
},
record_type="memory_operation",
human_input_id=human_input_id
)
except RuntimeError as e:
logger.warning(f"Failed to record trajectory: {str(e)}")
console.print(
Panel(
Markdown(md_content),

View File

@ -1,7 +1,7 @@
import logging
import os.path
import time
from typing import Dict
from typing import Dict, Optional
from langchain_core.tools import tool
from rich.console import Console
@ -16,6 +16,49 @@ console = Console()
CHUNK_SIZE = 8192
def record_trajectory(
tool_name: str,
tool_parameters: Dict,
step_data: Dict,
record_type: str = "tool_execution",
is_error: bool = False,
error_message: Optional[str] = None,
error_type: Optional[str] = None
) -> None:
"""
Helper function to record trajectory information, handling the case when repositories are not available.
Args:
tool_name: Name of the tool
tool_parameters: Parameters passed to the tool
step_data: UI rendering data
record_type: Type of trajectory record
is_error: Flag indicating if this record represents an error
error_message: The error message
error_type: The type/class of the error
"""
try:
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
tool_name=tool_name,
tool_parameters=tool_parameters,
step_data=step_data,
record_type=record_type,
human_input_id=human_input_id,
is_error=is_error,
error_message=error_message,
error_type=error_type
)
except (ImportError, RuntimeError):
# If either the repository modules can't be imported or no repository is available,
# just log and continue without recording trajectory
logging.debug("Skipping trajectory recording: repositories not available")
@tool
def read_file_tool(filepath: str, encoding: str = "utf-8") -> Dict[str, str]:
"""Read and return the contents of a text file.
@ -29,10 +72,43 @@ def read_file_tool(filepath: str, encoding: str = "utf-8") -> Dict[str, str]:
start_time = time.time()
try:
if not os.path.exists(filepath):
# Record error in trajectory
record_trajectory(
tool_name="read_file_tool",
tool_parameters={
"filepath": filepath,
"encoding": encoding
},
step_data={
"filepath": filepath,
"display_title": "File Not Found",
"error_message": f"File not found: {filepath}"
},
is_error=True,
error_message=f"File not found: {filepath}",
error_type="FileNotFoundError"
)
raise FileNotFoundError(f"File not found: {filepath}")
# Check if the file is binary
if is_binary_file(filepath):
# Record binary file error in trajectory
record_trajectory(
tool_name="read_file_tool",
tool_parameters={
"filepath": filepath,
"encoding": encoding
},
step_data={
"filepath": filepath,
"display_title": "Binary File Detected",
"error_message": f"Cannot read binary file: {filepath}"
},
is_error=True,
error_message="Cannot read binary file",
error_type="BinaryFileError"
)
console.print(
Panel(
f"Cannot read binary file: {filepath}",
@ -67,6 +143,22 @@ def read_file_tool(filepath: str, encoding: str = "utf-8") -> Dict[str, str]:
logging.debug(f"File read complete: {total_bytes} bytes in {elapsed:.2f}s")
logging.debug(f"Pre-truncation stats: {total_bytes} bytes, {line_count} lines")
# Record successful file read in trajectory
record_trajectory(
tool_name="read_file_tool",
tool_parameters={
"filepath": filepath,
"encoding": encoding
},
step_data={
"filepath": filepath,
"display_title": "File Read",
"line_count": line_count,
"total_bytes": total_bytes,
"elapsed_time": elapsed
}
)
console.print(
Panel(
f"Read {line_count} lines ({total_bytes} bytes) from {filepath} in {elapsed:.2f}s",
@ -80,6 +172,25 @@ def read_file_tool(filepath: str, encoding: str = "utf-8") -> Dict[str, str]:
return {"content": truncated}
except Exception:
except Exception as e:
elapsed = time.time() - start_time
# Record exception in trajectory (if it's not already a handled FileNotFoundError)
if not isinstance(e, FileNotFoundError):
record_trajectory(
tool_name="read_file_tool",
tool_parameters={
"filepath": filepath,
"encoding": encoding
},
step_data={
"filepath": filepath,
"display_title": "File Read Error",
"error_message": str(e)
},
is_error=True,
error_message=str(e),
error_type=type(e).__name__
)
raise

View File

@ -2,6 +2,9 @@ from langchain_core.tools import tool
from rich.console import Console
from rich.panel import Panel
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
console = Console()
@ -10,6 +13,24 @@ def existing_project_detected() -> dict:
"""
When to call: Once you have confirmed that the current working directory contains project files.
"""
try:
# Record detection in trajectory
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
tool_name="existing_project_detected",
tool_parameters={},
step_data={
"detection_type": "existing_project",
"display_title": "Existing Project Detected",
},
record_type="tool_execution",
human_input_id=human_input_id
)
except Exception as e:
# Continue even if trajectory recording fails
console.print(f"Warning: Could not record trajectory: {str(e)}")
console.print(Panel("📁 Existing Project Detected", style="bright_blue", padding=0))
return {
"hint": (
@ -30,6 +51,24 @@ def monorepo_detected() -> dict:
"""
When to call: After identifying that multiple packages or modules exist within a single repository.
"""
try:
# Record detection in trajectory
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
tool_name="monorepo_detected",
tool_parameters={},
step_data={
"detection_type": "monorepo",
"display_title": "Monorepo Detected",
},
record_type="tool_execution",
human_input_id=human_input_id
)
except Exception as e:
# Continue even if trajectory recording fails
console.print(f"Warning: Could not record trajectory: {str(e)}")
console.print(Panel("📦 Monorepo Detected", style="bright_blue", padding=0))
return {
"hint": (
@ -53,6 +92,24 @@ def ui_detected() -> dict:
"""
When to call: After detecting that the project contains a user interface layer or front-end component.
"""
try:
# Record detection in trajectory
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
tool_name="ui_detected",
tool_parameters={},
step_data={
"detection_type": "ui",
"display_title": "UI Detected",
},
record_type="tool_execution",
human_input_id=human_input_id
)
except Exception as e:
# Continue even if trajectory recording fails
console.print(f"Warning: Could not record trajectory: {str(e)}")
console.print(Panel("🎯 UI Detected", style="bright_blue", padding=0))
return {
"hint": (

View File

@ -5,6 +5,8 @@ from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
from ra_aid.proc.interactive import run_interactive_command
from ra_aid.text.processing import truncate_output
@ -158,6 +160,30 @@ def ripgrep_search(
info_sections.append("\n".join(params))
# Execute command
# Record ripgrep search in trajectory
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
tool_name="ripgrep_search",
tool_parameters={
"pattern": pattern,
"before_context_lines": before_context_lines,
"after_context_lines": after_context_lines,
"file_type": file_type,
"case_sensitive": case_sensitive,
"include_hidden": include_hidden,
"follow_links": follow_links,
"exclude_dirs": exclude_dirs,
"fixed_string": fixed_string
},
step_data={
"search_pattern": pattern,
"display_title": "Ripgrep Search",
},
record_type="tool_execution",
human_input_id=human_input_id
)
console.print(
Panel(
Markdown(f"Searching for: **{pattern}**"),
@ -179,5 +205,34 @@ def ripgrep_search(
except Exception as e:
error_msg = str(e)
# Record error in trajectory
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
tool_name="ripgrep_search",
tool_parameters={
"pattern": pattern,
"before_context_lines": before_context_lines,
"after_context_lines": after_context_lines,
"file_type": file_type,
"case_sensitive": case_sensitive,
"include_hidden": include_hidden,
"follow_links": follow_links,
"exclude_dirs": exclude_dirs,
"fixed_string": fixed_string
},
step_data={
"search_pattern": pattern,
"display_title": "Ripgrep Search Error",
"error_message": error_msg
},
record_type="tool_execution",
human_input_id=human_input_id,
is_error=True,
error_message=error_msg,
error_type=type(e).__name__
)
console.print(Panel(error_msg, title="❌ Error", border_style="red"))
return {"output": error_msg, "return_code": 1, "success": False}

View File

@ -10,6 +10,8 @@ from ra_aid.proc.interactive import run_interactive_command
from ra_aid.text.processing import truncate_output
from ra_aid.tools.memory import log_work_event
from ra_aid.database.repositories.config_repository import get_config_repository
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
console = Console()
@ -54,6 +56,20 @@ def run_shell_command(
console.print(" " + get_cowboy_message())
console.print("")
# Record tool execution in trajectory
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
tool_name="run_shell_command",
tool_parameters={"command": command, "timeout": timeout},
step_data={
"command": command,
"display_title": "Shell Command",
},
record_type="tool_execution",
human_input_id=human_input_id
)
# Show just the command in a simple panel
console.print(Panel(command, title="🐚 Shell", border_style="bright_yellow"))
@ -96,5 +112,23 @@ def run_shell_command(
return result
except Exception as e:
print()
# Record error in trajectory
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
tool_name="run_shell_command",
tool_parameters={"command": command, "timeout": timeout},
step_data={
"command": command,
"error": str(e),
"display_title": "Shell Error",
},
record_type="tool_execution",
is_error=True,
error_message=str(e),
error_type=type(e).__name__,
human_input_id=human_input_id
)
console.print(Panel(str(e), title="❌ Error", border_style="red"))
return {"output": str(e), "return_code": 1, "success": False}

View File

@ -7,6 +7,9 @@ from rich.markdown import Markdown
from rich.panel import Panel
from tavily import TavilyClient
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
console = Console()
@ -21,9 +24,44 @@ def web_search_tavily(query: str) -> Dict:
Returns:
Dict containing search results from Tavily
"""
client = TavilyClient(api_key=os.environ["TAVILY_API_KEY"])
# Record trajectory before displaying panel
trajectory_repo = get_trajectory_repository()
human_input_id = get_human_input_repository().get_most_recent_id()
trajectory_repo.create(
tool_name="web_search_tavily",
tool_parameters={"query": query},
step_data={
"query": query,
"display_title": "Web Search",
},
record_type="tool_execution",
human_input_id=human_input_id
)
# Display search query panel
console.print(
Panel(Markdown(query), title="🔍 Searching Tavily", border_style="bright_blue")
)
search_result = client.search(query=query)
return search_result
try:
client = TavilyClient(api_key=os.environ["TAVILY_API_KEY"])
search_result = client.search(query=query)
return search_result
except Exception as e:
# Record error in trajectory
trajectory_repo.create(
tool_name="web_search_tavily",
tool_parameters={"query": query},
step_data={
"query": query,
"display_title": "Web Search Error",
"error": str(e)
},
record_type="tool_execution",
human_input_id=human_input_id,
is_error=True,
error_message=str(e),
error_type=type(e).__name__
)
# Re-raise the exception to maintain original behavior
raise

View File

@ -7,7 +7,7 @@ ensuring consistent test environments and proper isolation.
import os
from pathlib import Path
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch
import pytest
@ -26,6 +26,39 @@ def mock_config_repository():
yield repo
@pytest.fixture()
def mock_trajectory_repository():
"""Mock the TrajectoryRepository to avoid database operations during tests."""
with patch('ra_aid.database.repositories.trajectory_repository.TrajectoryRepository') as mock:
# Setup a mock repository
mock_repo = MagicMock()
mock_repo.create.return_value = MagicMock(id=1)
mock.return_value = mock_repo
yield mock_repo
@pytest.fixture()
def mock_human_input_repository():
"""Mock the HumanInputRepository to avoid database operations during tests."""
with patch('ra_aid.database.repositories.human_input_repository.HumanInputRepository') as mock:
# Setup a mock repository
mock_repo = MagicMock()
mock_repo.get_most_recent_id.return_value = 1
mock_repo.create.return_value = MagicMock(id=1)
mock.return_value = mock_repo
yield mock_repo
@pytest.fixture()
def mock_repository_access(mock_trajectory_repository, mock_human_input_repository):
"""Mock all repository accessor functions."""
with patch('ra_aid.database.repositories.trajectory_repository.get_trajectory_repository',
return_value=mock_trajectory_repository):
with patch('ra_aid.database.repositories.human_input_repository.get_human_input_repository',
return_value=mock_human_input_repository):
yield
@pytest.fixture(autouse=True)
def isolated_db_environment(tmp_path, monkeypatch):
"""

View File

@ -21,7 +21,11 @@ from ra_aid.anthropic_token_limiter import (
state_modifier,
)
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
from ra_aid.database.repositories.config_repository import ConfigRepositoryManager, get_config_repository, config_repo_var
from ra_aid.database.repositories.config_repository import (
ConfigRepositoryManager,
get_config_repository,
config_repo_var,
)
@pytest.fixture
@ -34,7 +38,9 @@ def mock_model():
@pytest.fixture
def mock_config_repository():
"""Mock the ConfigRepository to avoid database operations during tests"""
with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var:
with patch(
"ra_aid.database.repositories.config_repository.config_repo_var"
) as mock_repo_var:
# Setup a mock repository
mock_repo = MagicMock()
@ -44,6 +50,7 @@ def mock_config_repository():
# Setup get method to return config values
def get_config(key, default=None):
return config.get(key, default)
mock_repo.get.side_effect = get_config
# Setup get_all method to return all config values
@ -52,11 +59,13 @@ def mock_config_repository():
# Setup set method to update config values
def set_config(key, value):
config[key] = value
mock_repo.set.side_effect = set_config
# Setup update method to update multiple config values
def update_config(update_dict):
config.update(update_dict)
mock_repo.update.side_effect = update_config
# Make the mock context var return our mock repo
@ -65,15 +74,55 @@ def mock_config_repository():
yield mock_repo
# These tests have been moved to test_anthropic_token_limiter.py
@pytest.fixture(autouse=True)
def mock_trajectory_repository():
"""Mock the TrajectoryRepository to avoid database operations during tests"""
with patch(
"ra_aid.database.repositories.trajectory_repository.trajectory_repo_var"
) as mock_repo_var:
# Setup a mock repository
mock_repo = MagicMock()
# Setup create method to return a mock trajectory
def mock_create(**kwargs):
mock_trajectory = MagicMock()
mock_trajectory.id = 1
return mock_trajectory
mock_repo.create.side_effect = mock_create
# Make the mock context var return our mock repo
mock_repo_var.get.return_value = mock_repo
yield mock_repo
@pytest.fixture(autouse=True)
def mock_human_input_repository():
"""Mock the HumanInputRepository to avoid database operations during tests"""
with patch(
"ra_aid.database.repositories.human_input_repository.human_input_repo_var"
) as mock_repo_var:
# Setup a mock repository
mock_repo = MagicMock()
# Setup get_most_recent_id method to return a dummy ID
mock_repo.get_most_recent_id.return_value = 1
# Make the mock context var return our mock repo
mock_repo_var.get.return_value = mock_repo
yield mock_repo
def test_create_agent_anthropic(mock_model, mock_config_repository):
"""Test create_agent with Anthropic Claude model."""
mock_config_repository.update({"provider": "anthropic", "model": "claude-2"})
with patch("ra_aid.agent_utils.create_react_agent") as mock_react, \
patch("ra_aid.anthropic_token_limiter.state_modifier") as mock_state_modifier:
with (
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
patch("ra_aid.anthropic_token_limiter.state_modifier") as mock_state_modifier,
):
mock_react.return_value = "react_agent"
agent = create_agent(mock_model, [])
@ -81,7 +130,7 @@ def test_create_agent_anthropic(mock_model, mock_config_repository):
mock_react.assert_called_once_with(
mock_model,
[],
interrupt_after=['tools'],
interrupt_after=["tools"],
version="v2",
state_modifier=mock_react.call_args[1]["state_modifier"],
name="React",
@ -173,13 +222,17 @@ def test_create_agent_with_checkpointer(mock_model, mock_config_repository):
)
def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_config_repository):
def test_create_agent_anthropic_token_limiting_enabled(
mock_model, mock_config_repository
):
"""Test create_agent sets up token limiting for Claude models when enabled."""
mock_config_repository.update({
"provider": "anthropic",
"model": "claude-2",
"limit_tokens": True,
})
mock_config_repository.update(
{
"provider": "anthropic",
"model": "claude-2",
"limit_tokens": True,
}
)
with (
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
@ -196,13 +249,17 @@ def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_config_r
assert callable(args[1]["state_modifier"])
def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_repository):
def test_create_agent_anthropic_token_limiting_disabled(
mock_model, mock_config_repository
):
"""Test create_agent doesn't set up token limiting for Claude models when disabled."""
mock_config_repository.update({
"provider": "anthropic",
"model": "claude-2",
"limit_tokens": False,
})
mock_config_repository.update(
{
"provider": "anthropic",
"model": "claude-2",
"limit_tokens": False,
}
)
with (
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
@ -214,7 +271,9 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_
agent = create_agent(mock_model, [])
assert agent == "react_agent"
mock_react.assert_called_once_with(mock_model, [], interrupt_after=['tools'], version="v2", name="React")
mock_react.assert_called_once_with(
mock_model, [], interrupt_after=["tools"], version="v2", name="React"
)
# These tests have been moved to test_anthropic_token_limiter.py
@ -482,7 +541,9 @@ def test_run_agent_with_retry_checks_crash_status(monkeypatch, mock_config_repos
assert "Agent has crashed: Test crash message" in result
def test_run_agent_with_retry_handles_badrequest_error(monkeypatch, mock_config_repository):
def test_run_agent_with_retry_handles_badrequest_error(
monkeypatch, mock_config_repository
):
"""Test that run_agent_with_retry properly handles BadRequestError as unretryable."""
from ra_aid.agent_context import agent_context, is_crashed
from ra_aid.agent_utils import run_agent_with_retry
@ -540,7 +601,9 @@ def test_run_agent_with_retry_handles_badrequest_error(monkeypatch, mock_config_
assert is_crashed()
def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch, mock_config_repository):
def test_run_agent_with_retry_handles_api_badrequest_error(
monkeypatch, mock_config_repository
):
"""Test that run_agent_with_retry properly handles API BadRequestError as unretryable."""
# Import APIError from anthropic module and patch it on the agent_utils module
@ -613,5 +676,7 @@ def test_handle_api_error_resource_exhausted():
from ra_aid.agent_utils import _handle_api_error
# ResourceExhausted exception should be handled without raising
resource_exhausted_error = ResourceExhausted("429 Resource has been exhausted (e.g. check quota).")
resource_exhausted_error = ResourceExhausted(
"429 Resource has been exhausted (e.g. check quota)."
)
_handle_api_error(resource_exhausted_error, 0, 5, 1)

View File

@ -113,6 +113,40 @@ def mock_work_log_repository():
yield mock_repo
@pytest.fixture(autouse=True)
def mock_trajectory_repository():
"""Mock the TrajectoryRepository to avoid database operations during tests"""
with patch('ra_aid.database.repositories.trajectory_repository.trajectory_repo_var') as mock_repo_var:
# Setup a mock repository
mock_repo = MagicMock()
# Setup create method to return a mock trajectory
def mock_create(**kwargs):
mock_trajectory = MagicMock()
mock_trajectory.id = 1
return mock_trajectory
mock_repo.create.side_effect = mock_create
# Make the mock context var return our mock repo
mock_repo_var.get.return_value = mock_repo
yield mock_repo
@pytest.fixture(autouse=True)
def mock_human_input_repository():
"""Mock the HumanInputRepository to avoid database operations during tests"""
with patch('ra_aid.database.repositories.human_input_repository.human_input_repo_var') as mock_repo_var:
# Setup a mock repository
mock_repo = MagicMock()
# Setup get_most_recent_id method to return a dummy ID
mock_repo.get_most_recent_id.return_value = 1
# Make the mock context var return our mock repo
mock_repo_var.get.return_value = mock_repo
yield mock_repo
@pytest.fixture
def mock_functions():
"""Mock functions used in agent.py"""
@ -126,7 +160,9 @@ def mock_functions():
patch('ra_aid.tools.agent.get_related_files') as mock_get_files, \
patch('ra_aid.tools.agent.get_work_log') as mock_get_work_log, \
patch('ra_aid.tools.agent.reset_completion_flags') as mock_reset, \
patch('ra_aid.tools.agent.get_completion_message') as mock_get_completion:
patch('ra_aid.tools.agent.get_completion_message') as mock_get_completion, \
patch('ra_aid.tools.agent.get_trajectory_repository') as mock_get_trajectory_repo, \
patch('ra_aid.tools.agent.get_human_input_repository') as mock_get_human_input_repo:
# Setup mock return values
mock_fact_repo.get_facts_dict.return_value = {1: "Test fact 1", 2: "Test fact 2"}
@ -138,6 +174,15 @@ def mock_functions():
mock_get_work_log.return_value = "Test work log"
mock_get_completion.return_value = "Task completed"
# Setup mock for trajectory repository
mock_trajectory_repo = MagicMock()
mock_get_trajectory_repo.return_value = mock_trajectory_repo
# Setup mock for human input repository
mock_human_input_repo = MagicMock()
mock_human_input_repo.get_most_recent_id.return_value = 1
mock_get_human_input_repo.return_value = mock_human_input_repo
# Return all mocks as a dictionary
yield {
'get_key_fact_repository': mock_get_fact_repo,
@ -148,7 +193,9 @@ def mock_functions():
'get_related_files': mock_get_files,
'get_work_log': mock_get_work_log,
'reset_completion_flags': mock_reset,
'get_completion_message': mock_get_completion
'get_completion_message': mock_get_completion,
'get_trajectory_repository': mock_get_trajectory_repo,
'get_human_input_repository': mock_get_human_input_repo
}

View File

@ -52,6 +52,40 @@ def mock_config_repository():
yield mock_repo
@pytest.fixture(autouse=True)
def mock_trajectory_repository():
"""Mock the TrajectoryRepository to avoid database operations during tests"""
with patch('ra_aid.database.repositories.trajectory_repository.trajectory_repo_var') as mock_repo_var:
# Setup a mock repository
mock_repo = MagicMock()
# Setup create method to return a mock trajectory
def mock_create(**kwargs):
mock_trajectory = MagicMock()
mock_trajectory.id = 1
return mock_trajectory
mock_repo.create.side_effect = mock_create
# Make the mock context var return our mock repo
mock_repo_var.get.return_value = mock_repo
yield mock_repo
@pytest.fixture(autouse=True)
def mock_human_input_repository():
"""Mock the HumanInputRepository to avoid database operations during tests"""
with patch('ra_aid.database.repositories.human_input_repository.human_input_repo_var') as mock_repo_var:
# Setup a mock repository
mock_repo = MagicMock()
# Setup get_most_recent_id method to return a dummy ID
mock_repo.get_most_recent_id.return_value = 1
# Make the mock context var return our mock repo
mock_repo_var.get.return_value = mock_repo
yield mock_repo
def test_shell_command_cowboy_mode(mock_console, mock_prompt, mock_run_interactive, mock_config_repository):
"""Test shell command execution in cowboy mode (no approval)"""

5
vsc/.vscode-test.mjs Normal file
View File

@ -0,0 +1,5 @@
import { defineConfig } from '@vscode/test-cli';
export default defineConfig({
files: 'out/test/**/*.test.js',
});

5
vsc/.vscode/extensions.json vendored Normal file
View File

@ -0,0 +1,5 @@
{
// See http://go.microsoft.com/fwlink/?LinkId=827846
// for the documentation about the extensions.json format
"recommendations": ["dbaeumer.vscode-eslint", "connor4312.esbuild-problem-matchers", "ms-vscode.extension-test-runner"]
}

21
vsc/.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,21 @@
// A launch configuration that compiles the extension and then opens it inside a new window
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
{
"version": "0.2.0",
"configurations": [
{
"name": "Run Extension",
"type": "extensionHost",
"request": "launch",
"args": [
"--extensionDevelopmentPath=${workspaceFolder}"
],
"outFiles": [
"${workspaceFolder}/dist/**/*.js"
],
"preLaunchTask": "${defaultBuildTask}"
}
]
}

13
vsc/.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,13 @@
// Place your settings in this file to overwrite default and user settings.
{
"files.exclude": {
"out": false, // set this to true to hide the "out" folder with the compiled JS files
"dist": false // set this to true to hide the "dist" folder with the compiled JS files
},
"search.exclude": {
"out": true, // set this to false to include "out" folder in search results
"dist": true // set this to false to include "dist" folder in search results
},
// Turn off tsc task auto detection since we have the necessary tasks as npm scripts
"typescript.tsc.autoDetect": "off"
}

64
vsc/.vscode/tasks.json vendored Normal file
View File

@ -0,0 +1,64 @@
// See https://go.microsoft.com/fwlink/?LinkId=733558
// for the documentation about the tasks.json format
{
"version": "2.0.0",
"tasks": [
{
"label": "watch",
"dependsOn": [
"npm: watch:tsc",
"npm: watch:esbuild"
],
"presentation": {
"reveal": "never"
},
"group": {
"kind": "build",
"isDefault": true
}
},
{
"type": "npm",
"script": "watch:esbuild",
"group": "build",
"problemMatcher": "$esbuild-watch",
"isBackground": true,
"label": "npm: watch:esbuild",
"presentation": {
"group": "watch",
"reveal": "never"
}
},
{
"type": "npm",
"script": "watch:tsc",
"group": "build",
"problemMatcher": "$tsc-watch",
"isBackground": true,
"label": "npm: watch:tsc",
"presentation": {
"group": "watch",
"reveal": "never"
}
},
{
"type": "npm",
"script": "watch-tests",
"problemMatcher": "$tsc-watch",
"isBackground": true,
"presentation": {
"reveal": "never",
"group": "watchers"
},
"group": "build"
},
{
"label": "tasks: watch-tests",
"dependsOn": [
"npm: watch",
"npm: watch-tests"
],
"problemMatcher": []
}
]
}

14
vsc/.vscodeignore Normal file
View File

@ -0,0 +1,14 @@
.vscode/**
.vscode-test/**
out/**
node_modules/**
src/**
.gitignore
.yarnrc
esbuild.js
vsc-extension-quickstart.md
**/tsconfig.json
**/eslint.config.mjs
**/*.map
**/*.ts
**/.vscode-test.*

9
vsc/CHANGELOG.md Normal file
View File

@ -0,0 +1,9 @@
# Change Log
All notable changes to the "ra-aid" extension will be documented in this file.
Check [Keep a Changelog](http://keepachangelog.com/) for recommendations on how to structure this file.
## [Unreleased]
- Initial release

71
vsc/README.md Normal file
View File

@ -0,0 +1,71 @@
# ra-aid README
This is the README for your extension "ra-aid". After writing up a brief description, we recommend including the following sections.
## Features
Describe specific features of your extension including screenshots of your extension in action. Image paths are relative to this README file.
For example if there is an image subfolder under your extension project workspace:
\!\[feature X\]\(images/feature-x.png\)
> Tip: Many popular extensions utilize animations. This is an excellent way to show off your extension! We recommend short, focused animations that are easy to follow.
## Requirements
If you have any requirements or dependencies, add a section describing those and how to install and configure them.
## Extension Settings
Include if your extension adds any VS Code settings through the `contributes.configuration` extension point.
For example:
This extension contributes the following settings:
* `myExtension.enable`: Enable/disable this extension.
* `myExtension.thing`: Set to `blah` to do something.
## Known Issues
Calling out known issues can help limit users opening duplicate issues against your extension.
## Release Notes
Users appreciate release notes as you update your extension.
### 1.0.0
Initial release of ...
### 1.0.1
Fixed issue #.
### 1.1.0
Added features X, Y, and Z.
---
## Following extension guidelines
Ensure that you've read through the extensions guidelines and follow the best practices for creating your extension.
* [Extension Guidelines](https://code.visualstudio.com/api/references/extension-guidelines)
## Working with Markdown
You can author your README using Visual Studio Code. Here are some useful editor keyboard shortcuts:
* Split the editor (`Cmd+\` on macOS or `Ctrl+\` on Windows and Linux).
* Toggle preview (`Shift+Cmd+V` on macOS or `Shift+Ctrl+V` on Windows and Linux).
* Press `Ctrl+Space` (Windows, Linux, macOS) to see a list of Markdown snippets.
## For more information
* [Visual Studio Code's Markdown Support](http://code.visualstudio.com/docs/languages/markdown)
* [Markdown Syntax Reference](https://help.github.com/articles/markdown-basics/)
**Enjoy!**

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.5 KiB

BIN
vsc/assets/RA.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.6 KiB

56
vsc/esbuild.js Normal file
View File

@ -0,0 +1,56 @@
const esbuild = require("esbuild");
const production = process.argv.includes('--production');
const watch = process.argv.includes('--watch');
/**
* @type {import('esbuild').Plugin}
*/
const esbuildProblemMatcherPlugin = {
name: 'esbuild-problem-matcher',
setup(build) {
build.onStart(() => {
console.log('[watch] build started');
});
build.onEnd((result) => {
result.errors.forEach(({ text, location }) => {
console.error(`✘ [ERROR] ${text}`);
console.error(` ${location.file}:${location.line}:${location.column}:`);
});
console.log('[watch] build finished');
});
},
};
async function main() {
const ctx = await esbuild.context({
entryPoints: [
'src/extension.ts'
],
bundle: true,
format: 'cjs',
minify: production,
sourcemap: !production,
sourcesContent: false,
platform: 'node',
outfile: 'dist/extension.js',
external: ['vscode'],
logLevel: 'silent',
plugins: [
/* add to the end of plugins array */
esbuildProblemMatcherPlugin,
],
});
if (watch) {
await ctx.watch();
} else {
await ctx.rebuild();
await ctx.dispose();
}
}
main().catch(e => {
console.error(e);
process.exit(1);
});

28
vsc/eslint.config.mjs Normal file
View File

@ -0,0 +1,28 @@
import typescriptEslint from "@typescript-eslint/eslint-plugin";
import tsParser from "@typescript-eslint/parser";
export default [{
files: ["**/*.ts"],
}, {
plugins: {
"@typescript-eslint": typescriptEslint,
},
languageOptions: {
parser: tsParser,
ecmaVersion: 2022,
sourceType: "module",
},
rules: {
"@typescript-eslint/naming-convention": ["warn", {
selector: "import",
format: ["camelCase", "PascalCase"],
}],
curly: "warn",
eqeqeq: "warn",
"no-throw-literal": "warn",
semi: "warn",
},
}];

5960
vsc/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

67
vsc/package.json Normal file
View File

@ -0,0 +1,67 @@
{
"name": "ra-aid",
"displayName": "RA.Aid",
"description": "Develop software autonomously.",
"version": "0.0.1",
"engines": {
"vscode": "^1.98.0"
},
"categories": [
"Other"
],
"activationEvents": [],
"main": "./dist/extension.js",
"contributes": {
"viewsContainers": {
"activitybar": [
{
"id": "ra-aid-view",
"title": "RA.Aid",
"icon": "assets/RA-white-transp.png"
}
]
},
"views": {
"ra-aid-view": [
{
"type": "webview",
"id": "ra-aid.view",
"name": "RA.Aid"
}
]
},
"commands": [
{
"command": "ra-aid.helloWorld",
"title": "Hello World"
}
]
},
"scripts": {
"vscode:prepublish": "npm run package",
"compile": "npm run check-types && npm run lint && node esbuild.js",
"watch": "npm-run-all -p watch:*",
"watch:esbuild": "node esbuild.js --watch",
"watch:tsc": "tsc --noEmit --watch --project tsconfig.json",
"package": "npm run check-types && npm run lint && node esbuild.js --production",
"compile-tests": "tsc -p . --outDir out",
"watch-tests": "tsc -p . -w --outDir out",
"pretest": "npm run compile-tests && npm run compile && npm run lint",
"check-types": "tsc --noEmit",
"lint": "eslint src",
"test": "vscode-test"
},
"devDependencies": {
"@types/vscode": "^1.98.0",
"@types/mocha": "^10.0.10",
"@types/node": "20.x",
"@typescript-eslint/eslint-plugin": "^8.25.0",
"@typescript-eslint/parser": "^8.25.0",
"eslint": "^9.21.0",
"esbuild": "^0.25.0",
"npm-run-all": "^4.1.5",
"typescript": "^5.7.3",
"@vscode/test-cli": "^0.0.10",
"@vscode/test-electron": "^2.4.1"
}
}

133
vsc/src/extension.ts Normal file
View File

@ -0,0 +1,133 @@
// The module 'vscode' contains the VS Code extensibility API
import * as vscode from 'vscode';
/**
* WebviewViewProvider implementation for the RA.Aid panel
*/
class RAWebviewViewProvider implements vscode.WebviewViewProvider {
constructor(private readonly _extensionUri: vscode.Uri) {}
/**
* Called when a view is first created to initialize the webview
*/
public resolveWebviewView(
webviewView: vscode.WebviewView,
context: vscode.WebviewViewResolveContext,
_token: vscode.CancellationToken
) {
// Set options for the webview
webviewView.webview.options = {
// Enable JavaScript in the webview
enableScripts: true,
// Restrict the webview to only load resources from the extension's directory
localResourceRoots: [this._extensionUri]
};
// Set the HTML content of the webview
webviewView.webview.html = this._getHtmlForWebview(webviewView.webview);
}
/**
* Creates HTML content for the webview with proper security policies
*/
private _getHtmlForWebview(webview: vscode.Webview): string {
// Create a URI to the extension's assets directory
const logoUri = webview.asWebviewUri(vscode.Uri.joinPath(this._extensionUri, 'assets', 'RA.png'));
// Create a URI to the script file
// const scriptUri = webview.asWebviewUri(vscode.Uri.joinPath(this._extensionUri, 'dist', 'webview.js'));
// Use a nonce to whitelist scripts
const nonce = getNonce();
return `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta http-equiv="Content-Security-Policy" content="default-src 'none'; img-src ${webview.cspSource} https:; style-src ${webview.cspSource} 'unsafe-inline'; script-src 'nonce-${nonce}';">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>RA.Aid</title>
<style>
body {
padding: 0;
color: var(--vscode-foreground);
font-size: var(--vscode-font-size);
font-weight: var(--vscode-font-weight);
font-family: var(--vscode-font-family);
background-color: var(--vscode-editor-background);
}
.container {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
padding: 20px;
text-align: center;
}
.logo {
width: 100px;
height: 100px;
margin-bottom: 20px;
}
h1 {
color: var(--vscode-editor-foreground);
font-size: 1.3em;
margin-bottom: 15px;
}
p {
color: var(--vscode-foreground);
margin-bottom: 10px;
}
</style>
</head>
<body>
<div class="container">
<img src="${logoUri}" alt="RA.Aid Logo" class="logo">
<h1>RA.Aid</h1>
<p>Your research and development assistant.</p>
<p>More features coming soon!</p>
</div>
</body>
</html>`;
}
}
/**
* Generates a random nonce for CSP
*/
function getNonce() {
let text = '';
const possible = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789';
for (let i = 0; i < 32; i++) {
text += possible.charAt(Math.floor(Math.random() * possible.length));
}
return text;
}
// This method is called when your extension is activated
export function activate(context: vscode.ExtensionContext) {
// Use the console to output diagnostic information (console.log) and errors (console.error)
console.log('Congratulations, your extension "ra-aid" is now active!');
// Register the WebviewViewProvider
const provider = new RAWebviewViewProvider(context.extensionUri);
const viewRegistration = vscode.window.registerWebviewViewProvider(
'ra-aid.view', // Must match the view id in package.json
provider
);
context.subscriptions.push(viewRegistration);
// The command has been defined in the package.json file
// Now provide the implementation of the command with registerCommand
// The commandId parameter must match the command field in package.json
const disposable = vscode.commands.registerCommand('ra-aid.helloWorld', () => {
// The code you place here will be executed every time your command is executed
// Display a message box to the user
vscode.window.showInformationMessage('Hello World from RA.Aid!');
});
context.subscriptions.push(disposable);
}
// This method is called when your extension is deactivated
export function deactivate() {}

View File

@ -0,0 +1,15 @@
import * as assert from 'assert';
// You can import and use all API from the 'vscode' module
// as well as import your extension to test it
import * as vscode from 'vscode';
// import * as myExtension from '../../extension';
suite('Extension Test Suite', () => {
vscode.window.showInformationMessage('Start all tests.');
test('Sample test', () => {
assert.strictEqual(-1, [1, 2, 3].indexOf(5));
assert.strictEqual(-1, [1, 2, 3].indexOf(0));
});
});

16
vsc/tsconfig.json Normal file
View File

@ -0,0 +1,16 @@
{
"compilerOptions": {
"module": "Node16",
"target": "ES2022",
"lib": [
"ES2022"
],
"sourceMap": true,
"rootDir": "src",
"strict": true, /* enable all strict type-checking options */
/* Additional Checks */
// "noImplicitReturns": true, /* Report error when not all code paths in function return a value. */
// "noFallthroughCasesInSwitch": true, /* Report errors for fallthrough cases in switch statement. */
// "noUnusedParameters": true, /* Report errors on unused parameters. */
}
}

View File

@ -0,0 +1,48 @@
# Welcome to your VS Code Extension
## What's in the folder
* This folder contains all of the files necessary for your extension.
* `package.json` - this is the manifest file in which you declare your extension and command.
* The sample plugin registers a command and defines its title and command name. With this information VS Code can show the command in the command palette. It doesnt yet need to load the plugin.
* `src/extension.ts` - this is the main file where you will provide the implementation of your command.
* The file exports one function, `activate`, which is called the very first time your extension is activated (in this case by executing the command). Inside the `activate` function we call `registerCommand`.
* We pass the function containing the implementation of the command as the second parameter to `registerCommand`.
## Setup
* install the recommended extensions (amodio.tsl-problem-matcher, ms-vscode.extension-test-runner, and dbaeumer.vscode-eslint)
## Get up and running straight away
* Press `F5` to open a new window with your extension loaded.
* Run your command from the command palette by pressing (`Ctrl+Shift+P` or `Cmd+Shift+P` on Mac) and typing `Hello World`.
* Set breakpoints in your code inside `src/extension.ts` to debug your extension.
* Find output from your extension in the debug console.
## Make changes
* You can relaunch the extension from the debug toolbar after changing code in `src/extension.ts`.
* You can also reload (`Ctrl+R` or `Cmd+R` on Mac) the VS Code window with your extension to load your changes.
## Explore the API
* You can open the full set of our API when you open the file `node_modules/@types/vscode/index.d.ts`.
## Run tests
* Install the [Extension Test Runner](https://marketplace.visualstudio.com/items?itemName=ms-vscode.extension-test-runner)
* Run the "watch" task via the **Tasks: Run Task** command. Make sure this is running, or tests might not be discovered.
* Open the Testing view from the activity bar and click the Run Test" button, or use the hotkey `Ctrl/Cmd + ; A`
* See the output of the test result in the Test Results view.
* Make changes to `src/test/extension.test.ts` or create new test files inside the `test` folder.
* The provided test runner will only consider files matching the name pattern `**.test.ts`.
* You can create folders inside the `test` folder to structure your tests any way you want.
## Go further
* Reduce the extension size and improve the startup time by [bundling your extension](https://code.visualstudio.com/api/working-with-extensions/bundling-extension).
* [Publish your extension](https://code.visualstudio.com/api/working-with-extensions/publishing-extension) on the VS Code extension marketplace.
* Automate builds by setting up [Continuous Integration](https://code.visualstudio.com/api/working-with-extensions/continuous-integration).