feat: add session and trajectory models to track application state and events
- Introduce a new `Session` model to store information about each program run, including command line arguments and environment details. - Implement a `Trajectory` model to log significant events and errors during execution, enhancing debugging and monitoring capabilities. - Update various repository classes to support session and trajectory management, allowing for better tracking of user interactions and system behavior. - Modify existing functions to record relevant events in the trajectory, ensuring comprehensive logging of application activities. - Enhance error handling by logging errors to the trajectory, providing insights into failures and system performance. feat(vsc): add initial setup for VS Code extension "ra-aid" with essential files and configurations chore(vsc): create tasks.json for managing build and watch tasks in VS Code chore(vsc): add .vscodeignore to exclude unnecessary files from the extension package docs(vsc): create CHANGELOG.md to document changes and updates for the extension docs(vsc): add README.md with instructions and information about the extension feat(vsc): include esbuild.js for building and bundling the extension chore(vsc): add eslint.config.mjs for TypeScript linting configuration chore(vsc): create package.json with dependencies and scripts for the extension feat(vsc): implement extension logic in src/extension.ts with webview support test(vsc): add initial test suite in extension.test.ts for extension functionality chore(vsc): create tsconfig.json for TypeScript compiler options docs(vsc): add vsc-extension-quickstart.md for guidance on extension development
This commit is contained in:
commit
77a256317a
|
|
@ -14,3 +14,5 @@ __pycache__/
|
|||
.envrc
|
||||
appmap.log
|
||||
*.swp
|
||||
/vsc/node_modules
|
||||
/vsc/dist
|
||||
|
|
|
|||
|
|
@ -64,6 +64,9 @@ from ra_aid.database.repositories.trajectory_repository import (
|
|||
TrajectoryRepositoryManager,
|
||||
get_trajectory_repository,
|
||||
)
|
||||
from ra_aid.database.repositories.session_repository import (
|
||||
SessionRepositoryManager, get_session_repository
|
||||
)
|
||||
from ra_aid.database.repositories.related_files_repository import (
|
||||
RelatedFilesRepositoryManager,
|
||||
)
|
||||
|
|
@ -298,6 +301,11 @@ Examples:
|
|||
action="store_true",
|
||||
help="Display model thinking content extracted from think tags when supported by the model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-cost",
|
||||
action="store_true",
|
||||
help="Display cost information as the agent works",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reasoning-assistance",
|
||||
action="store_true",
|
||||
|
|
@ -538,18 +546,18 @@ def main():
|
|||
env_discovery.discover()
|
||||
env_data = env_discovery.format_markdown()
|
||||
|
||||
with (
|
||||
KeyFactRepositoryManager(db) as key_fact_repo,
|
||||
KeySnippetRepositoryManager(db) as key_snippet_repo,
|
||||
HumanInputRepositoryManager(db) as human_input_repo,
|
||||
ResearchNoteRepositoryManager(db) as research_note_repo,
|
||||
RelatedFilesRepositoryManager() as related_files_repo,
|
||||
TrajectoryRepositoryManager(db) as trajectory_repo,
|
||||
WorkLogRepositoryManager() as work_log_repo,
|
||||
ConfigRepositoryManager(config) as config_repo,
|
||||
EnvInvManager(env_data) as env_inv,
|
||||
):
|
||||
with SessionRepositoryManager(db) as session_repo, \
|
||||
KeyFactRepositoryManager(db) as key_fact_repo, \
|
||||
KeySnippetRepositoryManager(db) as key_snippet_repo, \
|
||||
HumanInputRepositoryManager(db) as human_input_repo, \
|
||||
ResearchNoteRepositoryManager(db) as research_note_repo, \
|
||||
RelatedFilesRepositoryManager() as related_files_repo, \
|
||||
TrajectoryRepositoryManager(db) as trajectory_repo, \
|
||||
WorkLogRepositoryManager() as work_log_repo, \
|
||||
ConfigRepositoryManager(config) as config_repo, \
|
||||
EnvInvManager(env_data) as env_inv:
|
||||
# This initializes all repositories and makes them available via their respective get methods
|
||||
logger.debug("Initialized SessionRepository")
|
||||
logger.debug("Initialized KeyFactRepository")
|
||||
logger.debug("Initialized KeySnippetRepository")
|
||||
logger.debug("Initialized HumanInputRepository")
|
||||
|
|
@ -560,6 +568,10 @@ def main():
|
|||
logger.debug("Initialized ConfigRepository")
|
||||
logger.debug("Initialized Environment Inventory")
|
||||
|
||||
# Create a new session for this program run
|
||||
logger.debug("Initializing new session")
|
||||
session_repo.create_session()
|
||||
|
||||
# Check dependencies before proceeding
|
||||
check_dependencies()
|
||||
|
||||
|
|
@ -611,6 +623,7 @@ def main():
|
|||
)
|
||||
config_repo.set("web_research_enabled", web_research_enabled)
|
||||
config_repo.set("show_thoughts", args.show_thoughts)
|
||||
config_repo.set("show_cost", args.show_cost)
|
||||
config_repo.set("force_reasoning_assistance", args.reasoning_assistance)
|
||||
config_repo.set(
|
||||
"disable_reasoning_assistance", args.no_reasoning_assistance
|
||||
|
|
@ -636,11 +649,41 @@ def main():
|
|||
)
|
||||
|
||||
if args.research_only:
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
error_message = "Chat mode cannot be used with --research-only"
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"display_title": "Error",
|
||||
"error_message": error_message,
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=error_message,
|
||||
)
|
||||
except Exception as traj_error:
|
||||
# Swallow exception to avoid recursion
|
||||
logger.debug(f"Error recording trajectory: {traj_error}")
|
||||
pass
|
||||
print_error("Chat mode cannot be used with --research-only")
|
||||
sys.exit(1)
|
||||
|
||||
print_stage_header("Chat Mode")
|
||||
|
||||
# Record stage transition in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"stage": "chat_mode",
|
||||
"display_title": "Chat Mode",
|
||||
},
|
||||
record_type="stage_transition",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
|
||||
# Get project info
|
||||
try:
|
||||
project_info = get_project_info(".", file_limit=2000)
|
||||
|
|
@ -690,12 +733,9 @@ def main():
|
|||
config_repo.set("expert_model", args.expert_model)
|
||||
config_repo.set("temperature", args.temperature)
|
||||
config_repo.set("show_thoughts", args.show_thoughts)
|
||||
config_repo.set(
|
||||
"force_reasoning_assistance", args.reasoning_assistance
|
||||
)
|
||||
config_repo.set(
|
||||
"disable_reasoning_assistance", args.no_reasoning_assistance
|
||||
)
|
||||
config_repo.set("show_cost", args.show_cost)
|
||||
config_repo.set("force_reasoning_assistance", args.reasoning_assistance)
|
||||
config_repo.set("disable_reasoning_assistance", args.no_reasoning_assistance)
|
||||
|
||||
# Set modification tools based on use_aider flag
|
||||
set_modification_tools(args.use_aider)
|
||||
|
|
@ -737,6 +777,24 @@ def main():
|
|||
|
||||
# Validate message is provided
|
||||
if not args.message:
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
error_message = "--message is required"
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"display_title": "Error",
|
||||
"error_message": error_message,
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=error_message,
|
||||
)
|
||||
except Exception as traj_error:
|
||||
# Swallow exception to avoid recursion
|
||||
logger.debug(f"Error recording trajectory: {traj_error}")
|
||||
pass
|
||||
print_error("--message is required")
|
||||
sys.exit(1)
|
||||
|
||||
|
|
@ -806,6 +864,18 @@ def main():
|
|||
# Run research stage
|
||||
print_stage_header("Research Stage")
|
||||
|
||||
# Record stage transition in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"stage": "research_stage",
|
||||
"display_title": "Research Stage",
|
||||
},
|
||||
record_type="stage_transition",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
|
||||
# Initialize research model with potential overrides
|
||||
research_provider = args.research_provider or args.provider
|
||||
research_model_name = args.research_model or args.model
|
||||
|
|
|
|||
|
|
@ -462,6 +462,38 @@ class CiaynAgent:
|
|||
error_msg = f"Error: {str(e)} \n Could not execute code: {code}"
|
||||
tool_name = self.extract_tool_name(code)
|
||||
logger.info(f"Tool execution failed for `{tool_name}`: {str(e)}")
|
||||
|
||||
# Record error in trajectory
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.database.connection import get_db
|
||||
|
||||
# Create repositories directly
|
||||
trajectory_repo = TrajectoryRepository(get_db())
|
||||
human_input_repo = HumanInputRepository(get_db())
|
||||
human_input_id = human_input_repo.get_most_recent_id()
|
||||
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": f"Tool execution failed for `{tool_name}`:\nError: {str(e)}",
|
||||
"display_title": "Tool Error",
|
||||
"code": code,
|
||||
"tool_name": tool_name
|
||||
},
|
||||
record_type="tool_execution",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=str(e),
|
||||
error_type="ToolExecutionError",
|
||||
tool_name=tool_name,
|
||||
tool_parameters={"code": code}
|
||||
)
|
||||
except Exception as trajectory_error:
|
||||
# Just log and continue if there's an error in trajectory recording
|
||||
logger.error(f"Error recording trajectory for tool error display: {trajectory_error}")
|
||||
|
||||
print_warning(f"Tool execution failed for `{tool_name}`:\nError: {str(e)}\n\nCode:\n\n````\n{code}\n````", title="Tool Error")
|
||||
raise ToolExecutionError(
|
||||
error_msg, base_message=msg, tool_name=tool_name
|
||||
|
|
@ -495,6 +527,36 @@ class CiaynAgent:
|
|||
if not fallback_response:
|
||||
self.chat_history.append(err_msg)
|
||||
logger.info(f"Tool fallback was attempted but did not succeed. Original error: {str(e)}")
|
||||
|
||||
# Record error in trajectory
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.database.connection import get_db
|
||||
|
||||
# Create repositories directly
|
||||
trajectory_repo = TrajectoryRepository(get_db())
|
||||
human_input_repo = HumanInputRepository(get_db())
|
||||
human_input_id = human_input_repo.get_most_recent_id()
|
||||
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": f"Tool fallback was attempted but did not succeed. Original error: {str(e)}",
|
||||
"display_title": "Fallback Failed",
|
||||
"tool_name": e.tool_name if hasattr(e, "tool_name") else "unknown_tool"
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=str(e),
|
||||
error_type="FallbackFailedError",
|
||||
tool_name=e.tool_name if hasattr(e, "tool_name") else "unknown_tool"
|
||||
)
|
||||
except Exception as trajectory_error:
|
||||
# Just log and continue if there's an error in trajectory recording
|
||||
logger.error(f"Error recording trajectory for fallback failed warning: {trajectory_error}")
|
||||
|
||||
print_warning(f"Tool fallback was attempted but did not succeed. Original error: {str(e)}", title="Fallback Failed")
|
||||
return ""
|
||||
|
||||
|
|
@ -595,6 +657,35 @@ class CiaynAgent:
|
|||
matches = re.findall(pattern, response, re.DOTALL)
|
||||
if len(matches) == 0:
|
||||
logger.info("Failed to extract a valid tool call from the model's response.")
|
||||
|
||||
# Record error in trajectory
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.database.connection import get_db
|
||||
|
||||
# Create repositories directly
|
||||
trajectory_repo = TrajectoryRepository(get_db())
|
||||
human_input_repo = HumanInputRepository(get_db())
|
||||
human_input_id = human_input_repo.get_most_recent_id()
|
||||
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": "Failed to extract a valid tool call from the model's response.",
|
||||
"display_title": "Extraction Failed",
|
||||
"code": code
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message="Failed to extract a valid tool call from the model's response.",
|
||||
error_type="ExtractionError"
|
||||
)
|
||||
except Exception as trajectory_error:
|
||||
# Just log and continue if there's an error in trajectory recording
|
||||
logger.error(f"Error recording trajectory for extraction error display: {trajectory_error}")
|
||||
|
||||
print_warning("Failed to extract a valid tool call from the model's response.", title="Extraction Failed")
|
||||
raise ToolExecutionError("Failed to extract tool call")
|
||||
ma = matches[0][0].strip()
|
||||
|
|
@ -647,6 +738,36 @@ class CiaynAgent:
|
|||
|
||||
warning_message = f"The model returned an empty response (attempt {empty_response_count} of {max_empty_responses}). Requesting the model to make a valid tool call."
|
||||
logger.info(warning_message)
|
||||
|
||||
# Record warning in trajectory
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.database.connection import get_db_connection
|
||||
|
||||
# Create repositories directly
|
||||
trajectory_repo = TrajectoryRepository(get_db_connection())
|
||||
human_input_repo = HumanInputRepository(get_db_connection())
|
||||
human_input_id = human_input_repo.get_most_recent_id()
|
||||
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"warning_message": warning_message,
|
||||
"display_title": "Empty Response",
|
||||
"attempt": empty_response_count,
|
||||
"max_attempts": max_empty_responses
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=warning_message,
|
||||
error_type="EmptyResponseWarning"
|
||||
)
|
||||
except Exception as trajectory_error:
|
||||
# Just log and continue if there's an error in trajectory recording
|
||||
logger.error(f"Error recording trajectory for empty response warning: {trajectory_error}")
|
||||
|
||||
print_warning(warning_message, title="Empty Response")
|
||||
|
||||
if empty_response_count >= max_empty_responses:
|
||||
|
|
@ -658,6 +779,36 @@ class CiaynAgent:
|
|||
|
||||
error_message = "The agent has crashed after multiple failed attempts to generate a valid tool call."
|
||||
logger.error(error_message)
|
||||
|
||||
# Record error in trajectory
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.database.connection import get_db_connection
|
||||
|
||||
# Create repositories directly
|
||||
trajectory_repo = TrajectoryRepository(get_db_connection())
|
||||
human_input_repo = HumanInputRepository(get_db_connection())
|
||||
human_input_id = human_input_repo.get_most_recent_id()
|
||||
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": error_message,
|
||||
"display_title": "Agent Crashed",
|
||||
"crash_reason": crash_message,
|
||||
"attempts": empty_response_count
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=error_message,
|
||||
error_type="AgentCrashError"
|
||||
)
|
||||
except Exception as trajectory_error:
|
||||
# Just log and continue if there's an error in trajectory recording
|
||||
logger.error(f"Error recording trajectory for agent crash: {trajectory_error}")
|
||||
|
||||
print_error(error_message)
|
||||
|
||||
yield self._create_error_chunk(crash_message)
|
||||
|
|
|
|||
|
|
@ -46,6 +46,10 @@ from ra_aid.fallback_handler import FallbackHandler
|
|||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
|
||||
from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command
|
||||
from ra_aid.database.repositories.human_input_repository import (
|
||||
get_human_input_repository,
|
||||
)
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.anthropic_token_limiter import sonnet_35_state_modifier, state_modifier, get_model_token_limit
|
||||
|
||||
|
|
@ -284,9 +288,23 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
|
|||
|
||||
logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e))
|
||||
delay = base_delay * (2**attempt)
|
||||
print_error(
|
||||
f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
|
||||
error_message = f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
|
||||
|
||||
# Record error in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": error_message,
|
||||
"display_title": "Error",
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=error_message
|
||||
)
|
||||
|
||||
print_error(error_message)
|
||||
start = time.monotonic()
|
||||
while time.monotonic() - start < delay:
|
||||
check_interrupt()
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from ra_aid import agent_utils
|
|||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.llm import initialize_llm
|
||||
from ra_aid.prompts.key_facts_gc_prompts import KEY_FACTS_GC_PROMPT
|
||||
from ra_aid.tools.memory import log_work_event
|
||||
|
|
@ -82,6 +83,22 @@ def delete_key_facts(fact_ids: List[int]) -> str:
|
|||
if deleted_facts:
|
||||
deleted_msg = "Successfully deleted facts:\n" + "\n".join([f"- #{fact_id}: {content}" for fact_id, content in deleted_facts])
|
||||
result_parts.append(deleted_msg)
|
||||
# Record GC operation in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"deleted_facts": deleted_facts,
|
||||
"display_title": "Facts Deleted",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_facts_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(
|
||||
Panel(Markdown(deleted_msg), title="Facts Deleted", border_style="green")
|
||||
)
|
||||
|
|
@ -89,6 +106,22 @@ def delete_key_facts(fact_ids: List[int]) -> str:
|
|||
if protected_facts:
|
||||
protected_msg = "Protected facts (associated with current request):\n" + "\n".join([f"- #{fact_id}: {content}" for fact_id, content in protected_facts])
|
||||
result_parts.append(protected_msg)
|
||||
# Record GC operation in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"protected_facts": protected_facts,
|
||||
"display_title": "Facts Protected",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_facts_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(
|
||||
Panel(Markdown(protected_msg), title="Facts Protected", border_style="blue")
|
||||
)
|
||||
|
|
@ -120,10 +153,44 @@ def run_key_facts_gc_agent() -> None:
|
|||
fact_count = len(facts)
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||
# Record GC error in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error": str(e),
|
||||
"display_title": "GC Error",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_facts_gc_agent",
|
||||
is_error=True,
|
||||
error_message=str(e),
|
||||
error_type="Repository Error"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(Panel(f"Error: {str(e)}", title="🗑 GC Error", border_style="red"))
|
||||
return # Exit the function if we can't access the repository
|
||||
|
||||
# Display status panel with fact count included
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"fact_count": fact_count,
|
||||
"display_title": "Garbage Collection",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_facts_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(Panel(f"Gathering my thoughts...\nCurrent number of key facts: {fact_count}", title="🗑 Garbage Collection"))
|
||||
|
||||
# Only run the agent if we actually have facts to clean
|
||||
|
|
@ -185,6 +252,24 @@ def run_key_facts_gc_agent() -> None:
|
|||
# Show info panel with updated count and protected facts count
|
||||
protected_count = len(protected_facts)
|
||||
if protected_count > 0:
|
||||
# Record GC completion in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"original_count": fact_count,
|
||||
"updated_count": updated_count,
|
||||
"protected_count": protected_count,
|
||||
"display_title": "GC Complete",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_facts_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
f"Cleaned key facts: {fact_count} → {updated_count}\nProtected facts (associated with current request): {protected_count}",
|
||||
|
|
@ -192,6 +277,24 @@ def run_key_facts_gc_agent() -> None:
|
|||
)
|
||||
)
|
||||
else:
|
||||
# Record GC completion in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"original_count": fact_count,
|
||||
"updated_count": updated_count,
|
||||
"protected_count": 0,
|
||||
"display_title": "GC Complete",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_facts_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
f"Cleaned key facts: {fact_count} → {updated_count}",
|
||||
|
|
@ -199,6 +302,40 @@ def run_key_facts_gc_agent() -> None:
|
|||
)
|
||||
)
|
||||
else:
|
||||
# Record GC info in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"protected_count": len(protected_facts),
|
||||
"message": "All facts are protected",
|
||||
"display_title": "GC Info",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_facts_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(Panel(f"All {len(protected_facts)} facts are associated with the current request and protected from deletion.", title="🗑 GC Info"))
|
||||
else:
|
||||
# Record GC info in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"fact_count": 0,
|
||||
"message": "No key facts to clean",
|
||||
"display_title": "GC Info",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_facts_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(Panel("No key facts to clean.", title="🗑 GC Info"))
|
||||
|
|
@ -18,6 +18,7 @@ from ra_aid import agent_utils
|
|||
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.llm import initialize_llm
|
||||
from ra_aid.prompts.key_snippets_gc_prompts import KEY_SNIPPETS_GC_PROMPT
|
||||
from ra_aid.tools.memory import log_work_event
|
||||
|
|
@ -65,6 +66,23 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
|
|||
success = get_key_snippet_repository().delete(snippet_id)
|
||||
if success:
|
||||
success_msg = f"Successfully deleted snippet #{snippet_id} from {filepath}"
|
||||
# Record GC operation in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"deleted_snippet_id": snippet_id,
|
||||
"filepath": filepath,
|
||||
"display_title": "Snippet Deleted",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_snippets_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(success_msg), title="Snippet Deleted", border_style="green"
|
||||
|
|
@ -86,6 +104,22 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
|
|||
if protected_snippets:
|
||||
protected_msg = "Protected snippets (associated with current request):\n" + "\n".join([f"- #{snippet_id}: {filepath}" for snippet_id, filepath in protected_snippets])
|
||||
result_parts.append(protected_msg)
|
||||
# Record GC operation in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"protected_snippets": protected_snippets,
|
||||
"display_title": "Snippets Protected",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_snippets_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(
|
||||
Panel(Markdown(protected_msg), title="Snippets Protected", border_style="blue")
|
||||
)
|
||||
|
|
@ -116,6 +150,21 @@ def run_key_snippets_gc_agent() -> None:
|
|||
snippet_count = len(snippets)
|
||||
|
||||
# Display status panel with snippet count included
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"snippet_count": snippet_count,
|
||||
"display_title": "Garbage Collection",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_snippets_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(Panel(f"Gathering my thoughts...\nCurrent number of key snippets: {snippet_count}", title="🗑 Garbage Collection"))
|
||||
|
||||
# Only run the agent if we actually have snippets to clean
|
||||
|
|
@ -185,6 +234,24 @@ def run_key_snippets_gc_agent() -> None:
|
|||
# Show info panel with updated count and protected snippets count
|
||||
protected_count = len(protected_snippets)
|
||||
if protected_count > 0:
|
||||
# Record GC completion in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"original_count": snippet_count,
|
||||
"updated_count": updated_count,
|
||||
"protected_count": protected_count,
|
||||
"display_title": "GC Complete",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_snippets_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
f"Cleaned key snippets: {snippet_count} → {updated_count}\nProtected snippets (associated with current request): {protected_count}",
|
||||
|
|
@ -192,6 +259,24 @@ def run_key_snippets_gc_agent() -> None:
|
|||
)
|
||||
)
|
||||
else:
|
||||
# Record GC completion in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"original_count": snippet_count,
|
||||
"updated_count": updated_count,
|
||||
"protected_count": 0,
|
||||
"display_title": "GC Complete",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_snippets_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
f"Cleaned key snippets: {snippet_count} → {updated_count}",
|
||||
|
|
@ -199,6 +284,40 @@ def run_key_snippets_gc_agent() -> None:
|
|||
)
|
||||
)
|
||||
else:
|
||||
# Record GC info in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"protected_count": len(protected_snippets),
|
||||
"message": "All snippets are protected",
|
||||
"display_title": "GC Info",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_snippets_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(Panel(f"All {len(protected_snippets)} snippets are associated with the current request and protected from deletion.", title="🗑 GC Info"))
|
||||
else:
|
||||
# Record GC info in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"snippet_count": 0,
|
||||
"message": "No key snippets to clean",
|
||||
"display_title": "GC Info",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_snippets_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(Panel("No key snippets to clean.", title="🗑 GC Info"))
|
||||
|
|
@ -24,6 +24,8 @@ from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_
|
|||
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.database.repositories.work_log_repository import get_work_log_repository
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
from ra_aid.env_inv_context import get_env_inv
|
||||
from ra_aid.exceptions import AgentInterrupt
|
||||
from ra_aid.llm import initialize_expert_llm
|
||||
|
|
@ -156,6 +158,18 @@ def run_planning_agent(
|
|||
# Display the planning stage header before any reasoning assistance
|
||||
print_stage_header("Planning Stage")
|
||||
|
||||
# Record stage transition in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"stage": "planning_stage",
|
||||
"display_title": "Planning Stage",
|
||||
},
|
||||
record_type="stage_transition",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
|
||||
# Initialize expert guidance section
|
||||
expert_guidance = ""
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from ra_aid.agent_utils import create_agent, run_agent_with_retry
|
|||
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.llm import initialize_llm
|
||||
from ra_aid.model_formatters.research_notes_formatter import format_research_note
|
||||
from ra_aid.tools.memory import log_work_event
|
||||
|
|
@ -84,6 +85,22 @@ def delete_research_notes(note_ids: List[int]) -> str:
|
|||
if deleted_notes:
|
||||
deleted_msg = "Successfully deleted research notes:\n" + "\n".join([f"- #{note_id}: {content[:100]}..." if len(content) > 100 else f"- #{note_id}: {content}" for note_id, content in deleted_notes])
|
||||
result_parts.append(deleted_msg)
|
||||
# Record GC operation in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"deleted_notes": deleted_notes,
|
||||
"display_title": "Research Notes Deleted",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="research_notes_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(
|
||||
Panel(Markdown(deleted_msg), title="Research Notes Deleted", border_style="green")
|
||||
)
|
||||
|
|
@ -91,6 +108,22 @@ def delete_research_notes(note_ids: List[int]) -> str:
|
|||
if protected_notes:
|
||||
protected_msg = "Protected research notes (associated with current request):\n" + "\n".join([f"- #{note_id}: {content[:100]}..." if len(content) > 100 else f"- #{note_id}: {content}" for note_id, content in protected_notes])
|
||||
result_parts.append(protected_msg)
|
||||
# Record GC operation in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"protected_notes": protected_notes,
|
||||
"display_title": "Research Notes Protected",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="research_notes_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(
|
||||
Panel(Markdown(protected_msg), title="Research Notes Protected", border_style="blue")
|
||||
)
|
||||
|
|
@ -125,10 +158,44 @@ def run_research_notes_gc_agent(threshold: int = 30) -> None:
|
|||
note_count = len(notes)
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access research note repository: {str(e)}")
|
||||
# Record GC error in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error": str(e),
|
||||
"display_title": "GC Error",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="research_notes_gc_agent",
|
||||
is_error=True,
|
||||
error_message=str(e),
|
||||
error_type="Repository Error"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(Panel(f"Error: {str(e)}", title="🗑 GC Error", border_style="red"))
|
||||
return # Exit the function if we can't access the repository
|
||||
|
||||
# Display status panel with note count included
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"note_count": note_count,
|
||||
"display_title": "Garbage Collection",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="research_notes_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(Panel(f"Gathering my thoughts...\nCurrent number of research notes: {note_count}", title="🗑 Garbage Collection"))
|
||||
|
||||
# Only run the agent if we actually have notes to clean and we're over the threshold
|
||||
|
|
@ -235,6 +302,24 @@ Remember: Your goal is to maintain a concise, high-value collection of research
|
|||
# Show info panel with updated count and protected notes count
|
||||
protected_count = len(protected_notes)
|
||||
if protected_count > 0:
|
||||
# Record GC completion in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"original_count": note_count,
|
||||
"updated_count": updated_count,
|
||||
"protected_count": protected_count,
|
||||
"display_title": "GC Complete",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="research_notes_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
f"Cleaned research notes: {note_count} → {updated_count}\nProtected notes (associated with current request): {protected_count}",
|
||||
|
|
@ -242,6 +327,24 @@ Remember: Your goal is to maintain a concise, high-value collection of research
|
|||
)
|
||||
)
|
||||
else:
|
||||
# Record GC completion in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"original_count": note_count,
|
||||
"updated_count": updated_count,
|
||||
"protected_count": 0,
|
||||
"display_title": "GC Complete",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="research_notes_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
f"Cleaned research notes: {note_count} → {updated_count}",
|
||||
|
|
@ -249,6 +352,41 @@ Remember: Your goal is to maintain a concise, high-value collection of research
|
|||
)
|
||||
)
|
||||
else:
|
||||
# Record GC info in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"protected_count": len(protected_notes),
|
||||
"message": "All research notes are protected",
|
||||
"display_title": "GC Info",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="research_notes_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(Panel(f"All {len(protected_notes)} research notes are associated with the current request and protected from deletion.", title="🗑 GC Info"))
|
||||
else:
|
||||
# Record GC info in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"note_count": note_count,
|
||||
"threshold": threshold,
|
||||
"message": "Below threshold - no cleanup needed",
|
||||
"display_title": "GC Info",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="research_notes_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(Panel(f"Research notes count ({note_count}) is below threshold ({threshold}). No cleanup needed.", title="🗑 GC Info"))
|
||||
|
|
@ -7,6 +7,7 @@ FALLBACK_TOOL_MODEL_LIMIT = 5
|
|||
RETRY_FALLBACK_COUNT = 3
|
||||
DEFAULT_TEST_CMD_TIMEOUT = 60 * 5 # 5 minutes in seconds
|
||||
DEFAULT_MODEL="claude-3-7-sonnet-20250219"
|
||||
DEFAULT_SHOW_COST = False
|
||||
|
||||
|
||||
VALID_PROVIDERS = [
|
||||
|
|
|
|||
|
|
@ -1,6 +1,10 @@
|
|||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
from typing import Optional
|
||||
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
|
||||
console = Console()
|
||||
|
||||
|
|
|
|||
|
|
@ -6,14 +6,18 @@ from rich.panel import Panel
|
|||
|
||||
from ra_aid.exceptions import ToolExecutionError
|
||||
from ra_aid.callbacks.anthropic_callback_handler import AnthropicCallbackHandler
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.config import DEFAULT_SHOW_COST
|
||||
|
||||
# Import shared console instance
|
||||
from .formatting import console
|
||||
|
||||
|
||||
def get_cost_subtitle(cost_cb: Optional[AnthropicCallbackHandler]) -> Optional[str]:
|
||||
"""Generate a subtitle with cost information if a callback is provided."""
|
||||
if cost_cb:
|
||||
"""Generate a subtitle with cost information if a callback is provided and show_cost is enabled."""
|
||||
# Only show cost information if both cost_cb is provided AND show_cost is True
|
||||
show_cost = get_config_repository().get("show_cost", DEFAULT_SHOW_COST)
|
||||
if cost_cb and show_cost:
|
||||
return f"Cost: ${cost_cb.total_cost:.6f} | Tokens: {cost_cb.total_tokens}"
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -42,8 +42,8 @@ def initialize_database():
|
|||
# to avoid circular imports
|
||||
# Note: This import needs to be here, not at the top level
|
||||
try:
|
||||
from ra_aid.database.models import KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory
|
||||
db.create_tables([KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory], safe=True)
|
||||
from ra_aid.database.models import KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory, Session
|
||||
db.create_tables([KeyFact, KeySnippet, HumanInput, ResearchNote, Trajectory, Session], safe=True)
|
||||
logger.debug("Ensured database tables exist")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating tables: {str(e)}")
|
||||
|
|
@ -99,6 +99,25 @@ class BaseModel(peewee.Model):
|
|||
raise
|
||||
|
||||
|
||||
class Session(BaseModel):
|
||||
"""
|
||||
Model representing a session stored in the database.
|
||||
|
||||
Sessions track information about each program run, providing a way to group
|
||||
related records like human inputs, trajectories, and key facts.
|
||||
|
||||
Each session record captures details about when the program was started,
|
||||
what command line arguments were used, and environment information.
|
||||
"""
|
||||
start_time = peewee.DateTimeField(default=datetime.datetime.now)
|
||||
command_line = peewee.TextField(null=True)
|
||||
program_version = peewee.TextField(null=True)
|
||||
machine_info = peewee.TextField(null=True) # JSON-encoded machine information
|
||||
|
||||
class Meta:
|
||||
table_name = "session"
|
||||
|
||||
|
||||
class HumanInput(BaseModel):
|
||||
"""
|
||||
Model representing human input stored in the database.
|
||||
|
|
@ -109,6 +128,7 @@ class HumanInput(BaseModel):
|
|||
"""
|
||||
content = peewee.TextField()
|
||||
source = peewee.TextField() # 'cli', 'chat', or 'hil'
|
||||
session = peewee.ForeignKeyField(Session, backref='human_inputs', null=True)
|
||||
# created_at and updated_at are inherited from BaseModel
|
||||
|
||||
class Meta:
|
||||
|
|
@ -124,6 +144,7 @@ class KeyFact(BaseModel):
|
|||
"""
|
||||
content = peewee.TextField()
|
||||
human_input = peewee.ForeignKeyField(HumanInput, backref='key_facts', null=True)
|
||||
session = peewee.ForeignKeyField(Session, backref='key_facts', null=True)
|
||||
# created_at and updated_at are inherited from BaseModel
|
||||
|
||||
class Meta:
|
||||
|
|
@ -143,6 +164,7 @@ class KeySnippet(BaseModel):
|
|||
snippet = peewee.TextField()
|
||||
description = peewee.TextField(null=True)
|
||||
human_input = peewee.ForeignKeyField(HumanInput, backref='key_snippets', null=True)
|
||||
session = peewee.ForeignKeyField(Session, backref='key_snippets', null=True)
|
||||
# created_at and updated_at are inherited from BaseModel
|
||||
|
||||
class Meta:
|
||||
|
|
@ -159,6 +181,7 @@ class ResearchNote(BaseModel):
|
|||
"""
|
||||
content = peewee.TextField()
|
||||
human_input = peewee.ForeignKeyField(HumanInput, backref='research_notes', null=True)
|
||||
session = peewee.ForeignKeyField(Session, backref='research_notes', null=True)
|
||||
# created_at and updated_at are inherited from BaseModel
|
||||
|
||||
class Meta:
|
||||
|
|
@ -182,17 +205,18 @@ class Trajectory(BaseModel):
|
|||
- Error information (when a tool execution fails)
|
||||
"""
|
||||
human_input = peewee.ForeignKeyField(HumanInput, backref='trajectories', null=True)
|
||||
tool_name = peewee.TextField()
|
||||
tool_parameters = peewee.TextField() # JSON-encoded parameters
|
||||
tool_result = peewee.TextField() # JSON-encoded result
|
||||
step_data = peewee.TextField() # JSON-encoded UI rendering data
|
||||
record_type = peewee.TextField() # Type of trajectory record
|
||||
tool_name = peewee.TextField(null=True)
|
||||
tool_parameters = peewee.TextField(null=True) # JSON-encoded parameters
|
||||
tool_result = peewee.TextField(null=True) # JSON-encoded result
|
||||
step_data = peewee.TextField(null=True) # JSON-encoded UI rendering data
|
||||
record_type = peewee.TextField(null=True) # Type of trajectory record
|
||||
cost = peewee.FloatField(null=True) # Placeholder for cost tracking
|
||||
tokens = peewee.IntegerField(null=True) # Placeholder for token usage tracking
|
||||
is_error = peewee.BooleanField(default=False) # Flag indicating if this record represents an error
|
||||
error_message = peewee.TextField(null=True) # The error message
|
||||
error_type = peewee.TextField(null=True) # The type/class of the error
|
||||
error_details = peewee.TextField(null=True) # Additional error details like stack traces or context
|
||||
session = peewee.ForeignKeyField(Session, backref='trajectories', null=True)
|
||||
# created_at and updated_at are inherited from BaseModel
|
||||
|
||||
class Meta:
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ class ConfigRepository:
|
|||
FALLBACK_TOOL_MODEL_LIMIT,
|
||||
RETRY_FALLBACK_COUNT,
|
||||
DEFAULT_TEST_CMD_TIMEOUT,
|
||||
DEFAULT_SHOW_COST,
|
||||
VALID_PROVIDERS,
|
||||
)
|
||||
|
||||
|
|
@ -42,6 +43,7 @@ class ConfigRepository:
|
|||
"fallback_tool_model_limit": FALLBACK_TOOL_MODEL_LIMIT,
|
||||
"retry_fallback_count": RETRY_FALLBACK_COUNT,
|
||||
"test_cmd_timeout": DEFAULT_TEST_CMD_TIMEOUT,
|
||||
"show_cost": DEFAULT_SHOW_COST,
|
||||
"valid_providers": VALID_PROVIDERS,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,243 @@
|
|||
"""
|
||||
Session repository implementation for database access.
|
||||
|
||||
This module provides a repository implementation for the Session model,
|
||||
following the repository pattern for data access abstraction. It handles
|
||||
operations for storing and retrieving application session information.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any
|
||||
import contextvars
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import peewee
|
||||
|
||||
from ra_aid.database.models import Session
|
||||
from ra_aid.__version__ import __version__
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Create contextvar to hold the SessionRepository instance
|
||||
session_repo_var = contextvars.ContextVar("session_repo", default=None)
|
||||
|
||||
|
||||
class SessionRepositoryManager:
|
||||
"""
|
||||
Context manager for SessionRepository.
|
||||
|
||||
This class provides a context manager interface for SessionRepository,
|
||||
using the contextvars approach for thread safety.
|
||||
|
||||
Example:
|
||||
with DatabaseManager() as db:
|
||||
with SessionRepositoryManager(db) as repo:
|
||||
# Use the repository
|
||||
session = repo.create_session()
|
||||
current_session = repo.get_current_session()
|
||||
"""
|
||||
|
||||
def __init__(self, db):
|
||||
"""
|
||||
Initialize the SessionRepositoryManager.
|
||||
|
||||
Args:
|
||||
db: Database connection to use (required)
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
def __enter__(self) -> 'SessionRepository':
|
||||
"""
|
||||
Initialize the SessionRepository and return it.
|
||||
|
||||
Returns:
|
||||
SessionRepository: The initialized repository
|
||||
"""
|
||||
repo = SessionRepository(self.db)
|
||||
session_repo_var.set(repo)
|
||||
return repo
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[type],
|
||||
exc_val: Optional[Exception],
|
||||
exc_tb: Optional[object],
|
||||
) -> None:
|
||||
"""
|
||||
Reset the repository when exiting the context.
|
||||
|
||||
Args:
|
||||
exc_type: The exception type if an exception was raised
|
||||
exc_val: The exception value if an exception was raised
|
||||
exc_tb: The traceback if an exception was raised
|
||||
"""
|
||||
# Reset the contextvar to None
|
||||
session_repo_var.set(None)
|
||||
|
||||
# Don't suppress exceptions
|
||||
return False
|
||||
|
||||
|
||||
def get_session_repository() -> 'SessionRepository':
|
||||
"""
|
||||
Get the current SessionRepository instance.
|
||||
|
||||
Returns:
|
||||
SessionRepository: The current repository instance
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no repository has been initialized with SessionRepositoryManager
|
||||
"""
|
||||
repo = session_repo_var.get()
|
||||
if repo is None:
|
||||
raise RuntimeError(
|
||||
"No SessionRepository available. "
|
||||
"Make sure to initialize one with SessionRepositoryManager first."
|
||||
)
|
||||
return repo
|
||||
|
||||
|
||||
class SessionRepository:
|
||||
"""
|
||||
Repository for handling Session records in the database.
|
||||
|
||||
This class provides methods for creating, retrieving, and managing Session records.
|
||||
It abstracts away the database operations and provides a clean interface for working
|
||||
with Session entities.
|
||||
"""
|
||||
|
||||
def __init__(self, db):
|
||||
"""
|
||||
Initialize the SessionRepository.
|
||||
|
||||
Args:
|
||||
db: Database connection to use (required)
|
||||
"""
|
||||
if db is None:
|
||||
raise ValueError("Database connection is required for SessionRepository")
|
||||
self.db = db
|
||||
self.current_session = None
|
||||
|
||||
def create_session(self, metadata: Optional[Dict[str, Any]] = None) -> Session:
|
||||
"""
|
||||
Create a new session record in the database.
|
||||
|
||||
Args:
|
||||
metadata: Optional dictionary of additional metadata to store with the session
|
||||
|
||||
Returns:
|
||||
Session: The newly created session instance
|
||||
|
||||
Raises:
|
||||
peewee.DatabaseError: If there's an error creating the record
|
||||
"""
|
||||
try:
|
||||
# Get command line arguments
|
||||
command_line = " ".join(sys.argv)
|
||||
|
||||
# Get program version
|
||||
program_version = __version__
|
||||
|
||||
# JSON encode metadata if provided
|
||||
machine_info = json.dumps(metadata) if metadata is not None else None
|
||||
|
||||
session = Session.create(
|
||||
start_time=datetime.datetime.now(),
|
||||
command_line=command_line,
|
||||
program_version=program_version,
|
||||
machine_info=machine_info
|
||||
)
|
||||
|
||||
# Store the current session
|
||||
self.current_session = session
|
||||
|
||||
logger.debug(f"Created new session with ID {session.id}")
|
||||
return session
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to create session record: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_current_session(self) -> Optional[Session]:
|
||||
"""
|
||||
Get the current active session.
|
||||
|
||||
If no session has been created in this repository instance,
|
||||
retrieves the most recent session from the database.
|
||||
|
||||
Returns:
|
||||
Optional[Session]: The current session or None if no sessions exist
|
||||
"""
|
||||
if self.current_session is not None:
|
||||
return self.current_session
|
||||
|
||||
try:
|
||||
# Find the most recent session
|
||||
session = Session.select().order_by(Session.created_at.desc()).first()
|
||||
if session:
|
||||
self.current_session = session
|
||||
return session
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to get current session: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_current_session_id(self) -> Optional[int]:
|
||||
"""
|
||||
Get the ID of the current active session.
|
||||
|
||||
Returns:
|
||||
Optional[int]: The ID of the current session or None if no session exists
|
||||
"""
|
||||
session = self.get_current_session()
|
||||
return session.id if session else None
|
||||
|
||||
def get(self, session_id: int) -> Optional[Session]:
|
||||
"""
|
||||
Get a session by its ID.
|
||||
|
||||
Args:
|
||||
session_id: The ID of the session to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[Session]: The session with the given ID or None if not found
|
||||
"""
|
||||
try:
|
||||
return Session.get_or_none(Session.id == session_id)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Database error getting session {session_id}: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_all(self) -> List[Session]:
|
||||
"""
|
||||
Get all sessions from the database.
|
||||
|
||||
Returns:
|
||||
List[Session]: List of all sessions
|
||||
"""
|
||||
try:
|
||||
return list(Session.select().order_by(Session.created_at.desc()))
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to get all sessions: {str(e)}")
|
||||
return []
|
||||
|
||||
def get_recent(self, limit: int = 10) -> List[Session]:
|
||||
"""
|
||||
Get the most recent sessions from the database.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of sessions to return (default: 10)
|
||||
|
||||
Returns:
|
||||
List[Session]: List of the most recent sessions
|
||||
"""
|
||||
try:
|
||||
return list(
|
||||
Session.select()
|
||||
.order_by(Session.created_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to get recent sessions: {str(e)}")
|
||||
return []
|
||||
|
|
@ -132,8 +132,8 @@ class TrajectoryRepository:
|
|||
|
||||
def create(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_parameters: Dict[str, Any],
|
||||
tool_name: Optional[str] = None,
|
||||
tool_parameters: Optional[Dict[str, Any]] = None,
|
||||
tool_result: Optional[Dict[str, Any]] = None,
|
||||
step_data: Optional[Dict[str, Any]] = None,
|
||||
record_type: str = "tool_execution",
|
||||
|
|
@ -149,8 +149,8 @@ class TrajectoryRepository:
|
|||
Create a new trajectory record in the database.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool that was executed
|
||||
tool_parameters: Parameters passed to the tool (will be JSON encoded)
|
||||
tool_name: Optional name of the tool that was executed
|
||||
tool_parameters: Optional parameters passed to the tool (will be JSON encoded)
|
||||
tool_result: Result returned by the tool (will be JSON encoded)
|
||||
step_data: UI rendering data (will be JSON encoded)
|
||||
record_type: Type of trajectory record
|
||||
|
|
@ -170,7 +170,7 @@ class TrajectoryRepository:
|
|||
"""
|
||||
try:
|
||||
# Serialize JSON fields
|
||||
tool_parameters_json = json.dumps(tool_parameters)
|
||||
tool_parameters_json = json.dumps(tool_parameters) if tool_parameters is not None else None
|
||||
tool_result_json = json.dumps(tool_result) if tool_result is not None else None
|
||||
step_data_json = json.dumps(step_data) if step_data is not None else None
|
||||
|
||||
|
|
@ -185,7 +185,7 @@ class TrajectoryRepository:
|
|||
# Create the trajectory record
|
||||
trajectory = Trajectory.create(
|
||||
human_input=human_input,
|
||||
tool_name=tool_name,
|
||||
tool_name=tool_name or "", # Use empty string if tool_name is None
|
||||
tool_parameters=tool_parameters_json,
|
||||
tool_result=tool_result_json,
|
||||
step_data=step_data_json,
|
||||
|
|
@ -197,7 +197,10 @@ class TrajectoryRepository:
|
|||
error_type=error_type,
|
||||
error_details=error_details
|
||||
)
|
||||
logger.debug(f"Created trajectory record ID {trajectory.id} for tool: {tool_name}")
|
||||
if tool_name:
|
||||
logger.debug(f"Created trajectory record ID {trajectory.id} for tool: {tool_name}")
|
||||
else:
|
||||
logger.debug(f"Created trajectory record ID {trajectory.id} of type: {record_type}")
|
||||
return trajectory
|
||||
except peewee.DatabaseError as e:
|
||||
logger.error(f"Failed to create trajectory record: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -154,6 +154,24 @@ class FallbackHandler:
|
|||
logger.debug(
|
||||
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback for tool: {self.current_failing_tool_name}"
|
||||
)
|
||||
# Import repository classes directly to avoid circular imports
|
||||
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.database.connection import get_db
|
||||
|
||||
# Create repositories directly
|
||||
trajectory_repo = TrajectoryRepository(get_db())
|
||||
human_input_repo = HumanInputRepository(get_db())
|
||||
human_input_id = human_input_repo.get_most_recent_id()
|
||||
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"message": f"**Tool fallback activated**: Attempting fallback for tool {self.current_failing_tool_name}.",
|
||||
"display_title": "Fallback Notification",
|
||||
},
|
||||
record_type="info",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
cpm(
|
||||
f"**Tool fallback activated**: Attempting fallback for tool {self.current_failing_tool_name}.",
|
||||
title="Fallback Notification",
|
||||
|
|
@ -163,6 +181,24 @@ class FallbackHandler:
|
|||
if result_list:
|
||||
return result_list
|
||||
|
||||
# Import repository classes directly to avoid circular imports
|
||||
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.database.connection import get_db
|
||||
|
||||
# Create repositories directly
|
||||
trajectory_repo = TrajectoryRepository(get_db())
|
||||
human_input_repo = HumanInputRepository(get_db())
|
||||
human_input_id = human_input_repo.get_most_recent_id()
|
||||
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"message": "All fallback models have failed.",
|
||||
"display_title": "Fallback Failed",
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
cpm("All fallback models have failed.", title="Fallback Failed")
|
||||
|
||||
current_failing_tool_name = self.current_failing_tool_name
|
||||
|
|
|
|||
|
|
@ -234,6 +234,24 @@ def create_llm_client(
|
|||
elif supports_temperature:
|
||||
if temperature is None:
|
||||
temperature = 0.7
|
||||
# Import repository classes directly to avoid circular imports
|
||||
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.database.connection import get_db
|
||||
|
||||
# Create repositories directly
|
||||
trajectory_repo = TrajectoryRepository(get_db())
|
||||
human_input_repo = HumanInputRepository(get_db())
|
||||
human_input_id = human_input_repo.get_most_recent_id()
|
||||
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"message": "This model supports temperature argument but none was given. Setting default temperature to 0.7.",
|
||||
"display_title": "Information",
|
||||
},
|
||||
record_type="info",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
cpm(
|
||||
"This model supports temperature argument but none was given. Setting default temperature to 0.7."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -51,11 +51,11 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||
id = pw.AutoField()
|
||||
created_at = pw.DateTimeField()
|
||||
updated_at = pw.DateTimeField()
|
||||
tool_name = pw.TextField()
|
||||
tool_parameters = pw.TextField() # JSON-encoded parameters
|
||||
tool_result = pw.TextField() # JSON-encoded result
|
||||
step_data = pw.TextField() # JSON-encoded UI rendering data
|
||||
record_type = pw.TextField() # Type of trajectory record
|
||||
tool_name = pw.TextField(null=True) # JSON-encoded parameters
|
||||
tool_parameters = pw.TextField(null=True) # JSON-encoded parameters
|
||||
tool_result = pw.TextField(null=True) # JSON-encoded result
|
||||
step_data = pw.TextField(null=True) # JSON-encoded UI rendering data
|
||||
record_type = pw.TextField(null=True) # Type of trajectory record
|
||||
cost = pw.FloatField(null=True) # Placeholder for cost tracking
|
||||
tokens = pw.IntegerField(null=True) # Placeholder for token usage tracking
|
||||
is_error = pw.BooleanField(default=False) # Flag indicating if this record represents an error
|
||||
|
|
|
|||
|
|
@ -0,0 +1,79 @@
|
|||
"""Peewee migrations -- 008_20250311_191232_add_session_model.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Create the session table for storing application session information."""
|
||||
|
||||
table_exists = False
|
||||
# Check if the table already exists
|
||||
try:
|
||||
database.execute_sql("SELECT id FROM session LIMIT 1")
|
||||
# If we reach here, the table exists
|
||||
table_exists = True
|
||||
except pw.OperationalError:
|
||||
# Table doesn't exist, safe to create
|
||||
pass
|
||||
|
||||
# Create the Session model - this registers it in migrator.orm as 'Session'
|
||||
@migrator.create_model
|
||||
class Session(pw.Model):
|
||||
id = pw.AutoField()
|
||||
created_at = pw.DateTimeField()
|
||||
updated_at = pw.DateTimeField()
|
||||
start_time = pw.DateTimeField()
|
||||
command_line = pw.TextField(null=True)
|
||||
program_version = pw.TextField(null=True)
|
||||
machine_info = pw.TextField(null=True)
|
||||
|
||||
class Meta:
|
||||
table_name = "session"
|
||||
|
||||
# FIX: Explicitly register the model under the lowercase table name key
|
||||
# This ensures that later migrations can access it via either:
|
||||
# - migrator.orm['Session'] (class name)
|
||||
# - migrator.orm['session'] (table name)
|
||||
if 'Session' in migrator.orm:
|
||||
migrator.orm['session'] = migrator.orm['Session']
|
||||
|
||||
# Only return after model registration is complete
|
||||
if table_exists:
|
||||
return
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Remove the session table."""
|
||||
|
||||
migrator.remove_model('session')
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
"""Peewee migrations -- 009_20250311_191517_add_session_fk_to_human_input.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Add session foreign key to HumanInput table."""
|
||||
|
||||
# Get the Session model from migrator.orm
|
||||
Session = migrator.orm['session']
|
||||
|
||||
# Check if the column already exists
|
||||
try:
|
||||
database.execute_sql("SELECT session_id FROM human_input LIMIT 1")
|
||||
# If we reach here, the column exists
|
||||
return
|
||||
except pw.OperationalError:
|
||||
# Column doesn't exist, safe to add
|
||||
pass
|
||||
|
||||
# Add the session_id foreign key column
|
||||
migrator.add_fields(
|
||||
'human_input',
|
||||
session=pw.ForeignKeyField(
|
||||
Session,
|
||||
null=True,
|
||||
field='id',
|
||||
on_delete='CASCADE'
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Remove session foreign key from HumanInput table."""
|
||||
|
||||
migrator.remove_fields('human_input', 'session')
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
"""Peewee migrations -- 010_20250311_191617_add_session_fk_to_key_fact.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Add session foreign key to KeyFact table."""
|
||||
|
||||
# Get the Session model from migrator.orm
|
||||
Session = migrator.orm['session']
|
||||
|
||||
# Check if the column already exists
|
||||
try:
|
||||
database.execute_sql("SELECT session_id FROM key_fact LIMIT 1")
|
||||
# If we reach here, the column exists
|
||||
return
|
||||
except pw.OperationalError:
|
||||
# Column doesn't exist, safe to add
|
||||
pass
|
||||
|
||||
# Add the session_id foreign key column
|
||||
migrator.add_fields(
|
||||
'key_fact',
|
||||
session=pw.ForeignKeyField(
|
||||
Session,
|
||||
null=True,
|
||||
field='id',
|
||||
on_delete='CASCADE'
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Remove session foreign key from KeyFact table."""
|
||||
|
||||
migrator.remove_fields('key_fact', 'session')
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
"""Peewee migrations -- 011_20250311_191732_add_session_fk_to_key_snippet.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Add session foreign key to KeySnippet table."""
|
||||
|
||||
# Get the Session model from migrator.orm
|
||||
Session = migrator.orm['session']
|
||||
|
||||
# Check if the column already exists
|
||||
try:
|
||||
database.execute_sql("SELECT session_id FROM key_snippet LIMIT 1")
|
||||
# If we reach here, the column exists
|
||||
return
|
||||
except pw.OperationalError:
|
||||
# Column doesn't exist, safe to add
|
||||
pass
|
||||
|
||||
# Add the session_id foreign key column
|
||||
migrator.add_fields(
|
||||
'key_snippet',
|
||||
session=pw.ForeignKeyField(
|
||||
Session,
|
||||
null=True,
|
||||
field='id',
|
||||
on_delete='CASCADE'
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Remove session foreign key from KeySnippet table."""
|
||||
|
||||
migrator.remove_fields('key_snippet', 'session')
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
"""Peewee migrations -- 012_20250311_191832_add_session_fk_to_research_note.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Add session foreign key to ResearchNote table."""
|
||||
|
||||
# Get the Session model from migrator.orm
|
||||
Session = migrator.orm['session']
|
||||
|
||||
# Check if the column already exists
|
||||
try:
|
||||
database.execute_sql("SELECT session_id FROM research_note LIMIT 1")
|
||||
# If we reach here, the column exists
|
||||
return
|
||||
except pw.OperationalError:
|
||||
# Column doesn't exist, safe to add
|
||||
pass
|
||||
|
||||
# Add the session_id foreign key column
|
||||
migrator.add_fields(
|
||||
'research_note',
|
||||
session=pw.ForeignKeyField(
|
||||
Session,
|
||||
null=True,
|
||||
field='id',
|
||||
on_delete='CASCADE'
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Remove session foreign key from ResearchNote table."""
|
||||
|
||||
migrator.remove_fields('research_note', 'session')
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
"""Peewee migrations -- 013_20250311_191701_add_session_fk_to_trajectory.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Add session foreign key to Trajectory table."""
|
||||
|
||||
# Get the Session model from migrator.orm
|
||||
Session = migrator.orm['session']
|
||||
|
||||
# Check if the column already exists
|
||||
try:
|
||||
database.execute_sql("SELECT session_id FROM trajectory LIMIT 1")
|
||||
# If we reach here, the column exists
|
||||
return
|
||||
except pw.OperationalError:
|
||||
# Column doesn't exist, safe to add
|
||||
pass
|
||||
|
||||
# Add the session_id foreign key column
|
||||
migrator.add_fields(
|
||||
'trajectory',
|
||||
session=pw.ForeignKeyField(
|
||||
Session,
|
||||
null=True,
|
||||
field='id',
|
||||
on_delete='CASCADE'
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Remove session foreign key from Trajectory table."""
|
||||
|
||||
migrator.remove_fields('trajectory', 'session')
|
||||
|
|
@ -17,6 +17,8 @@ __all__ = [
|
|||
|
||||
from ra_aid.file_listing import FileListerError, get_file_listing
|
||||
from ra_aid.project_state import ProjectStateError, is_new_project
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -130,6 +132,24 @@ def display_project_status(info: ProjectInfo) -> None:
|
|||
{status} with **{file_count} file(s)**
|
||||
"""
|
||||
|
||||
# Record project status in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"project_status": "new" if info.is_new else "existing",
|
||||
"file_count": file_count,
|
||||
"total_files": info.total_files,
|
||||
"display_title": "Project Status",
|
||||
},
|
||||
record_type="info",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except Exception as e:
|
||||
# Silently continue if trajectory recording fails
|
||||
pass
|
||||
|
||||
# Create and display panel
|
||||
console = Console()
|
||||
console.print(Panel(Markdown(status_text.strip()), title="📊 Project Status"))
|
||||
|
|
@ -194,4 +194,5 @@ THE AGENT IS VERY FORGETFUL AND YOUR WRITING MUST INCLUDE REMARKS ABOUT HOW IT S
|
|||
REMEMBER WE ARE INSTRUCTING THE AGENT **HOW TO DO RESEARCH ABOUT WHAT ALREADY EXISTS** AT THIS POINT USING THE TOOLS AVAILABLE. YOU ARE NOT TO DO THE ACTUAL RESEARCH YOURSELF. IF AN IMPLEMENTATION IS REQUESTED, THE AGENT SHOULD BE INSTRUCTED TO CALL request_task_implementation BUT ONLY AFTER EMITTING RESEARCH NOTES, KEY FACTS, AND KEY SNIPPETS AS RELEVANT.
|
||||
IT IS IMPERATIVE THAT WE DO NOT START DIRECTLY IMPLEMENTING ANYTHING AT THIS POINT. WE ARE RESEARCHING, THEN CALLING request_implementation *AT MOST ONCE*.
|
||||
IT IS IMPERATIVE THE AGENT EMITS KEY FACTS AND THOROUGH RESEARCH NOTES AT THIS POINT. THE RESEARCH NOTES CAN JUST BE THOUGHTS AT THIS POINT IF IT IS A NEW PROJECT.
|
||||
THE AGENT MUST ALWAYS CALL emit_research_notes AT LEAST ONCE, ESPECIALLY IF IT CALLS ask_expert.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -15,11 +15,12 @@ from ra_aid.agent_context import (
|
|||
reset_completion_flags,
|
||||
)
|
||||
from ra_aid.config import DEFAULT_MODEL
|
||||
from ra_aid.console.formatting import print_error
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.console.formatting import print_error, print_task_header
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.database.repositories.related_files_repository import get_related_files_repository
|
||||
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
||||
from ra_aid.exceptions import AgentInterrupt
|
||||
|
|
@ -27,8 +28,7 @@ from ra_aid.model_formatters import format_key_facts_dict
|
|||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||
from ra_aid.model_formatters.research_notes_formatter import format_research_notes_dict
|
||||
|
||||
from ..console import print_task_header
|
||||
from ..llm import initialize_llm
|
||||
from ra_aid.llm import initialize_llm
|
||||
from .human import ask_human
|
||||
from .memory import get_related_files, get_work_log
|
||||
|
||||
|
|
@ -63,7 +63,23 @@ def request_research(query: str) -> ResearchResult:
|
|||
# Check recursion depth
|
||||
current_depth = get_depth()
|
||||
if current_depth >= RESEARCH_AGENT_RECURSION_LIMIT:
|
||||
print_error("Maximum research recursion depth reached")
|
||||
error_message = "Maximum research recursion depth reached"
|
||||
|
||||
# Record error in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": error_message,
|
||||
"display_title": "Error",
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=error_message
|
||||
)
|
||||
|
||||
print_error(error_message)
|
||||
try:
|
||||
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
|
||||
except RuntimeError as e:
|
||||
|
|
@ -110,7 +126,23 @@ def request_research(query: str) -> ResearchResult:
|
|||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
print_error(f"Error during research: {str(e)}")
|
||||
error_message = f"Error during research: {str(e)}"
|
||||
|
||||
# Record error in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": error_message,
|
||||
"display_title": "Error",
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=error_message
|
||||
)
|
||||
|
||||
print_error(error_message)
|
||||
success = False
|
||||
reason = f"error: {str(e)}"
|
||||
finally:
|
||||
|
|
@ -195,7 +227,23 @@ def request_web_research(query: str) -> ResearchResult:
|
|||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
print_error(f"Error during web research: {str(e)}")
|
||||
error_message = f"Error during web research: {str(e)}"
|
||||
|
||||
# Record error in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": error_message,
|
||||
"display_title": "Error",
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=error_message
|
||||
)
|
||||
|
||||
print_error(error_message)
|
||||
success = False
|
||||
reason = f"error: {str(e)}"
|
||||
finally:
|
||||
|
|
@ -347,6 +395,19 @@ def request_task_implementation(task_spec: str) -> str:
|
|||
|
||||
try:
|
||||
print_task_header(task_spec)
|
||||
|
||||
# Record task display in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"task": task_spec,
|
||||
"display_title": "Task",
|
||||
},
|
||||
record_type="task_display",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
|
||||
# Run implementation agent
|
||||
from ..agents.implementation_agent import run_task_implementation_agent
|
||||
|
||||
|
|
@ -372,7 +433,23 @@ def request_task_implementation(task_spec: str) -> str:
|
|||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
print_error(f"Error during task implementation: {str(e)}")
|
||||
error_message = f"Error during task implementation: {str(e)}"
|
||||
|
||||
# Record error in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": error_message,
|
||||
"display_title": "Error",
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=error_message
|
||||
)
|
||||
|
||||
print_error(error_message)
|
||||
success = False
|
||||
reason = f"error: {str(e)}"
|
||||
|
||||
|
|
@ -503,7 +580,23 @@ def request_implementation(task_spec: str) -> str:
|
|||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
print_error(f"Error during planning: {str(e)}")
|
||||
error_message = f"Error during planning: {str(e)}"
|
||||
|
||||
# Record error in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": error_message,
|
||||
"display_title": "Error",
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=error_message
|
||||
)
|
||||
|
||||
print_error(error_message)
|
||||
success = False
|
||||
reason = f"error: {str(e)}"
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,9 @@ from rich.panel import Panel
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from ..database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ..database.repositories.human_input_repository import get_human_input_repository
|
||||
|
||||
from ..database.repositories.key_fact_repository import get_key_fact_repository
|
||||
from ..database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||
from ..database.repositories.related_files_repository import get_related_files_repository
|
||||
|
|
@ -72,6 +75,23 @@ def emit_expert_context(context: str) -> str:
|
|||
"""
|
||||
expert_context["text"].append(context)
|
||||
|
||||
# Record expert context in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
tool_name="emit_expert_context",
|
||||
tool_parameters={"context_length": len(context)},
|
||||
step_data={
|
||||
"display_title": "Expert Context",
|
||||
"context_length": len(context),
|
||||
},
|
||||
record_type="tool_execution",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to record trajectory: {e}")
|
||||
|
||||
# Create and display status panel
|
||||
panel_content = f"Added expert context ({len(context)} characters)"
|
||||
console.print(Panel(panel_content, title="Expert Context", border_style="blue"))
|
||||
|
|
@ -184,6 +204,23 @@ def ask_expert(question: str) -> str:
|
|||
# Build display query (just question)
|
||||
display_query = "# Question\n" + question
|
||||
|
||||
# Record expert query in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
tool_name="ask_expert",
|
||||
tool_parameters={"question": question},
|
||||
step_data={
|
||||
"display_title": "Expert Query",
|
||||
"question": question,
|
||||
},
|
||||
record_type="tool_execution",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to record trajectory: {e}")
|
||||
|
||||
# Show only question in panel
|
||||
console.print(
|
||||
Panel(Markdown(display_query), title="🤔 Expert Query", border_style="yellow")
|
||||
|
|
@ -263,6 +300,23 @@ def ask_expert(question: str) -> str:
|
|||
logger.error(f"Exception during content processing: {str(e)}")
|
||||
raise
|
||||
|
||||
# Record expert response in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
tool_name="ask_expert",
|
||||
tool_parameters={"question": question},
|
||||
step_data={
|
||||
"display_title": "Expert Response",
|
||||
"response_length": len(content),
|
||||
},
|
||||
record_type="tool_execution",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to record trajectory: {e}")
|
||||
|
||||
# Format and display response
|
||||
console.print(
|
||||
Panel(Markdown(content), title="Expert Response", border_style="blue")
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ from rich.panel import Panel
|
|||
from ra_aid.console import console
|
||||
from ra_aid.console.formatting import print_error
|
||||
from ra_aid.tools.memory import emit_related_files
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
|
||||
|
||||
def truncate_display_str(s: str, max_length: int = 30) -> str:
|
||||
|
|
@ -54,6 +56,32 @@ def file_str_replace(filepath: str, old_str: str, new_str: str, *, replace_all:
|
|||
path = Path(filepath)
|
||||
if not path.exists():
|
||||
msg = f"File not found: {filepath}"
|
||||
|
||||
# Record error in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": msg,
|
||||
"display_title": "Error",
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=msg,
|
||||
tool_name="file_str_replace",
|
||||
tool_parameters={
|
||||
"filepath": filepath,
|
||||
"old_str": old_str,
|
||||
"new_str": new_str,
|
||||
"replace_all": replace_all
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# Silently handle trajectory recording failures (e.g., in test environments)
|
||||
pass
|
||||
|
||||
print_error(msg)
|
||||
return {"success": False, "message": msg}
|
||||
|
||||
|
|
@ -62,10 +90,62 @@ def file_str_replace(filepath: str, old_str: str, new_str: str, *, replace_all:
|
|||
|
||||
if count == 0:
|
||||
msg = f"String not found: {truncate_display_str(old_str)}"
|
||||
|
||||
# Record error in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": msg,
|
||||
"display_title": "Error",
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=msg,
|
||||
tool_name="file_str_replace",
|
||||
tool_parameters={
|
||||
"filepath": filepath,
|
||||
"old_str": old_str,
|
||||
"new_str": new_str,
|
||||
"replace_all": replace_all
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# Silently handle trajectory recording failures (e.g., in test environments)
|
||||
pass
|
||||
|
||||
print_error(msg)
|
||||
return {"success": False, "message": msg}
|
||||
elif count > 1 and not replace_all:
|
||||
msg = f"String appears {count} times - must be unique (use replace_all=True to replace all occurrences)"
|
||||
|
||||
# Record error in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": msg,
|
||||
"display_title": "Error",
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=msg,
|
||||
tool_name="file_str_replace",
|
||||
tool_parameters={
|
||||
"filepath": filepath,
|
||||
"old_str": old_str,
|
||||
"new_str": new_str,
|
||||
"replace_all": replace_all
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# Silently handle trajectory recording failures (e.g., in test environments)
|
||||
pass
|
||||
|
||||
print_error(msg)
|
||||
return {"success": False, "message": msg}
|
||||
|
||||
|
|
@ -93,7 +173,34 @@ def file_str_replace(filepath: str, old_str: str, new_str: str, *, replace_all:
|
|||
emit_related_files.invoke({"files": [filepath]})
|
||||
except Exception as e:
|
||||
# Don't let related files error affect main function success
|
||||
print_error(f"Note: Could not add to related files: {str(e)}")
|
||||
error_msg = f"Note: Could not add to related files: {str(e)}"
|
||||
|
||||
# Record error in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": error_msg,
|
||||
"display_title": "Error",
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=error_msg,
|
||||
tool_name="file_str_replace",
|
||||
tool_parameters={
|
||||
"filepath": filepath,
|
||||
"old_str": old_str,
|
||||
"new_str": new_str,
|
||||
"replace_all": replace_all
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# Silently handle trajectory recording failures (e.g., in test environments)
|
||||
pass
|
||||
|
||||
print_error(error_msg)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
|
|
@ -102,5 +209,31 @@ def file_str_replace(filepath: str, old_str: str, new_str: str, *, replace_all:
|
|||
|
||||
except Exception as e:
|
||||
msg = f"Error: {str(e)}"
|
||||
|
||||
# Record error in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": msg,
|
||||
"display_title": "Error",
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=msg,
|
||||
tool_name="file_str_replace",
|
||||
tool_parameters={
|
||||
"filepath": filepath,
|
||||
"old_str": old_str,
|
||||
"new_str": new_str,
|
||||
"replace_all": replace_all
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# Silently handle trajectory recording failures (e.g., in test environments)
|
||||
pass
|
||||
|
||||
print_error(msg)
|
||||
return {"success": False, "message": msg}
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
import fnmatch
|
||||
from typing import List, Tuple
|
||||
import logging
|
||||
from typing import List, Tuple, Dict, Optional, Any
|
||||
|
||||
from fuzzywuzzy import process
|
||||
from git import Repo, exc
|
||||
|
|
@ -12,6 +13,49 @@ from ra_aid.file_listing import get_all_project_files, FileListerError
|
|||
|
||||
console = Console()
|
||||
|
||||
|
||||
def record_trajectory(
|
||||
tool_name: str,
|
||||
tool_parameters: Dict,
|
||||
step_data: Dict,
|
||||
record_type: str = "tool_execution",
|
||||
is_error: bool = False,
|
||||
error_message: Optional[str] = None,
|
||||
error_type: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Helper function to record trajectory information, handling the case when repositories are not available.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
tool_parameters: Parameters passed to the tool
|
||||
step_data: UI rendering data
|
||||
record_type: Type of trajectory record
|
||||
is_error: Flag indicating if this record represents an error
|
||||
error_message: The error message
|
||||
error_type: The type/class of the error
|
||||
"""
|
||||
try:
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
tool_name=tool_name,
|
||||
tool_parameters=tool_parameters,
|
||||
step_data=step_data,
|
||||
record_type=record_type,
|
||||
human_input_id=human_input_id,
|
||||
is_error=is_error,
|
||||
error_message=error_message,
|
||||
error_type=error_type
|
||||
)
|
||||
except (ImportError, RuntimeError):
|
||||
# If either the repository modules can't be imported or no repository is available,
|
||||
# just log and continue without recording trajectory
|
||||
logging.debug("Skipping trajectory recording: repositories not available")
|
||||
|
||||
DEFAULT_EXCLUDE_PATTERNS = [
|
||||
"*.pyc",
|
||||
"__pycache__/*",
|
||||
|
|
@ -57,7 +101,32 @@ def fuzzy_find_project_files(
|
|||
"""
|
||||
# Validate threshold
|
||||
if not 0 <= threshold <= 100:
|
||||
raise ValueError("Threshold must be between 0 and 100")
|
||||
error_msg = "Threshold must be between 0 and 100"
|
||||
|
||||
# Record error in trajectory
|
||||
record_trajectory(
|
||||
tool_name="fuzzy_find_project_files",
|
||||
tool_parameters={
|
||||
"search_term": search_term,
|
||||
"repo_path": repo_path,
|
||||
"threshold": threshold,
|
||||
"max_results": max_results,
|
||||
"include_paths": include_paths,
|
||||
"exclude_patterns": exclude_patterns,
|
||||
"include_hidden": include_hidden
|
||||
},
|
||||
step_data={
|
||||
"search_term": search_term,
|
||||
"display_title": "Invalid Threshold Value",
|
||||
"error_message": error_msg
|
||||
},
|
||||
record_type="tool_execution",
|
||||
is_error=True,
|
||||
error_message=error_msg,
|
||||
error_type="ValueError"
|
||||
)
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Handle empty search term as special case
|
||||
if not search_term:
|
||||
|
|
@ -126,6 +195,27 @@ def fuzzy_find_project_files(
|
|||
else:
|
||||
info_sections.append("## Results\n*No matches found*")
|
||||
|
||||
# Record fuzzy find in trajectory
|
||||
record_trajectory(
|
||||
tool_name="fuzzy_find_project_files",
|
||||
tool_parameters={
|
||||
"search_term": search_term,
|
||||
"repo_path": repo_path,
|
||||
"threshold": threshold,
|
||||
"max_results": max_results,
|
||||
"include_paths": include_paths,
|
||||
"exclude_patterns": exclude_patterns,
|
||||
"include_hidden": include_hidden
|
||||
},
|
||||
step_data={
|
||||
"search_term": search_term,
|
||||
"display_title": "Fuzzy Find Results",
|
||||
"total_files": len(all_files),
|
||||
"matches_found": len(filtered_matches)
|
||||
},
|
||||
record_type="tool_execution"
|
||||
)
|
||||
|
||||
# Display the panel
|
||||
console.print(
|
||||
Panel(
|
||||
|
|
@ -138,5 +228,30 @@ def fuzzy_find_project_files(
|
|||
return filtered_matches
|
||||
|
||||
except FileListerError as e:
|
||||
console.print(f"[bold red]Error listing files: {e}[/bold red]")
|
||||
error_msg = f"Error listing files: {e}"
|
||||
|
||||
# Record error in trajectory
|
||||
record_trajectory(
|
||||
tool_name="fuzzy_find_project_files",
|
||||
tool_parameters={
|
||||
"search_term": search_term,
|
||||
"repo_path": repo_path,
|
||||
"threshold": threshold,
|
||||
"max_results": max_results,
|
||||
"include_paths": include_paths,
|
||||
"exclude_patterns": exclude_patterns,
|
||||
"include_hidden": include_hidden
|
||||
},
|
||||
step_data={
|
||||
"search_term": search_term,
|
||||
"display_title": "Fuzzy Find Error",
|
||||
"error_message": error_msg
|
||||
},
|
||||
record_type="tool_execution",
|
||||
is_error=True,
|
||||
error_message=error_msg,
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
|
||||
console.print(f"[bold red]{error_msg}[/bold red]")
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from ra_aid.database.repositories.key_fact_repository import get_key_fact_reposi
|
|||
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.database.repositories.work_log_repository import get_work_log_repository
|
||||
from ra_aid.model_formatters import key_snippets_formatter
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
|
@ -69,6 +70,22 @@ def emit_research_notes(notes: str) -> str:
|
|||
from ra_aid.model_formatters.research_notes_formatter import format_research_note
|
||||
formatted_note = format_research_note(note_id, notes)
|
||||
|
||||
# Record to trajectory before displaying panel
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
trajectory_repo.create(
|
||||
tool_name="emit_research_notes",
|
||||
tool_parameters={"notes": notes},
|
||||
step_data={
|
||||
"note_id": note_id,
|
||||
"display_title": "Research Notes",
|
||||
},
|
||||
record_type="memory_operation",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Failed to record trajectory: {str(e)}")
|
||||
|
||||
# Display formatted note
|
||||
console.print(Panel(Markdown(formatted_note), title="🔍 Research Notes"))
|
||||
|
||||
|
|
@ -123,6 +140,23 @@ def emit_key_facts(facts: List[str]) -> str:
|
|||
console.print(f"Error storing fact: {str(e)}", style="red")
|
||||
continue
|
||||
|
||||
# Record to trajectory before displaying panel
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
trajectory_repo.create(
|
||||
tool_name="emit_key_facts",
|
||||
tool_parameters={"facts": [fact]},
|
||||
step_data={
|
||||
"fact_id": fact_id,
|
||||
"fact": fact,
|
||||
"display_title": f"Key Fact #{fact_id}",
|
||||
},
|
||||
record_type="memory_operation",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Failed to record trajectory: {str(e)}")
|
||||
|
||||
# Display panel with ID
|
||||
console.print(
|
||||
Panel(
|
||||
|
|
@ -214,6 +248,32 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str:
|
|||
if snippet_info["description"]:
|
||||
display_text.extend(["", "**Description**:", snippet_info["description"]])
|
||||
|
||||
# Record to trajectory before displaying panel
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
trajectory_repo.create(
|
||||
tool_name="emit_key_snippet",
|
||||
tool_parameters={
|
||||
"snippet_info": {
|
||||
"filepath": snippet_info["filepath"],
|
||||
"line_number": snippet_info["line_number"],
|
||||
"description": snippet_info["description"],
|
||||
# Omit the full snippet content to avoid duplicating large text in the database
|
||||
"snippet_length": len(snippet_info["snippet"])
|
||||
}
|
||||
},
|
||||
step_data={
|
||||
"snippet_id": snippet_id,
|
||||
"filepath": snippet_info["filepath"],
|
||||
"line_number": snippet_info["line_number"],
|
||||
"display_title": f"Key Snippet #{snippet_id}",
|
||||
},
|
||||
record_type="memory_operation",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Failed to record trajectory: {str(e)}")
|
||||
|
||||
# Display panel
|
||||
console.print(
|
||||
Panel(
|
||||
|
|
@ -248,6 +308,25 @@ def one_shot_completed(message: str) -> str:
|
|||
message: Completion message to display
|
||||
"""
|
||||
mark_task_completed(message)
|
||||
|
||||
# Record to trajectory before displaying panel
|
||||
human_input_id = None
|
||||
try:
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
trajectory_repo.create(
|
||||
tool_name="one_shot_completed",
|
||||
tool_parameters={"message": message},
|
||||
step_data={
|
||||
"completion_message": message,
|
||||
"display_title": "Task Completed",
|
||||
},
|
||||
record_type="task_completion",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Failed to record trajectory: {str(e)}")
|
||||
|
||||
console.print(Panel(Markdown(message), title="✅ Task Completed"))
|
||||
log_work_event(f"Task completed:\n\n{message}")
|
||||
return "Completion noted."
|
||||
|
|
@ -261,6 +340,25 @@ def task_completed(message: str) -> str:
|
|||
message: Message explaining how/why the task is complete
|
||||
"""
|
||||
mark_task_completed(message)
|
||||
|
||||
# Record to trajectory before displaying panel
|
||||
human_input_id = None
|
||||
try:
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
trajectory_repo.create(
|
||||
tool_name="task_completed",
|
||||
tool_parameters={"message": message},
|
||||
step_data={
|
||||
"completion_message": message,
|
||||
"display_title": "Task Completed",
|
||||
},
|
||||
record_type="task_completion",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Failed to record trajectory: {str(e)}")
|
||||
|
||||
console.print(Panel(Markdown(message), title="✅ Task Completed"))
|
||||
log_work_event(f"Task completed:\n\n{message}")
|
||||
return "Completion noted."
|
||||
|
|
@ -275,6 +373,25 @@ def plan_implementation_completed(message: str) -> str:
|
|||
"""
|
||||
mark_should_exit(propagation_depth=1)
|
||||
mark_plan_completed(message)
|
||||
|
||||
# Record to trajectory before displaying panel
|
||||
human_input_id = None
|
||||
try:
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
trajectory_repo.create(
|
||||
tool_name="plan_implementation_completed",
|
||||
tool_parameters={"message": message},
|
||||
step_data={
|
||||
"completion_message": message,
|
||||
"display_title": "Plan Executed",
|
||||
},
|
||||
record_type="plan_completion",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Failed to record trajectory: {str(e)}")
|
||||
|
||||
console.print(Panel(Markdown(message), title="✅ Plan Executed"))
|
||||
log_work_event(f"Completed implementation:\n\n{message}")
|
||||
return "Plan completion noted."
|
||||
|
|
@ -361,10 +478,29 @@ def emit_related_files(files: List[str]) -> str:
|
|||
|
||||
results.append(f"File ID #{file_id}: {file}")
|
||||
|
||||
# Rich output - single consolidated panel for added files
|
||||
# Record to trajectory before displaying panel for added files
|
||||
if added_files:
|
||||
files_added_md = "\n".join(f"- `{file}`" for id, file in added_files)
|
||||
md_content = f"**Files Noted:**\n{files_added_md}"
|
||||
|
||||
human_input_id = None
|
||||
try:
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
trajectory_repo.create(
|
||||
tool_name="emit_related_files",
|
||||
tool_parameters={"files": files},
|
||||
step_data={
|
||||
"added_files": [file for _, file in added_files],
|
||||
"added_file_ids": [file_id for file_id, _ in added_files],
|
||||
"display_title": "Related Files Noted",
|
||||
},
|
||||
record_type="memory_operation",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Failed to record trajectory: {str(e)}")
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(md_content),
|
||||
|
|
@ -373,10 +509,28 @@ def emit_related_files(files: List[str]) -> str:
|
|||
)
|
||||
)
|
||||
|
||||
# Display skipped binary files
|
||||
# Record to trajectory before displaying panel for binary files
|
||||
if binary_files:
|
||||
binary_files_md = "\n".join(f"- `{file}`" for file in binary_files)
|
||||
md_content = f"**Binary Files Skipped:**\n{binary_files_md}"
|
||||
|
||||
human_input_id = None
|
||||
try:
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
trajectory_repo.create(
|
||||
tool_name="emit_related_files",
|
||||
tool_parameters={"files": files},
|
||||
step_data={
|
||||
"binary_files": binary_files,
|
||||
"display_title": "Binary Files Not Added",
|
||||
},
|
||||
record_type="memory_operation",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Failed to record trajectory: {str(e)}")
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(md_content),
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
import os.path
|
||||
import time
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from rich.console import Console
|
||||
|
|
@ -16,6 +16,49 @@ console = Console()
|
|||
CHUNK_SIZE = 8192
|
||||
|
||||
|
||||
def record_trajectory(
|
||||
tool_name: str,
|
||||
tool_parameters: Dict,
|
||||
step_data: Dict,
|
||||
record_type: str = "tool_execution",
|
||||
is_error: bool = False,
|
||||
error_message: Optional[str] = None,
|
||||
error_type: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Helper function to record trajectory information, handling the case when repositories are not available.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
tool_parameters: Parameters passed to the tool
|
||||
step_data: UI rendering data
|
||||
record_type: Type of trajectory record
|
||||
is_error: Flag indicating if this record represents an error
|
||||
error_message: The error message
|
||||
error_type: The type/class of the error
|
||||
"""
|
||||
try:
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
tool_name=tool_name,
|
||||
tool_parameters=tool_parameters,
|
||||
step_data=step_data,
|
||||
record_type=record_type,
|
||||
human_input_id=human_input_id,
|
||||
is_error=is_error,
|
||||
error_message=error_message,
|
||||
error_type=error_type
|
||||
)
|
||||
except (ImportError, RuntimeError):
|
||||
# If either the repository modules can't be imported or no repository is available,
|
||||
# just log and continue without recording trajectory
|
||||
logging.debug("Skipping trajectory recording: repositories not available")
|
||||
|
||||
|
||||
@tool
|
||||
def read_file_tool(filepath: str, encoding: str = "utf-8") -> Dict[str, str]:
|
||||
"""Read and return the contents of a text file.
|
||||
|
|
@ -29,10 +72,43 @@ def read_file_tool(filepath: str, encoding: str = "utf-8") -> Dict[str, str]:
|
|||
start_time = time.time()
|
||||
try:
|
||||
if not os.path.exists(filepath):
|
||||
# Record error in trajectory
|
||||
record_trajectory(
|
||||
tool_name="read_file_tool",
|
||||
tool_parameters={
|
||||
"filepath": filepath,
|
||||
"encoding": encoding
|
||||
},
|
||||
step_data={
|
||||
"filepath": filepath,
|
||||
"display_title": "File Not Found",
|
||||
"error_message": f"File not found: {filepath}"
|
||||
},
|
||||
is_error=True,
|
||||
error_message=f"File not found: {filepath}",
|
||||
error_type="FileNotFoundError"
|
||||
)
|
||||
raise FileNotFoundError(f"File not found: {filepath}")
|
||||
|
||||
# Check if the file is binary
|
||||
if is_binary_file(filepath):
|
||||
# Record binary file error in trajectory
|
||||
record_trajectory(
|
||||
tool_name="read_file_tool",
|
||||
tool_parameters={
|
||||
"filepath": filepath,
|
||||
"encoding": encoding
|
||||
},
|
||||
step_data={
|
||||
"filepath": filepath,
|
||||
"display_title": "Binary File Detected",
|
||||
"error_message": f"Cannot read binary file: {filepath}"
|
||||
},
|
||||
is_error=True,
|
||||
error_message="Cannot read binary file",
|
||||
error_type="BinaryFileError"
|
||||
)
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
f"Cannot read binary file: {filepath}",
|
||||
|
|
@ -67,6 +143,22 @@ def read_file_tool(filepath: str, encoding: str = "utf-8") -> Dict[str, str]:
|
|||
logging.debug(f"File read complete: {total_bytes} bytes in {elapsed:.2f}s")
|
||||
logging.debug(f"Pre-truncation stats: {total_bytes} bytes, {line_count} lines")
|
||||
|
||||
# Record successful file read in trajectory
|
||||
record_trajectory(
|
||||
tool_name="read_file_tool",
|
||||
tool_parameters={
|
||||
"filepath": filepath,
|
||||
"encoding": encoding
|
||||
},
|
||||
step_data={
|
||||
"filepath": filepath,
|
||||
"display_title": "File Read",
|
||||
"line_count": line_count,
|
||||
"total_bytes": total_bytes,
|
||||
"elapsed_time": elapsed
|
||||
}
|
||||
)
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
f"Read {line_count} lines ({total_bytes} bytes) from {filepath} in {elapsed:.2f}s",
|
||||
|
|
@ -80,6 +172,25 @@ def read_file_tool(filepath: str, encoding: str = "utf-8") -> Dict[str, str]:
|
|||
|
||||
return {"content": truncated}
|
||||
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Record exception in trajectory (if it's not already a handled FileNotFoundError)
|
||||
if not isinstance(e, FileNotFoundError):
|
||||
record_trajectory(
|
||||
tool_name="read_file_tool",
|
||||
tool_parameters={
|
||||
"filepath": filepath,
|
||||
"encoding": encoding
|
||||
},
|
||||
step_data={
|
||||
"filepath": filepath,
|
||||
"display_title": "File Read Error",
|
||||
"error_message": str(e)
|
||||
},
|
||||
is_error=True,
|
||||
error_message=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@ from langchain_core.tools import tool
|
|||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
|
|
@ -10,6 +13,24 @@ def existing_project_detected() -> dict:
|
|||
"""
|
||||
When to call: Once you have confirmed that the current working directory contains project files.
|
||||
"""
|
||||
try:
|
||||
# Record detection in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
tool_name="existing_project_detected",
|
||||
tool_parameters={},
|
||||
step_data={
|
||||
"detection_type": "existing_project",
|
||||
"display_title": "Existing Project Detected",
|
||||
},
|
||||
record_type="tool_execution",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except Exception as e:
|
||||
# Continue even if trajectory recording fails
|
||||
console.print(f"Warning: Could not record trajectory: {str(e)}")
|
||||
|
||||
console.print(Panel("📁 Existing Project Detected", style="bright_blue", padding=0))
|
||||
return {
|
||||
"hint": (
|
||||
|
|
@ -30,6 +51,24 @@ def monorepo_detected() -> dict:
|
|||
"""
|
||||
When to call: After identifying that multiple packages or modules exist within a single repository.
|
||||
"""
|
||||
try:
|
||||
# Record detection in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
tool_name="monorepo_detected",
|
||||
tool_parameters={},
|
||||
step_data={
|
||||
"detection_type": "monorepo",
|
||||
"display_title": "Monorepo Detected",
|
||||
},
|
||||
record_type="tool_execution",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except Exception as e:
|
||||
# Continue even if trajectory recording fails
|
||||
console.print(f"Warning: Could not record trajectory: {str(e)}")
|
||||
|
||||
console.print(Panel("📦 Monorepo Detected", style="bright_blue", padding=0))
|
||||
return {
|
||||
"hint": (
|
||||
|
|
@ -53,6 +92,24 @@ def ui_detected() -> dict:
|
|||
"""
|
||||
When to call: After detecting that the project contains a user interface layer or front-end component.
|
||||
"""
|
||||
try:
|
||||
# Record detection in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
tool_name="ui_detected",
|
||||
tool_parameters={},
|
||||
step_data={
|
||||
"detection_type": "ui",
|
||||
"display_title": "UI Detected",
|
||||
},
|
||||
record_type="tool_execution",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except Exception as e:
|
||||
# Continue even if trajectory recording fails
|
||||
console.print(f"Warning: Could not record trajectory: {str(e)}")
|
||||
|
||||
console.print(Panel("🎯 UI Detected", style="bright_blue", padding=0))
|
||||
return {
|
||||
"hint": (
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ from rich.console import Console
|
|||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.proc.interactive import run_interactive_command
|
||||
from ra_aid.text.processing import truncate_output
|
||||
|
||||
|
|
@ -158,6 +160,30 @@ def ripgrep_search(
|
|||
info_sections.append("\n".join(params))
|
||||
|
||||
# Execute command
|
||||
# Record ripgrep search in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
tool_name="ripgrep_search",
|
||||
tool_parameters={
|
||||
"pattern": pattern,
|
||||
"before_context_lines": before_context_lines,
|
||||
"after_context_lines": after_context_lines,
|
||||
"file_type": file_type,
|
||||
"case_sensitive": case_sensitive,
|
||||
"include_hidden": include_hidden,
|
||||
"follow_links": follow_links,
|
||||
"exclude_dirs": exclude_dirs,
|
||||
"fixed_string": fixed_string
|
||||
},
|
||||
step_data={
|
||||
"search_pattern": pattern,
|
||||
"display_title": "Ripgrep Search",
|
||||
},
|
||||
record_type="tool_execution",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(f"Searching for: **{pattern}**"),
|
||||
|
|
@ -179,5 +205,34 @@ def ripgrep_search(
|
|||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
|
||||
# Record error in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
tool_name="ripgrep_search",
|
||||
tool_parameters={
|
||||
"pattern": pattern,
|
||||
"before_context_lines": before_context_lines,
|
||||
"after_context_lines": after_context_lines,
|
||||
"file_type": file_type,
|
||||
"case_sensitive": case_sensitive,
|
||||
"include_hidden": include_hidden,
|
||||
"follow_links": follow_links,
|
||||
"exclude_dirs": exclude_dirs,
|
||||
"fixed_string": fixed_string
|
||||
},
|
||||
step_data={
|
||||
"search_pattern": pattern,
|
||||
"display_title": "Ripgrep Search Error",
|
||||
"error_message": error_msg
|
||||
},
|
||||
record_type="tool_execution",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=error_msg,
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
|
||||
console.print(Panel(error_msg, title="❌ Error", border_style="red"))
|
||||
return {"output": error_msg, "return_code": 1, "success": False}
|
||||
|
|
@ -10,6 +10,8 @@ from ra_aid.proc.interactive import run_interactive_command
|
|||
from ra_aid.text.processing import truncate_output
|
||||
from ra_aid.tools.memory import log_work_event
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
|
||||
console = Console()
|
||||
|
||||
|
|
@ -54,6 +56,20 @@ def run_shell_command(
|
|||
console.print(" " + get_cowboy_message())
|
||||
console.print("")
|
||||
|
||||
# Record tool execution in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
tool_name="run_shell_command",
|
||||
tool_parameters={"command": command, "timeout": timeout},
|
||||
step_data={
|
||||
"command": command,
|
||||
"display_title": "Shell Command",
|
||||
},
|
||||
record_type="tool_execution",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
|
||||
# Show just the command in a simple panel
|
||||
console.print(Panel(command, title="🐚 Shell", border_style="bright_yellow"))
|
||||
|
||||
|
|
@ -96,5 +112,23 @@ def run_shell_command(
|
|||
return result
|
||||
except Exception as e:
|
||||
print()
|
||||
# Record error in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
tool_name="run_shell_command",
|
||||
tool_parameters={"command": command, "timeout": timeout},
|
||||
step_data={
|
||||
"command": command,
|
||||
"error": str(e),
|
||||
"display_title": "Shell Error",
|
||||
},
|
||||
record_type="tool_execution",
|
||||
is_error=True,
|
||||
error_message=str(e),
|
||||
error_type=type(e).__name__,
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
|
||||
console.print(Panel(str(e), title="❌ Error", border_style="red"))
|
||||
return {"output": str(e), "return_code": 1, "success": False}
|
||||
|
|
@ -7,6 +7,9 @@ from rich.markdown import Markdown
|
|||
from rich.panel import Panel
|
||||
from tavily import TavilyClient
|
||||
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
|
|
@ -21,9 +24,44 @@ def web_search_tavily(query: str) -> Dict:
|
|||
Returns:
|
||||
Dict containing search results from Tavily
|
||||
"""
|
||||
client = TavilyClient(api_key=os.environ["TAVILY_API_KEY"])
|
||||
# Record trajectory before displaying panel
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
tool_name="web_search_tavily",
|
||||
tool_parameters={"query": query},
|
||||
step_data={
|
||||
"query": query,
|
||||
"display_title": "Web Search",
|
||||
},
|
||||
record_type="tool_execution",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
|
||||
# Display search query panel
|
||||
console.print(
|
||||
Panel(Markdown(query), title="🔍 Searching Tavily", border_style="bright_blue")
|
||||
)
|
||||
search_result = client.search(query=query)
|
||||
return search_result
|
||||
|
||||
try:
|
||||
client = TavilyClient(api_key=os.environ["TAVILY_API_KEY"])
|
||||
search_result = client.search(query=query)
|
||||
return search_result
|
||||
except Exception as e:
|
||||
# Record error in trajectory
|
||||
trajectory_repo.create(
|
||||
tool_name="web_search_tavily",
|
||||
tool_parameters={"query": query},
|
||||
step_data={
|
||||
"query": query,
|
||||
"display_title": "Web Search Error",
|
||||
"error": str(e)
|
||||
},
|
||||
record_type="tool_execution",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
# Re-raise the exception to maintain original behavior
|
||||
raise
|
||||
|
|
@ -7,7 +7,7 @@ ensuring consistent test environments and proper isolation.
|
|||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -26,6 +26,39 @@ def mock_config_repository():
|
|||
yield repo
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_trajectory_repository():
|
||||
"""Mock the TrajectoryRepository to avoid database operations during tests."""
|
||||
with patch('ra_aid.database.repositories.trajectory_repository.TrajectoryRepository') as mock:
|
||||
# Setup a mock repository
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.create.return_value = MagicMock(id=1)
|
||||
mock.return_value = mock_repo
|
||||
yield mock_repo
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_human_input_repository():
|
||||
"""Mock the HumanInputRepository to avoid database operations during tests."""
|
||||
with patch('ra_aid.database.repositories.human_input_repository.HumanInputRepository') as mock:
|
||||
# Setup a mock repository
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.get_most_recent_id.return_value = 1
|
||||
mock_repo.create.return_value = MagicMock(id=1)
|
||||
mock.return_value = mock_repo
|
||||
yield mock_repo
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_repository_access(mock_trajectory_repository, mock_human_input_repository):
|
||||
"""Mock all repository accessor functions."""
|
||||
with patch('ra_aid.database.repositories.trajectory_repository.get_trajectory_repository',
|
||||
return_value=mock_trajectory_repository):
|
||||
with patch('ra_aid.database.repositories.human_input_repository.get_human_input_repository',
|
||||
return_value=mock_human_input_repository):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def isolated_db_environment(tmp_path, monkeypatch):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -21,7 +21,11 @@ from ra_aid.anthropic_token_limiter import (
|
|||
state_modifier,
|
||||
)
|
||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
|
||||
from ra_aid.database.repositories.config_repository import ConfigRepositoryManager, get_config_repository, config_repo_var
|
||||
from ra_aid.database.repositories.config_repository import (
|
||||
ConfigRepositoryManager,
|
||||
get_config_repository,
|
||||
config_repo_var,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -34,7 +38,9 @@ def mock_model():
|
|||
@pytest.fixture
|
||||
def mock_config_repository():
|
||||
"""Mock the ConfigRepository to avoid database operations during tests"""
|
||||
with patch('ra_aid.database.repositories.config_repository.config_repo_var') as mock_repo_var:
|
||||
with patch(
|
||||
"ra_aid.database.repositories.config_repository.config_repo_var"
|
||||
) as mock_repo_var:
|
||||
# Setup a mock repository
|
||||
mock_repo = MagicMock()
|
||||
|
||||
|
|
@ -44,6 +50,7 @@ def mock_config_repository():
|
|||
# Setup get method to return config values
|
||||
def get_config(key, default=None):
|
||||
return config.get(key, default)
|
||||
|
||||
mock_repo.get.side_effect = get_config
|
||||
|
||||
# Setup get_all method to return all config values
|
||||
|
|
@ -52,11 +59,13 @@ def mock_config_repository():
|
|||
# Setup set method to update config values
|
||||
def set_config(key, value):
|
||||
config[key] = value
|
||||
|
||||
mock_repo.set.side_effect = set_config
|
||||
|
||||
# Setup update method to update multiple config values
|
||||
def update_config(update_dict):
|
||||
config.update(update_dict)
|
||||
|
||||
mock_repo.update.side_effect = update_config
|
||||
|
||||
# Make the mock context var return our mock repo
|
||||
|
|
@ -65,15 +74,55 @@ def mock_config_repository():
|
|||
yield mock_repo
|
||||
|
||||
|
||||
# These tests have been moved to test_anthropic_token_limiter.py
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_trajectory_repository():
|
||||
"""Mock the TrajectoryRepository to avoid database operations during tests"""
|
||||
with patch(
|
||||
"ra_aid.database.repositories.trajectory_repository.trajectory_repo_var"
|
||||
) as mock_repo_var:
|
||||
# Setup a mock repository
|
||||
mock_repo = MagicMock()
|
||||
|
||||
# Setup create method to return a mock trajectory
|
||||
def mock_create(**kwargs):
|
||||
mock_trajectory = MagicMock()
|
||||
mock_trajectory.id = 1
|
||||
return mock_trajectory
|
||||
|
||||
mock_repo.create.side_effect = mock_create
|
||||
|
||||
# Make the mock context var return our mock repo
|
||||
mock_repo_var.get.return_value = mock_repo
|
||||
|
||||
yield mock_repo
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_human_input_repository():
|
||||
"""Mock the HumanInputRepository to avoid database operations during tests"""
|
||||
with patch(
|
||||
"ra_aid.database.repositories.human_input_repository.human_input_repo_var"
|
||||
) as mock_repo_var:
|
||||
# Setup a mock repository
|
||||
mock_repo = MagicMock()
|
||||
|
||||
# Setup get_most_recent_id method to return a dummy ID
|
||||
mock_repo.get_most_recent_id.return_value = 1
|
||||
|
||||
# Make the mock context var return our mock repo
|
||||
mock_repo_var.get.return_value = mock_repo
|
||||
|
||||
yield mock_repo
|
||||
|
||||
|
||||
def test_create_agent_anthropic(mock_model, mock_config_repository):
|
||||
"""Test create_agent with Anthropic Claude model."""
|
||||
mock_config_repository.update({"provider": "anthropic", "model": "claude-2"})
|
||||
|
||||
with patch("ra_aid.agent_utils.create_react_agent") as mock_react, \
|
||||
patch("ra_aid.anthropic_token_limiter.state_modifier") as mock_state_modifier:
|
||||
with (
|
||||
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
||||
patch("ra_aid.anthropic_token_limiter.state_modifier") as mock_state_modifier,
|
||||
):
|
||||
mock_react.return_value = "react_agent"
|
||||
agent = create_agent(mock_model, [])
|
||||
|
||||
|
|
@ -81,7 +130,7 @@ def test_create_agent_anthropic(mock_model, mock_config_repository):
|
|||
mock_react.assert_called_once_with(
|
||||
mock_model,
|
||||
[],
|
||||
interrupt_after=['tools'],
|
||||
interrupt_after=["tools"],
|
||||
version="v2",
|
||||
state_modifier=mock_react.call_args[1]["state_modifier"],
|
||||
name="React",
|
||||
|
|
@ -173,13 +222,17 @@ def test_create_agent_with_checkpointer(mock_model, mock_config_repository):
|
|||
)
|
||||
|
||||
|
||||
def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_config_repository):
|
||||
def test_create_agent_anthropic_token_limiting_enabled(
|
||||
mock_model, mock_config_repository
|
||||
):
|
||||
"""Test create_agent sets up token limiting for Claude models when enabled."""
|
||||
mock_config_repository.update({
|
||||
"provider": "anthropic",
|
||||
"model": "claude-2",
|
||||
"limit_tokens": True,
|
||||
})
|
||||
mock_config_repository.update(
|
||||
{
|
||||
"provider": "anthropic",
|
||||
"model": "claude-2",
|
||||
"limit_tokens": True,
|
||||
}
|
||||
)
|
||||
|
||||
with (
|
||||
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
||||
|
|
@ -196,13 +249,17 @@ def test_create_agent_anthropic_token_limiting_enabled(mock_model, mock_config_r
|
|||
assert callable(args[1]["state_modifier"])
|
||||
|
||||
|
||||
def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_repository):
|
||||
def test_create_agent_anthropic_token_limiting_disabled(
|
||||
mock_model, mock_config_repository
|
||||
):
|
||||
"""Test create_agent doesn't set up token limiting for Claude models when disabled."""
|
||||
mock_config_repository.update({
|
||||
"provider": "anthropic",
|
||||
"model": "claude-2",
|
||||
"limit_tokens": False,
|
||||
})
|
||||
mock_config_repository.update(
|
||||
{
|
||||
"provider": "anthropic",
|
||||
"model": "claude-2",
|
||||
"limit_tokens": False,
|
||||
}
|
||||
)
|
||||
|
||||
with (
|
||||
patch("ra_aid.agent_utils.create_react_agent") as mock_react,
|
||||
|
|
@ -214,7 +271,9 @@ def test_create_agent_anthropic_token_limiting_disabled(mock_model, mock_config_
|
|||
agent = create_agent(mock_model, [])
|
||||
|
||||
assert agent == "react_agent"
|
||||
mock_react.assert_called_once_with(mock_model, [], interrupt_after=['tools'], version="v2", name="React")
|
||||
mock_react.assert_called_once_with(
|
||||
mock_model, [], interrupt_after=["tools"], version="v2", name="React"
|
||||
)
|
||||
|
||||
|
||||
# These tests have been moved to test_anthropic_token_limiter.py
|
||||
|
|
@ -482,7 +541,9 @@ def test_run_agent_with_retry_checks_crash_status(monkeypatch, mock_config_repos
|
|||
assert "Agent has crashed: Test crash message" in result
|
||||
|
||||
|
||||
def test_run_agent_with_retry_handles_badrequest_error(monkeypatch, mock_config_repository):
|
||||
def test_run_agent_with_retry_handles_badrequest_error(
|
||||
monkeypatch, mock_config_repository
|
||||
):
|
||||
"""Test that run_agent_with_retry properly handles BadRequestError as unretryable."""
|
||||
from ra_aid.agent_context import agent_context, is_crashed
|
||||
from ra_aid.agent_utils import run_agent_with_retry
|
||||
|
|
@ -540,7 +601,9 @@ def test_run_agent_with_retry_handles_badrequest_error(monkeypatch, mock_config_
|
|||
assert is_crashed()
|
||||
|
||||
|
||||
def test_run_agent_with_retry_handles_api_badrequest_error(monkeypatch, mock_config_repository):
|
||||
def test_run_agent_with_retry_handles_api_badrequest_error(
|
||||
monkeypatch, mock_config_repository
|
||||
):
|
||||
"""Test that run_agent_with_retry properly handles API BadRequestError as unretryable."""
|
||||
# Import APIError from anthropic module and patch it on the agent_utils module
|
||||
|
||||
|
|
@ -613,5 +676,7 @@ def test_handle_api_error_resource_exhausted():
|
|||
from ra_aid.agent_utils import _handle_api_error
|
||||
|
||||
# ResourceExhausted exception should be handled without raising
|
||||
resource_exhausted_error = ResourceExhausted("429 Resource has been exhausted (e.g. check quota).")
|
||||
resource_exhausted_error = ResourceExhausted(
|
||||
"429 Resource has been exhausted (e.g. check quota)."
|
||||
)
|
||||
_handle_api_error(resource_exhausted_error, 0, 5, 1)
|
||||
|
|
|
|||
|
|
@ -113,6 +113,40 @@ def mock_work_log_repository():
|
|||
|
||||
yield mock_repo
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_trajectory_repository():
|
||||
"""Mock the TrajectoryRepository to avoid database operations during tests"""
|
||||
with patch('ra_aid.database.repositories.trajectory_repository.trajectory_repo_var') as mock_repo_var:
|
||||
# Setup a mock repository
|
||||
mock_repo = MagicMock()
|
||||
|
||||
# Setup create method to return a mock trajectory
|
||||
def mock_create(**kwargs):
|
||||
mock_trajectory = MagicMock()
|
||||
mock_trajectory.id = 1
|
||||
return mock_trajectory
|
||||
mock_repo.create.side_effect = mock_create
|
||||
|
||||
# Make the mock context var return our mock repo
|
||||
mock_repo_var.get.return_value = mock_repo
|
||||
|
||||
yield mock_repo
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_human_input_repository():
|
||||
"""Mock the HumanInputRepository to avoid database operations during tests"""
|
||||
with patch('ra_aid.database.repositories.human_input_repository.human_input_repo_var') as mock_repo_var:
|
||||
# Setup a mock repository
|
||||
mock_repo = MagicMock()
|
||||
|
||||
# Setup get_most_recent_id method to return a dummy ID
|
||||
mock_repo.get_most_recent_id.return_value = 1
|
||||
|
||||
# Make the mock context var return our mock repo
|
||||
mock_repo_var.get.return_value = mock_repo
|
||||
|
||||
yield mock_repo
|
||||
|
||||
@pytest.fixture
|
||||
def mock_functions():
|
||||
"""Mock functions used in agent.py"""
|
||||
|
|
@ -126,7 +160,9 @@ def mock_functions():
|
|||
patch('ra_aid.tools.agent.get_related_files') as mock_get_files, \
|
||||
patch('ra_aid.tools.agent.get_work_log') as mock_get_work_log, \
|
||||
patch('ra_aid.tools.agent.reset_completion_flags') as mock_reset, \
|
||||
patch('ra_aid.tools.agent.get_completion_message') as mock_get_completion:
|
||||
patch('ra_aid.tools.agent.get_completion_message') as mock_get_completion, \
|
||||
patch('ra_aid.tools.agent.get_trajectory_repository') as mock_get_trajectory_repo, \
|
||||
patch('ra_aid.tools.agent.get_human_input_repository') as mock_get_human_input_repo:
|
||||
|
||||
# Setup mock return values
|
||||
mock_fact_repo.get_facts_dict.return_value = {1: "Test fact 1", 2: "Test fact 2"}
|
||||
|
|
@ -138,6 +174,15 @@ def mock_functions():
|
|||
mock_get_work_log.return_value = "Test work log"
|
||||
mock_get_completion.return_value = "Task completed"
|
||||
|
||||
# Setup mock for trajectory repository
|
||||
mock_trajectory_repo = MagicMock()
|
||||
mock_get_trajectory_repo.return_value = mock_trajectory_repo
|
||||
|
||||
# Setup mock for human input repository
|
||||
mock_human_input_repo = MagicMock()
|
||||
mock_human_input_repo.get_most_recent_id.return_value = 1
|
||||
mock_get_human_input_repo.return_value = mock_human_input_repo
|
||||
|
||||
# Return all mocks as a dictionary
|
||||
yield {
|
||||
'get_key_fact_repository': mock_get_fact_repo,
|
||||
|
|
@ -148,7 +193,9 @@ def mock_functions():
|
|||
'get_related_files': mock_get_files,
|
||||
'get_work_log': mock_get_work_log,
|
||||
'reset_completion_flags': mock_reset,
|
||||
'get_completion_message': mock_get_completion
|
||||
'get_completion_message': mock_get_completion,
|
||||
'get_trajectory_repository': mock_get_trajectory_repo,
|
||||
'get_human_input_repository': mock_get_human_input_repo
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -52,6 +52,40 @@ def mock_config_repository():
|
|||
|
||||
yield mock_repo
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_trajectory_repository():
|
||||
"""Mock the TrajectoryRepository to avoid database operations during tests"""
|
||||
with patch('ra_aid.database.repositories.trajectory_repository.trajectory_repo_var') as mock_repo_var:
|
||||
# Setup a mock repository
|
||||
mock_repo = MagicMock()
|
||||
|
||||
# Setup create method to return a mock trajectory
|
||||
def mock_create(**kwargs):
|
||||
mock_trajectory = MagicMock()
|
||||
mock_trajectory.id = 1
|
||||
return mock_trajectory
|
||||
mock_repo.create.side_effect = mock_create
|
||||
|
||||
# Make the mock context var return our mock repo
|
||||
mock_repo_var.get.return_value = mock_repo
|
||||
|
||||
yield mock_repo
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_human_input_repository():
|
||||
"""Mock the HumanInputRepository to avoid database operations during tests"""
|
||||
with patch('ra_aid.database.repositories.human_input_repository.human_input_repo_var') as mock_repo_var:
|
||||
# Setup a mock repository
|
||||
mock_repo = MagicMock()
|
||||
|
||||
# Setup get_most_recent_id method to return a dummy ID
|
||||
mock_repo.get_most_recent_id.return_value = 1
|
||||
|
||||
# Make the mock context var return our mock repo
|
||||
mock_repo_var.get.return_value = mock_repo
|
||||
|
||||
yield mock_repo
|
||||
|
||||
|
||||
def test_shell_command_cowboy_mode(mock_console, mock_prompt, mock_run_interactive, mock_config_repository):
|
||||
"""Test shell command execution in cowboy mode (no approval)"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
import { defineConfig } from '@vscode/test-cli';
|
||||
|
||||
export default defineConfig({
|
||||
files: 'out/test/**/*.test.js',
|
||||
});
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
{
|
||||
// See http://go.microsoft.com/fwlink/?LinkId=827846
|
||||
// for the documentation about the extensions.json format
|
||||
"recommendations": ["dbaeumer.vscode-eslint", "connor4312.esbuild-problem-matchers", "ms-vscode.extension-test-runner"]
|
||||
}
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
// A launch configuration that compiles the extension and then opens it inside a new window
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
{
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Run Extension",
|
||||
"type": "extensionHost",
|
||||
"request": "launch",
|
||||
"args": [
|
||||
"--extensionDevelopmentPath=${workspaceFolder}"
|
||||
],
|
||||
"outFiles": [
|
||||
"${workspaceFolder}/dist/**/*.js"
|
||||
],
|
||||
"preLaunchTask": "${defaultBuildTask}"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
// Place your settings in this file to overwrite default and user settings.
|
||||
{
|
||||
"files.exclude": {
|
||||
"out": false, // set this to true to hide the "out" folder with the compiled JS files
|
||||
"dist": false // set this to true to hide the "dist" folder with the compiled JS files
|
||||
},
|
||||
"search.exclude": {
|
||||
"out": true, // set this to false to include "out" folder in search results
|
||||
"dist": true // set this to false to include "dist" folder in search results
|
||||
},
|
||||
// Turn off tsc task auto detection since we have the necessary tasks as npm scripts
|
||||
"typescript.tsc.autoDetect": "off"
|
||||
}
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
// See https://go.microsoft.com/fwlink/?LinkId=733558
|
||||
// for the documentation about the tasks.json format
|
||||
{
|
||||
"version": "2.0.0",
|
||||
"tasks": [
|
||||
{
|
||||
"label": "watch",
|
||||
"dependsOn": [
|
||||
"npm: watch:tsc",
|
||||
"npm: watch:esbuild"
|
||||
],
|
||||
"presentation": {
|
||||
"reveal": "never"
|
||||
},
|
||||
"group": {
|
||||
"kind": "build",
|
||||
"isDefault": true
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "npm",
|
||||
"script": "watch:esbuild",
|
||||
"group": "build",
|
||||
"problemMatcher": "$esbuild-watch",
|
||||
"isBackground": true,
|
||||
"label": "npm: watch:esbuild",
|
||||
"presentation": {
|
||||
"group": "watch",
|
||||
"reveal": "never"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "npm",
|
||||
"script": "watch:tsc",
|
||||
"group": "build",
|
||||
"problemMatcher": "$tsc-watch",
|
||||
"isBackground": true,
|
||||
"label": "npm: watch:tsc",
|
||||
"presentation": {
|
||||
"group": "watch",
|
||||
"reveal": "never"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "npm",
|
||||
"script": "watch-tests",
|
||||
"problemMatcher": "$tsc-watch",
|
||||
"isBackground": true,
|
||||
"presentation": {
|
||||
"reveal": "never",
|
||||
"group": "watchers"
|
||||
},
|
||||
"group": "build"
|
||||
},
|
||||
{
|
||||
"label": "tasks: watch-tests",
|
||||
"dependsOn": [
|
||||
"npm: watch",
|
||||
"npm: watch-tests"
|
||||
],
|
||||
"problemMatcher": []
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
.vscode/**
|
||||
.vscode-test/**
|
||||
out/**
|
||||
node_modules/**
|
||||
src/**
|
||||
.gitignore
|
||||
.yarnrc
|
||||
esbuild.js
|
||||
vsc-extension-quickstart.md
|
||||
**/tsconfig.json
|
||||
**/eslint.config.mjs
|
||||
**/*.map
|
||||
**/*.ts
|
||||
**/.vscode-test.*
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# Change Log
|
||||
|
||||
All notable changes to the "ra-aid" extension will be documented in this file.
|
||||
|
||||
Check [Keep a Changelog](http://keepachangelog.com/) for recommendations on how to structure this file.
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
- Initial release
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
# ra-aid README
|
||||
|
||||
This is the README for your extension "ra-aid". After writing up a brief description, we recommend including the following sections.
|
||||
|
||||
## Features
|
||||
|
||||
Describe specific features of your extension including screenshots of your extension in action. Image paths are relative to this README file.
|
||||
|
||||
For example if there is an image subfolder under your extension project workspace:
|
||||
|
||||
\!\[feature X\]\(images/feature-x.png\)
|
||||
|
||||
> Tip: Many popular extensions utilize animations. This is an excellent way to show off your extension! We recommend short, focused animations that are easy to follow.
|
||||
|
||||
## Requirements
|
||||
|
||||
If you have any requirements or dependencies, add a section describing those and how to install and configure them.
|
||||
|
||||
## Extension Settings
|
||||
|
||||
Include if your extension adds any VS Code settings through the `contributes.configuration` extension point.
|
||||
|
||||
For example:
|
||||
|
||||
This extension contributes the following settings:
|
||||
|
||||
* `myExtension.enable`: Enable/disable this extension.
|
||||
* `myExtension.thing`: Set to `blah` to do something.
|
||||
|
||||
## Known Issues
|
||||
|
||||
Calling out known issues can help limit users opening duplicate issues against your extension.
|
||||
|
||||
## Release Notes
|
||||
|
||||
Users appreciate release notes as you update your extension.
|
||||
|
||||
### 1.0.0
|
||||
|
||||
Initial release of ...
|
||||
|
||||
### 1.0.1
|
||||
|
||||
Fixed issue #.
|
||||
|
||||
### 1.1.0
|
||||
|
||||
Added features X, Y, and Z.
|
||||
|
||||
---
|
||||
|
||||
## Following extension guidelines
|
||||
|
||||
Ensure that you've read through the extensions guidelines and follow the best practices for creating your extension.
|
||||
|
||||
* [Extension Guidelines](https://code.visualstudio.com/api/references/extension-guidelines)
|
||||
|
||||
## Working with Markdown
|
||||
|
||||
You can author your README using Visual Studio Code. Here are some useful editor keyboard shortcuts:
|
||||
|
||||
* Split the editor (`Cmd+\` on macOS or `Ctrl+\` on Windows and Linux).
|
||||
* Toggle preview (`Shift+Cmd+V` on macOS or `Shift+Ctrl+V` on Windows and Linux).
|
||||
* Press `Ctrl+Space` (Windows, Linux, macOS) to see a list of Markdown snippets.
|
||||
|
||||
## For more information
|
||||
|
||||
* [Visual Studio Code's Markdown Support](http://code.visualstudio.com/docs/languages/markdown)
|
||||
* [Markdown Syntax Reference](https://help.github.com/articles/markdown-basics/)
|
||||
|
||||
**Enjoy!**
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 6.5 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 6.6 KiB |
|
|
@ -0,0 +1,56 @@
|
|||
const esbuild = require("esbuild");
|
||||
|
||||
const production = process.argv.includes('--production');
|
||||
const watch = process.argv.includes('--watch');
|
||||
|
||||
/**
|
||||
* @type {import('esbuild').Plugin}
|
||||
*/
|
||||
const esbuildProblemMatcherPlugin = {
|
||||
name: 'esbuild-problem-matcher',
|
||||
|
||||
setup(build) {
|
||||
build.onStart(() => {
|
||||
console.log('[watch] build started');
|
||||
});
|
||||
build.onEnd((result) => {
|
||||
result.errors.forEach(({ text, location }) => {
|
||||
console.error(`✘ [ERROR] ${text}`);
|
||||
console.error(` ${location.file}:${location.line}:${location.column}:`);
|
||||
});
|
||||
console.log('[watch] build finished');
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
async function main() {
|
||||
const ctx = await esbuild.context({
|
||||
entryPoints: [
|
||||
'src/extension.ts'
|
||||
],
|
||||
bundle: true,
|
||||
format: 'cjs',
|
||||
minify: production,
|
||||
sourcemap: !production,
|
||||
sourcesContent: false,
|
||||
platform: 'node',
|
||||
outfile: 'dist/extension.js',
|
||||
external: ['vscode'],
|
||||
logLevel: 'silent',
|
||||
plugins: [
|
||||
/* add to the end of plugins array */
|
||||
esbuildProblemMatcherPlugin,
|
||||
],
|
||||
});
|
||||
if (watch) {
|
||||
await ctx.watch();
|
||||
} else {
|
||||
await ctx.rebuild();
|
||||
await ctx.dispose();
|
||||
}
|
||||
}
|
||||
|
||||
main().catch(e => {
|
||||
console.error(e);
|
||||
process.exit(1);
|
||||
});
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
import typescriptEslint from "@typescript-eslint/eslint-plugin";
|
||||
import tsParser from "@typescript-eslint/parser";
|
||||
|
||||
export default [{
|
||||
files: ["**/*.ts"],
|
||||
}, {
|
||||
plugins: {
|
||||
"@typescript-eslint": typescriptEslint,
|
||||
},
|
||||
|
||||
languageOptions: {
|
||||
parser: tsParser,
|
||||
ecmaVersion: 2022,
|
||||
sourceType: "module",
|
||||
},
|
||||
|
||||
rules: {
|
||||
"@typescript-eslint/naming-convention": ["warn", {
|
||||
selector: "import",
|
||||
format: ["camelCase", "PascalCase"],
|
||||
}],
|
||||
|
||||
curly: "warn",
|
||||
eqeqeq: "warn",
|
||||
"no-throw-literal": "warn",
|
||||
semi: "warn",
|
||||
},
|
||||
}];
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,67 @@
|
|||
{
|
||||
"name": "ra-aid",
|
||||
"displayName": "RA.Aid",
|
||||
"description": "Develop software autonomously.",
|
||||
"version": "0.0.1",
|
||||
"engines": {
|
||||
"vscode": "^1.98.0"
|
||||
},
|
||||
"categories": [
|
||||
"Other"
|
||||
],
|
||||
"activationEvents": [],
|
||||
"main": "./dist/extension.js",
|
||||
"contributes": {
|
||||
"viewsContainers": {
|
||||
"activitybar": [
|
||||
{
|
||||
"id": "ra-aid-view",
|
||||
"title": "RA.Aid",
|
||||
"icon": "assets/RA-white-transp.png"
|
||||
}
|
||||
]
|
||||
},
|
||||
"views": {
|
||||
"ra-aid-view": [
|
||||
{
|
||||
"type": "webview",
|
||||
"id": "ra-aid.view",
|
||||
"name": "RA.Aid"
|
||||
}
|
||||
]
|
||||
},
|
||||
"commands": [
|
||||
{
|
||||
"command": "ra-aid.helloWorld",
|
||||
"title": "Hello World"
|
||||
}
|
||||
]
|
||||
},
|
||||
"scripts": {
|
||||
"vscode:prepublish": "npm run package",
|
||||
"compile": "npm run check-types && npm run lint && node esbuild.js",
|
||||
"watch": "npm-run-all -p watch:*",
|
||||
"watch:esbuild": "node esbuild.js --watch",
|
||||
"watch:tsc": "tsc --noEmit --watch --project tsconfig.json",
|
||||
"package": "npm run check-types && npm run lint && node esbuild.js --production",
|
||||
"compile-tests": "tsc -p . --outDir out",
|
||||
"watch-tests": "tsc -p . -w --outDir out",
|
||||
"pretest": "npm run compile-tests && npm run compile && npm run lint",
|
||||
"check-types": "tsc --noEmit",
|
||||
"lint": "eslint src",
|
||||
"test": "vscode-test"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/vscode": "^1.98.0",
|
||||
"@types/mocha": "^10.0.10",
|
||||
"@types/node": "20.x",
|
||||
"@typescript-eslint/eslint-plugin": "^8.25.0",
|
||||
"@typescript-eslint/parser": "^8.25.0",
|
||||
"eslint": "^9.21.0",
|
||||
"esbuild": "^0.25.0",
|
||||
"npm-run-all": "^4.1.5",
|
||||
"typescript": "^5.7.3",
|
||||
"@vscode/test-cli": "^0.0.10",
|
||||
"@vscode/test-electron": "^2.4.1"
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,133 @@
|
|||
// The module 'vscode' contains the VS Code extensibility API
|
||||
import * as vscode from 'vscode';
|
||||
|
||||
/**
|
||||
* WebviewViewProvider implementation for the RA.Aid panel
|
||||
*/
|
||||
class RAWebviewViewProvider implements vscode.WebviewViewProvider {
|
||||
constructor(private readonly _extensionUri: vscode.Uri) {}
|
||||
|
||||
/**
|
||||
* Called when a view is first created to initialize the webview
|
||||
*/
|
||||
public resolveWebviewView(
|
||||
webviewView: vscode.WebviewView,
|
||||
context: vscode.WebviewViewResolveContext,
|
||||
_token: vscode.CancellationToken
|
||||
) {
|
||||
// Set options for the webview
|
||||
webviewView.webview.options = {
|
||||
// Enable JavaScript in the webview
|
||||
enableScripts: true,
|
||||
// Restrict the webview to only load resources from the extension's directory
|
||||
localResourceRoots: [this._extensionUri]
|
||||
};
|
||||
|
||||
// Set the HTML content of the webview
|
||||
webviewView.webview.html = this._getHtmlForWebview(webviewView.webview);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates HTML content for the webview with proper security policies
|
||||
*/
|
||||
private _getHtmlForWebview(webview: vscode.Webview): string {
|
||||
// Create a URI to the extension's assets directory
|
||||
const logoUri = webview.asWebviewUri(vscode.Uri.joinPath(this._extensionUri, 'assets', 'RA.png'));
|
||||
|
||||
// Create a URI to the script file
|
||||
// const scriptUri = webview.asWebviewUri(vscode.Uri.joinPath(this._extensionUri, 'dist', 'webview.js'));
|
||||
|
||||
// Use a nonce to whitelist scripts
|
||||
const nonce = getNonce();
|
||||
|
||||
return `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta http-equiv="Content-Security-Policy" content="default-src 'none'; img-src ${webview.cspSource} https:; style-src ${webview.cspSource} 'unsafe-inline'; script-src 'nonce-${nonce}';">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>RA.Aid</title>
|
||||
<style>
|
||||
body {
|
||||
padding: 0;
|
||||
color: var(--vscode-foreground);
|
||||
font-size: var(--vscode-font-size);
|
||||
font-weight: var(--vscode-font-weight);
|
||||
font-family: var(--vscode-font-family);
|
||||
background-color: var(--vscode-editor-background);
|
||||
}
|
||||
.container {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 20px;
|
||||
text-align: center;
|
||||
}
|
||||
.logo {
|
||||
width: 100px;
|
||||
height: 100px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
h1 {
|
||||
color: var(--vscode-editor-foreground);
|
||||
font-size: 1.3em;
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
p {
|
||||
color: var(--vscode-foreground);
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<img src="${logoUri}" alt="RA.Aid Logo" class="logo">
|
||||
<h1>RA.Aid</h1>
|
||||
<p>Your research and development assistant.</p>
|
||||
<p>More features coming soon!</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a random nonce for CSP
|
||||
*/
|
||||
function getNonce() {
|
||||
let text = '';
|
||||
const possible = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789';
|
||||
for (let i = 0; i < 32; i++) {
|
||||
text += possible.charAt(Math.floor(Math.random() * possible.length));
|
||||
}
|
||||
return text;
|
||||
}
|
||||
|
||||
// This method is called when your extension is activated
|
||||
export function activate(context: vscode.ExtensionContext) {
|
||||
// Use the console to output diagnostic information (console.log) and errors (console.error)
|
||||
console.log('Congratulations, your extension "ra-aid" is now active!');
|
||||
|
||||
// Register the WebviewViewProvider
|
||||
const provider = new RAWebviewViewProvider(context.extensionUri);
|
||||
const viewRegistration = vscode.window.registerWebviewViewProvider(
|
||||
'ra-aid.view', // Must match the view id in package.json
|
||||
provider
|
||||
);
|
||||
context.subscriptions.push(viewRegistration);
|
||||
|
||||
// The command has been defined in the package.json file
|
||||
// Now provide the implementation of the command with registerCommand
|
||||
// The commandId parameter must match the command field in package.json
|
||||
const disposable = vscode.commands.registerCommand('ra-aid.helloWorld', () => {
|
||||
// The code you place here will be executed every time your command is executed
|
||||
// Display a message box to the user
|
||||
vscode.window.showInformationMessage('Hello World from RA.Aid!');
|
||||
});
|
||||
|
||||
context.subscriptions.push(disposable);
|
||||
}
|
||||
|
||||
// This method is called when your extension is deactivated
|
||||
export function deactivate() {}
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
import * as assert from 'assert';
|
||||
|
||||
// You can import and use all API from the 'vscode' module
|
||||
// as well as import your extension to test it
|
||||
import * as vscode from 'vscode';
|
||||
// import * as myExtension from '../../extension';
|
||||
|
||||
suite('Extension Test Suite', () => {
|
||||
vscode.window.showInformationMessage('Start all tests.');
|
||||
|
||||
test('Sample test', () => {
|
||||
assert.strictEqual(-1, [1, 2, 3].indexOf(5));
|
||||
assert.strictEqual(-1, [1, 2, 3].indexOf(0));
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
"compilerOptions": {
|
||||
"module": "Node16",
|
||||
"target": "ES2022",
|
||||
"lib": [
|
||||
"ES2022"
|
||||
],
|
||||
"sourceMap": true,
|
||||
"rootDir": "src",
|
||||
"strict": true, /* enable all strict type-checking options */
|
||||
/* Additional Checks */
|
||||
// "noImplicitReturns": true, /* Report error when not all code paths in function return a value. */
|
||||
// "noFallthroughCasesInSwitch": true, /* Report errors for fallthrough cases in switch statement. */
|
||||
// "noUnusedParameters": true, /* Report errors on unused parameters. */
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
# Welcome to your VS Code Extension
|
||||
|
||||
## What's in the folder
|
||||
|
||||
* This folder contains all of the files necessary for your extension.
|
||||
* `package.json` - this is the manifest file in which you declare your extension and command.
|
||||
* The sample plugin registers a command and defines its title and command name. With this information VS Code can show the command in the command palette. It doesn’t yet need to load the plugin.
|
||||
* `src/extension.ts` - this is the main file where you will provide the implementation of your command.
|
||||
* The file exports one function, `activate`, which is called the very first time your extension is activated (in this case by executing the command). Inside the `activate` function we call `registerCommand`.
|
||||
* We pass the function containing the implementation of the command as the second parameter to `registerCommand`.
|
||||
|
||||
## Setup
|
||||
|
||||
* install the recommended extensions (amodio.tsl-problem-matcher, ms-vscode.extension-test-runner, and dbaeumer.vscode-eslint)
|
||||
|
||||
|
||||
## Get up and running straight away
|
||||
|
||||
* Press `F5` to open a new window with your extension loaded.
|
||||
* Run your command from the command palette by pressing (`Ctrl+Shift+P` or `Cmd+Shift+P` on Mac) and typing `Hello World`.
|
||||
* Set breakpoints in your code inside `src/extension.ts` to debug your extension.
|
||||
* Find output from your extension in the debug console.
|
||||
|
||||
## Make changes
|
||||
|
||||
* You can relaunch the extension from the debug toolbar after changing code in `src/extension.ts`.
|
||||
* You can also reload (`Ctrl+R` or `Cmd+R` on Mac) the VS Code window with your extension to load your changes.
|
||||
|
||||
|
||||
## Explore the API
|
||||
|
||||
* You can open the full set of our API when you open the file `node_modules/@types/vscode/index.d.ts`.
|
||||
|
||||
## Run tests
|
||||
|
||||
* Install the [Extension Test Runner](https://marketplace.visualstudio.com/items?itemName=ms-vscode.extension-test-runner)
|
||||
* Run the "watch" task via the **Tasks: Run Task** command. Make sure this is running, or tests might not be discovered.
|
||||
* Open the Testing view from the activity bar and click the Run Test" button, or use the hotkey `Ctrl/Cmd + ; A`
|
||||
* See the output of the test result in the Test Results view.
|
||||
* Make changes to `src/test/extension.test.ts` or create new test files inside the `test` folder.
|
||||
* The provided test runner will only consider files matching the name pattern `**.test.ts`.
|
||||
* You can create folders inside the `test` folder to structure your tests any way you want.
|
||||
|
||||
## Go further
|
||||
|
||||
* Reduce the extension size and improve the startup time by [bundling your extension](https://code.visualstudio.com/api/working-with-extensions/bundling-extension).
|
||||
* [Publish your extension](https://code.visualstudio.com/api/working-with-extensions/publishing-extension) on the VS Code extension marketplace.
|
||||
* Automate builds by setting up [Continuous Integration](https://code.visualstudio.com/api/working-with-extensions/continuous-integration).
|
||||
Loading…
Reference in New Issue