record trajectory at all steps
This commit is contained in:
parent
5d899d3d13
commit
ae9cf5021b
|
|
@ -616,6 +616,24 @@ def main():
|
|||
)
|
||||
|
||||
if args.research_only:
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
error_message = "Chat mode cannot be used with --research-only"
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"display_title": "Error",
|
||||
"error_message": error_message,
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=error_message,
|
||||
)
|
||||
except Exception as traj_error:
|
||||
# Swallow exception to avoid recursion
|
||||
logger.debug(f"Error recording trajectory: {traj_error}")
|
||||
pass
|
||||
print_error("Chat mode cannot be used with --research-only")
|
||||
sys.exit(1)
|
||||
|
||||
|
|
@ -719,6 +737,24 @@ def main():
|
|||
|
||||
# Validate message is provided
|
||||
if not args.message:
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
error_message = "--message is required"
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"display_title": "Error",
|
||||
"error_message": error_message,
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=error_message,
|
||||
)
|
||||
except Exception as traj_error:
|
||||
# Swallow exception to avoid recursion
|
||||
logger.debug(f"Error recording trajectory: {traj_error}")
|
||||
pass
|
||||
print_error("--message is required")
|
||||
sys.exit(1)
|
||||
|
||||
|
|
|
|||
|
|
@ -462,6 +462,38 @@ class CiaynAgent:
|
|||
error_msg = f"Error: {str(e)} \n Could not execute code: {code}"
|
||||
tool_name = self.extract_tool_name(code)
|
||||
logger.info(f"Tool execution failed for `{tool_name}`: {str(e)}")
|
||||
|
||||
# Record error in trajectory
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.database.connection import get_db
|
||||
|
||||
# Create repositories directly
|
||||
trajectory_repo = TrajectoryRepository(get_db())
|
||||
human_input_repo = HumanInputRepository(get_db())
|
||||
human_input_id = human_input_repo.get_most_recent_id()
|
||||
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": f"Tool execution failed for `{tool_name}`:\nError: {str(e)}",
|
||||
"display_title": "Tool Error",
|
||||
"code": code,
|
||||
"tool_name": tool_name
|
||||
},
|
||||
record_type="tool_execution",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=str(e),
|
||||
error_type="ToolExecutionError",
|
||||
tool_name=tool_name,
|
||||
tool_parameters={"code": code}
|
||||
)
|
||||
except Exception as trajectory_error:
|
||||
# Just log and continue if there's an error in trajectory recording
|
||||
logger.error(f"Error recording trajectory for tool error display: {trajectory_error}")
|
||||
|
||||
print_warning(f"Tool execution failed for `{tool_name}`:\nError: {str(e)}\n\nCode:\n\n````\n{code}\n````", title="Tool Error")
|
||||
raise ToolExecutionError(
|
||||
error_msg, base_message=msg, tool_name=tool_name
|
||||
|
|
@ -495,6 +527,36 @@ class CiaynAgent:
|
|||
if not fallback_response:
|
||||
self.chat_history.append(err_msg)
|
||||
logger.info(f"Tool fallback was attempted but did not succeed. Original error: {str(e)}")
|
||||
|
||||
# Record error in trajectory
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.database.connection import get_db
|
||||
|
||||
# Create repositories directly
|
||||
trajectory_repo = TrajectoryRepository(get_db())
|
||||
human_input_repo = HumanInputRepository(get_db())
|
||||
human_input_id = human_input_repo.get_most_recent_id()
|
||||
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": f"Tool fallback was attempted but did not succeed. Original error: {str(e)}",
|
||||
"display_title": "Fallback Failed",
|
||||
"tool_name": e.tool_name if hasattr(e, "tool_name") else "unknown_tool"
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=str(e),
|
||||
error_type="FallbackFailedError",
|
||||
tool_name=e.tool_name if hasattr(e, "tool_name") else "unknown_tool"
|
||||
)
|
||||
except Exception as trajectory_error:
|
||||
# Just log and continue if there's an error in trajectory recording
|
||||
logger.error(f"Error recording trajectory for fallback failed warning: {trajectory_error}")
|
||||
|
||||
print_warning(f"Tool fallback was attempted but did not succeed. Original error: {str(e)}", title="Fallback Failed")
|
||||
return ""
|
||||
|
||||
|
|
@ -595,6 +657,35 @@ class CiaynAgent:
|
|||
matches = re.findall(pattern, response, re.DOTALL)
|
||||
if len(matches) == 0:
|
||||
logger.info("Failed to extract a valid tool call from the model's response.")
|
||||
|
||||
# Record error in trajectory
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.database.connection import get_db
|
||||
|
||||
# Create repositories directly
|
||||
trajectory_repo = TrajectoryRepository(get_db())
|
||||
human_input_repo = HumanInputRepository(get_db())
|
||||
human_input_id = human_input_repo.get_most_recent_id()
|
||||
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": "Failed to extract a valid tool call from the model's response.",
|
||||
"display_title": "Extraction Failed",
|
||||
"code": code
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message="Failed to extract a valid tool call from the model's response.",
|
||||
error_type="ExtractionError"
|
||||
)
|
||||
except Exception as trajectory_error:
|
||||
# Just log and continue if there's an error in trajectory recording
|
||||
logger.error(f"Error recording trajectory for extraction error display: {trajectory_error}")
|
||||
|
||||
print_warning("Failed to extract a valid tool call from the model's response.", title="Extraction Failed")
|
||||
raise ToolExecutionError("Failed to extract tool call")
|
||||
ma = matches[0][0].strip()
|
||||
|
|
@ -647,6 +738,36 @@ class CiaynAgent:
|
|||
|
||||
warning_message = f"The model returned an empty response (attempt {empty_response_count} of {max_empty_responses}). Requesting the model to make a valid tool call."
|
||||
logger.info(warning_message)
|
||||
|
||||
# Record warning in trajectory
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.database.connection import get_db_connection
|
||||
|
||||
# Create repositories directly
|
||||
trajectory_repo = TrajectoryRepository(get_db_connection())
|
||||
human_input_repo = HumanInputRepository(get_db_connection())
|
||||
human_input_id = human_input_repo.get_most_recent_id()
|
||||
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"warning_message": warning_message,
|
||||
"display_title": "Empty Response",
|
||||
"attempt": empty_response_count,
|
||||
"max_attempts": max_empty_responses
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=warning_message,
|
||||
error_type="EmptyResponseWarning"
|
||||
)
|
||||
except Exception as trajectory_error:
|
||||
# Just log and continue if there's an error in trajectory recording
|
||||
logger.error(f"Error recording trajectory for empty response warning: {trajectory_error}")
|
||||
|
||||
print_warning(warning_message, title="Empty Response")
|
||||
|
||||
if empty_response_count >= max_empty_responses:
|
||||
|
|
@ -658,6 +779,36 @@ class CiaynAgent:
|
|||
|
||||
error_message = "The agent has crashed after multiple failed attempts to generate a valid tool call."
|
||||
logger.error(error_message)
|
||||
|
||||
# Record error in trajectory
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.database.connection import get_db_connection
|
||||
|
||||
# Create repositories directly
|
||||
trajectory_repo = TrajectoryRepository(get_db_connection())
|
||||
human_input_repo = HumanInputRepository(get_db_connection())
|
||||
human_input_id = human_input_repo.get_most_recent_id()
|
||||
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": error_message,
|
||||
"display_title": "Agent Crashed",
|
||||
"crash_reason": crash_message,
|
||||
"attempts": empty_response_count
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=error_message,
|
||||
error_type="AgentCrashError"
|
||||
)
|
||||
except Exception as trajectory_error:
|
||||
# Just log and continue if there's an error in trajectory recording
|
||||
logger.error(f"Error recording trajectory for agent crash: {trajectory_error}")
|
||||
|
||||
print_error(error_message)
|
||||
|
||||
yield self._create_error_chunk(crash_message)
|
||||
|
|
|
|||
|
|
@ -106,6 +106,7 @@ from ra_aid.database.repositories.human_input_repository import (
|
|||
from ra_aid.database.repositories.research_note_repository import (
|
||||
get_research_note_repository,
|
||||
)
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.database.repositories.work_log_repository import get_work_log_repository
|
||||
from ra_aid.model_formatters import format_key_facts_dict
|
||||
from ra_aid.model_formatters.key_snippets_formatter import format_key_snippets_dict
|
||||
|
|
@ -460,9 +461,23 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
|
|||
|
||||
logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e))
|
||||
delay = base_delay * (2**attempt)
|
||||
print_error(
|
||||
f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
|
||||
error_message = f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
|
||||
|
||||
# Record error in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": error_message,
|
||||
"display_title": "Error",
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=error_message
|
||||
)
|
||||
|
||||
print_error(error_message)
|
||||
start = time.monotonic()
|
||||
while time.monotonic() - start < delay:
|
||||
check_interrupt()
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from ra_aid import agent_utils
|
|||
from ra_aid.database.repositories.key_fact_repository import get_key_fact_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.llm import initialize_llm
|
||||
from ra_aid.prompts.key_facts_gc_prompts import KEY_FACTS_GC_PROMPT
|
||||
from ra_aid.tools.memory import log_work_event
|
||||
|
|
@ -82,6 +83,22 @@ def delete_key_facts(fact_ids: List[int]) -> str:
|
|||
if deleted_facts:
|
||||
deleted_msg = "Successfully deleted facts:\n" + "\n".join([f"- #{fact_id}: {content}" for fact_id, content in deleted_facts])
|
||||
result_parts.append(deleted_msg)
|
||||
# Record GC operation in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"deleted_facts": deleted_facts,
|
||||
"display_title": "Facts Deleted",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_facts_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(
|
||||
Panel(Markdown(deleted_msg), title="Facts Deleted", border_style="green")
|
||||
)
|
||||
|
|
@ -89,6 +106,22 @@ def delete_key_facts(fact_ids: List[int]) -> str:
|
|||
if protected_facts:
|
||||
protected_msg = "Protected facts (associated with current request):\n" + "\n".join([f"- #{fact_id}: {content}" for fact_id, content in protected_facts])
|
||||
result_parts.append(protected_msg)
|
||||
# Record GC operation in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"protected_facts": protected_facts,
|
||||
"display_title": "Facts Protected",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_facts_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(
|
||||
Panel(Markdown(protected_msg), title="Facts Protected", border_style="blue")
|
||||
)
|
||||
|
|
@ -120,10 +153,44 @@ def run_key_facts_gc_agent() -> None:
|
|||
fact_count = len(facts)
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access key fact repository: {str(e)}")
|
||||
# Record GC error in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error": str(e),
|
||||
"display_title": "GC Error",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_facts_gc_agent",
|
||||
is_error=True,
|
||||
error_message=str(e),
|
||||
error_type="Repository Error"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(Panel(f"Error: {str(e)}", title="🗑 GC Error", border_style="red"))
|
||||
return # Exit the function if we can't access the repository
|
||||
|
||||
# Display status panel with fact count included
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"fact_count": fact_count,
|
||||
"display_title": "Garbage Collection",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_facts_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(Panel(f"Gathering my thoughts...\nCurrent number of key facts: {fact_count}", title="🗑 Garbage Collection"))
|
||||
|
||||
# Only run the agent if we actually have facts to clean
|
||||
|
|
@ -185,6 +252,24 @@ def run_key_facts_gc_agent() -> None:
|
|||
# Show info panel with updated count and protected facts count
|
||||
protected_count = len(protected_facts)
|
||||
if protected_count > 0:
|
||||
# Record GC completion in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"original_count": fact_count,
|
||||
"updated_count": updated_count,
|
||||
"protected_count": protected_count,
|
||||
"display_title": "GC Complete",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_facts_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
f"Cleaned key facts: {fact_count} → {updated_count}\nProtected facts (associated with current request): {protected_count}",
|
||||
|
|
@ -192,6 +277,24 @@ def run_key_facts_gc_agent() -> None:
|
|||
)
|
||||
)
|
||||
else:
|
||||
# Record GC completion in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"original_count": fact_count,
|
||||
"updated_count": updated_count,
|
||||
"protected_count": 0,
|
||||
"display_title": "GC Complete",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_facts_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
f"Cleaned key facts: {fact_count} → {updated_count}",
|
||||
|
|
@ -199,6 +302,40 @@ def run_key_facts_gc_agent() -> None:
|
|||
)
|
||||
)
|
||||
else:
|
||||
# Record GC info in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"protected_count": len(protected_facts),
|
||||
"message": "All facts are protected",
|
||||
"display_title": "GC Info",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_facts_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(Panel(f"All {len(protected_facts)} facts are associated with the current request and protected from deletion.", title="🗑 GC Info"))
|
||||
else:
|
||||
# Record GC info in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"fact_count": 0,
|
||||
"message": "No key facts to clean",
|
||||
"display_title": "GC Info",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_facts_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(Panel("No key facts to clean.", title="🗑 GC Info"))
|
||||
|
|
@ -18,6 +18,7 @@ from ra_aid import agent_utils
|
|||
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.llm import initialize_llm
|
||||
from ra_aid.prompts.key_snippets_gc_prompts import KEY_SNIPPETS_GC_PROMPT
|
||||
from ra_aid.tools.memory import log_work_event
|
||||
|
|
@ -65,6 +66,23 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
|
|||
success = get_key_snippet_repository().delete(snippet_id)
|
||||
if success:
|
||||
success_msg = f"Successfully deleted snippet #{snippet_id} from {filepath}"
|
||||
# Record GC operation in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"deleted_snippet_id": snippet_id,
|
||||
"filepath": filepath,
|
||||
"display_title": "Snippet Deleted",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_snippets_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(success_msg), title="Snippet Deleted", border_style="green"
|
||||
|
|
@ -86,6 +104,22 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
|
|||
if protected_snippets:
|
||||
protected_msg = "Protected snippets (associated with current request):\n" + "\n".join([f"- #{snippet_id}: {filepath}" for snippet_id, filepath in protected_snippets])
|
||||
result_parts.append(protected_msg)
|
||||
# Record GC operation in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"protected_snippets": protected_snippets,
|
||||
"display_title": "Snippets Protected",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_snippets_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(
|
||||
Panel(Markdown(protected_msg), title="Snippets Protected", border_style="blue")
|
||||
)
|
||||
|
|
@ -116,6 +150,21 @@ def run_key_snippets_gc_agent() -> None:
|
|||
snippet_count = len(snippets)
|
||||
|
||||
# Display status panel with snippet count included
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"snippet_count": snippet_count,
|
||||
"display_title": "Garbage Collection",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_snippets_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(Panel(f"Gathering my thoughts...\nCurrent number of key snippets: {snippet_count}", title="🗑 Garbage Collection"))
|
||||
|
||||
# Only run the agent if we actually have snippets to clean
|
||||
|
|
@ -185,6 +234,24 @@ def run_key_snippets_gc_agent() -> None:
|
|||
# Show info panel with updated count and protected snippets count
|
||||
protected_count = len(protected_snippets)
|
||||
if protected_count > 0:
|
||||
# Record GC completion in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"original_count": snippet_count,
|
||||
"updated_count": updated_count,
|
||||
"protected_count": protected_count,
|
||||
"display_title": "GC Complete",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_snippets_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
f"Cleaned key snippets: {snippet_count} → {updated_count}\nProtected snippets (associated with current request): {protected_count}",
|
||||
|
|
@ -192,6 +259,24 @@ def run_key_snippets_gc_agent() -> None:
|
|||
)
|
||||
)
|
||||
else:
|
||||
# Record GC completion in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"original_count": snippet_count,
|
||||
"updated_count": updated_count,
|
||||
"protected_count": 0,
|
||||
"display_title": "GC Complete",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_snippets_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
f"Cleaned key snippets: {snippet_count} → {updated_count}",
|
||||
|
|
@ -199,6 +284,40 @@ def run_key_snippets_gc_agent() -> None:
|
|||
)
|
||||
)
|
||||
else:
|
||||
# Record GC info in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"protected_count": len(protected_snippets),
|
||||
"message": "All snippets are protected",
|
||||
"display_title": "GC Info",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_snippets_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(Panel(f"All {len(protected_snippets)} snippets are associated with the current request and protected from deletion.", title="🗑 GC Info"))
|
||||
else:
|
||||
# Record GC info in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"snippet_count": 0,
|
||||
"message": "No key snippets to clean",
|
||||
"display_title": "GC Info",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="key_snippets_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(Panel("No key snippets to clean.", title="🗑 GC Info"))
|
||||
|
|
@ -22,6 +22,7 @@ from ra_aid.agent_utils import create_agent, run_agent_with_retry
|
|||
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.llm import initialize_llm
|
||||
from ra_aid.model_formatters.research_notes_formatter import format_research_note
|
||||
from ra_aid.tools.memory import log_work_event
|
||||
|
|
@ -84,6 +85,22 @@ def delete_research_notes(note_ids: List[int]) -> str:
|
|||
if deleted_notes:
|
||||
deleted_msg = "Successfully deleted research notes:\n" + "\n".join([f"- #{note_id}: {content[:100]}..." if len(content) > 100 else f"- #{note_id}: {content}" for note_id, content in deleted_notes])
|
||||
result_parts.append(deleted_msg)
|
||||
# Record GC operation in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"deleted_notes": deleted_notes,
|
||||
"display_title": "Research Notes Deleted",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="research_notes_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(
|
||||
Panel(Markdown(deleted_msg), title="Research Notes Deleted", border_style="green")
|
||||
)
|
||||
|
|
@ -91,6 +108,22 @@ def delete_research_notes(note_ids: List[int]) -> str:
|
|||
if protected_notes:
|
||||
protected_msg = "Protected research notes (associated with current request):\n" + "\n".join([f"- #{note_id}: {content[:100]}..." if len(content) > 100 else f"- #{note_id}: {content}" for note_id, content in protected_notes])
|
||||
result_parts.append(protected_msg)
|
||||
# Record GC operation in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"protected_notes": protected_notes,
|
||||
"display_title": "Research Notes Protected",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="research_notes_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(
|
||||
Panel(Markdown(protected_msg), title="Research Notes Protected", border_style="blue")
|
||||
)
|
||||
|
|
@ -125,10 +158,44 @@ def run_research_notes_gc_agent(threshold: int = 30) -> None:
|
|||
note_count = len(notes)
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to access research note repository: {str(e)}")
|
||||
# Record GC error in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error": str(e),
|
||||
"display_title": "GC Error",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="research_notes_gc_agent",
|
||||
is_error=True,
|
||||
error_message=str(e),
|
||||
error_type="Repository Error"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(Panel(f"Error: {str(e)}", title="🗑 GC Error", border_style="red"))
|
||||
return # Exit the function if we can't access the repository
|
||||
|
||||
# Display status panel with note count included
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"note_count": note_count,
|
||||
"display_title": "Garbage Collection",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="research_notes_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(Panel(f"Gathering my thoughts...\nCurrent number of research notes: {note_count}", title="🗑 Garbage Collection"))
|
||||
|
||||
# Only run the agent if we actually have notes to clean and we're over the threshold
|
||||
|
|
@ -235,6 +302,24 @@ Remember: Your goal is to maintain a concise, high-value collection of research
|
|||
# Show info panel with updated count and protected notes count
|
||||
protected_count = len(protected_notes)
|
||||
if protected_count > 0:
|
||||
# Record GC completion in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"original_count": note_count,
|
||||
"updated_count": updated_count,
|
||||
"protected_count": protected_count,
|
||||
"display_title": "GC Complete",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="research_notes_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
f"Cleaned research notes: {note_count} → {updated_count}\nProtected notes (associated with current request): {protected_count}",
|
||||
|
|
@ -242,6 +327,24 @@ Remember: Your goal is to maintain a concise, high-value collection of research
|
|||
)
|
||||
)
|
||||
else:
|
||||
# Record GC completion in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"original_count": note_count,
|
||||
"updated_count": updated_count,
|
||||
"protected_count": 0,
|
||||
"display_title": "GC Complete",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="research_notes_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
f"Cleaned research notes: {note_count} → {updated_count}",
|
||||
|
|
@ -249,6 +352,41 @@ Remember: Your goal is to maintain a concise, high-value collection of research
|
|||
)
|
||||
)
|
||||
else:
|
||||
# Record GC info in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"protected_count": len(protected_notes),
|
||||
"message": "All research notes are protected",
|
||||
"display_title": "GC Info",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="research_notes_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(Panel(f"All {len(protected_notes)} research notes are associated with the current request and protected from deletion.", title="🗑 GC Info"))
|
||||
else:
|
||||
# Record GC info in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"note_count": note_count,
|
||||
"threshold": threshold,
|
||||
"message": "Below threshold - no cleanup needed",
|
||||
"display_title": "GC Info",
|
||||
},
|
||||
record_type="gc_operation",
|
||||
human_input_id=human_input_id,
|
||||
tool_name="research_notes_gc_agent"
|
||||
)
|
||||
except Exception:
|
||||
pass # Continue if trajectory recording fails
|
||||
|
||||
console.print(Panel(f"Research notes count ({note_count}) is below threshold ({threshold}). No cleanup needed.", title="🗑 GC Info"))
|
||||
|
|
@ -154,6 +154,24 @@ class FallbackHandler:
|
|||
logger.debug(
|
||||
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback for tool: {self.current_failing_tool_name}"
|
||||
)
|
||||
# Import repository classes directly to avoid circular imports
|
||||
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.database.connection import get_db
|
||||
|
||||
# Create repositories directly
|
||||
trajectory_repo = TrajectoryRepository(get_db())
|
||||
human_input_repo = HumanInputRepository(get_db())
|
||||
human_input_id = human_input_repo.get_most_recent_id()
|
||||
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"message": f"**Tool fallback activated**: Attempting fallback for tool {self.current_failing_tool_name}.",
|
||||
"display_title": "Fallback Notification",
|
||||
},
|
||||
record_type="info",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
cpm(
|
||||
f"**Tool fallback activated**: Attempting fallback for tool {self.current_failing_tool_name}.",
|
||||
title="Fallback Notification",
|
||||
|
|
@ -163,6 +181,24 @@ class FallbackHandler:
|
|||
if result_list:
|
||||
return result_list
|
||||
|
||||
# Import repository classes directly to avoid circular imports
|
||||
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.database.connection import get_db
|
||||
|
||||
# Create repositories directly
|
||||
trajectory_repo = TrajectoryRepository(get_db())
|
||||
human_input_repo = HumanInputRepository(get_db())
|
||||
human_input_id = human_input_repo.get_most_recent_id()
|
||||
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"message": "All fallback models have failed.",
|
||||
"display_title": "Fallback Failed",
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
cpm("All fallback models have failed.", title="Fallback Failed")
|
||||
|
||||
current_failing_tool_name = self.current_failing_tool_name
|
||||
|
|
|
|||
|
|
@ -234,6 +234,24 @@ def create_llm_client(
|
|||
elif supports_temperature:
|
||||
if temperature is None:
|
||||
temperature = 0.7
|
||||
# Import repository classes directly to avoid circular imports
|
||||
from ra_aid.database.repositories.trajectory_repository import TrajectoryRepository
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.database.connection import get_db
|
||||
|
||||
# Create repositories directly
|
||||
trajectory_repo = TrajectoryRepository(get_db())
|
||||
human_input_repo = HumanInputRepository(get_db())
|
||||
human_input_id = human_input_repo.get_most_recent_id()
|
||||
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"message": "This model supports temperature argument but none was given. Setting default temperature to 0.7.",
|
||||
"display_title": "Information",
|
||||
},
|
||||
record_type="info",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
cpm(
|
||||
"This model supports temperature argument but none was given. Setting default temperature to 0.7."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ __all__ = [
|
|||
|
||||
from ra_aid.file_listing import FileListerError, get_file_listing
|
||||
from ra_aid.project_state import ProjectStateError, is_new_project
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -130,6 +132,24 @@ def display_project_status(info: ProjectInfo) -> None:
|
|||
{status} with **{file_count} file(s)**
|
||||
"""
|
||||
|
||||
# Record project status in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"project_status": "new" if info.is_new else "existing",
|
||||
"file_count": file_count,
|
||||
"total_files": info.total_files,
|
||||
"display_title": "Project Status",
|
||||
},
|
||||
record_type="info",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except Exception as e:
|
||||
# Silently continue if trajectory recording fails
|
||||
pass
|
||||
|
||||
# Create and display panel
|
||||
console = Console()
|
||||
console.print(Panel(Markdown(status_text.strip()), title="📊 Project Status"))
|
||||
console.print(Panel(Markdown(status_text.strip()), title="📊 Project Status"))
|
||||
|
|
@ -62,7 +62,23 @@ def request_research(query: str) -> ResearchResult:
|
|||
# Check recursion depth
|
||||
current_depth = get_depth()
|
||||
if current_depth >= RESEARCH_AGENT_RECURSION_LIMIT:
|
||||
print_error("Maximum research recursion depth reached")
|
||||
error_message = "Maximum research recursion depth reached"
|
||||
|
||||
# Record error in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": error_message,
|
||||
"display_title": "Error",
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=error_message
|
||||
)
|
||||
|
||||
print_error(error_message)
|
||||
try:
|
||||
key_facts = format_key_facts_dict(get_key_fact_repository().get_facts_dict())
|
||||
except RuntimeError as e:
|
||||
|
|
@ -109,7 +125,23 @@ def request_research(query: str) -> ResearchResult:
|
|||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
print_error(f"Error during research: {str(e)}")
|
||||
error_message = f"Error during research: {str(e)}"
|
||||
|
||||
# Record error in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": error_message,
|
||||
"display_title": "Error",
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=error_message
|
||||
)
|
||||
|
||||
print_error(error_message)
|
||||
success = False
|
||||
reason = f"error: {str(e)}"
|
||||
finally:
|
||||
|
|
@ -194,7 +226,23 @@ def request_web_research(query: str) -> ResearchResult:
|
|||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
print_error(f"Error during web research: {str(e)}")
|
||||
error_message = f"Error during web research: {str(e)}"
|
||||
|
||||
# Record error in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": error_message,
|
||||
"display_title": "Error",
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=error_message
|
||||
)
|
||||
|
||||
print_error(error_message)
|
||||
success = False
|
||||
reason = f"error: {str(e)}"
|
||||
finally:
|
||||
|
|
@ -384,7 +432,23 @@ def request_task_implementation(task_spec: str) -> str:
|
|||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
print_error(f"Error during task implementation: {str(e)}")
|
||||
error_message = f"Error during task implementation: {str(e)}"
|
||||
|
||||
# Record error in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": error_message,
|
||||
"display_title": "Error",
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=error_message
|
||||
)
|
||||
|
||||
print_error(error_message)
|
||||
success = False
|
||||
reason = f"error: {str(e)}"
|
||||
|
||||
|
|
@ -515,7 +579,23 @@ def request_implementation(task_spec: str) -> str:
|
|||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
print_error(f"Error during planning: {str(e)}")
|
||||
error_message = f"Error during planning: {str(e)}"
|
||||
|
||||
# Record error in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": error_message,
|
||||
"display_title": "Error",
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=error_message
|
||||
)
|
||||
|
||||
print_error(error_message)
|
||||
success = False
|
||||
reason = f"error: {str(e)}"
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,9 @@ from rich.panel import Panel
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from ..database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ..database.repositories.human_input_repository import get_human_input_repository
|
||||
|
||||
from ..database.repositories.key_fact_repository import get_key_fact_repository
|
||||
from ..database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||
from ..database.repositories.related_files_repository import get_related_files_repository
|
||||
|
|
@ -72,6 +75,23 @@ def emit_expert_context(context: str) -> str:
|
|||
"""
|
||||
expert_context["text"].append(context)
|
||||
|
||||
# Record expert context in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
tool_name="emit_expert_context",
|
||||
tool_parameters={"context_length": len(context)},
|
||||
step_data={
|
||||
"display_title": "Expert Context",
|
||||
"context_length": len(context),
|
||||
},
|
||||
record_type="tool_execution",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to record trajectory: {e}")
|
||||
|
||||
# Create and display status panel
|
||||
panel_content = f"Added expert context ({len(context)} characters)"
|
||||
console.print(Panel(panel_content, title="Expert Context", border_style="blue"))
|
||||
|
|
@ -184,6 +204,23 @@ def ask_expert(question: str) -> str:
|
|||
# Build display query (just question)
|
||||
display_query = "# Question\n" + question
|
||||
|
||||
# Record expert query in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
tool_name="ask_expert",
|
||||
tool_parameters={"question": question},
|
||||
step_data={
|
||||
"display_title": "Expert Query",
|
||||
"question": question,
|
||||
},
|
||||
record_type="tool_execution",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to record trajectory: {e}")
|
||||
|
||||
# Show only question in panel
|
||||
console.print(
|
||||
Panel(Markdown(display_query), title="🤔 Expert Query", border_style="yellow")
|
||||
|
|
@ -263,6 +300,23 @@ def ask_expert(question: str) -> str:
|
|||
logger.error(f"Exception during content processing: {str(e)}")
|
||||
raise
|
||||
|
||||
# Record expert response in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
tool_name="ask_expert",
|
||||
tool_parameters={"question": question},
|
||||
step_data={
|
||||
"display_title": "Expert Response",
|
||||
"response_length": len(content),
|
||||
},
|
||||
record_type="tool_execution",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to record trajectory: {e}")
|
||||
|
||||
# Format and display response
|
||||
console.print(
|
||||
Panel(Markdown(content), title="Expert Response", border_style="blue")
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ from rich.panel import Panel
|
|||
from ra_aid.console import console
|
||||
from ra_aid.console.formatting import print_error
|
||||
from ra_aid.tools.memory import emit_related_files
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
|
||||
|
||||
def truncate_display_str(s: str, max_length: int = 30) -> str:
|
||||
|
|
@ -54,6 +56,32 @@ def file_str_replace(filepath: str, old_str: str, new_str: str, *, replace_all:
|
|||
path = Path(filepath)
|
||||
if not path.exists():
|
||||
msg = f"File not found: {filepath}"
|
||||
|
||||
# Record error in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": msg,
|
||||
"display_title": "Error",
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=msg,
|
||||
tool_name="file_str_replace",
|
||||
tool_parameters={
|
||||
"filepath": filepath,
|
||||
"old_str": old_str,
|
||||
"new_str": new_str,
|
||||
"replace_all": replace_all
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# Silently handle trajectory recording failures (e.g., in test environments)
|
||||
pass
|
||||
|
||||
print_error(msg)
|
||||
return {"success": False, "message": msg}
|
||||
|
||||
|
|
@ -62,10 +90,62 @@ def file_str_replace(filepath: str, old_str: str, new_str: str, *, replace_all:
|
|||
|
||||
if count == 0:
|
||||
msg = f"String not found: {truncate_display_str(old_str)}"
|
||||
|
||||
# Record error in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": msg,
|
||||
"display_title": "Error",
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=msg,
|
||||
tool_name="file_str_replace",
|
||||
tool_parameters={
|
||||
"filepath": filepath,
|
||||
"old_str": old_str,
|
||||
"new_str": new_str,
|
||||
"replace_all": replace_all
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# Silently handle trajectory recording failures (e.g., in test environments)
|
||||
pass
|
||||
|
||||
print_error(msg)
|
||||
return {"success": False, "message": msg}
|
||||
elif count > 1 and not replace_all:
|
||||
msg = f"String appears {count} times - must be unique (use replace_all=True to replace all occurrences)"
|
||||
|
||||
# Record error in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": msg,
|
||||
"display_title": "Error",
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=msg,
|
||||
tool_name="file_str_replace",
|
||||
tool_parameters={
|
||||
"filepath": filepath,
|
||||
"old_str": old_str,
|
||||
"new_str": new_str,
|
||||
"replace_all": replace_all
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# Silently handle trajectory recording failures (e.g., in test environments)
|
||||
pass
|
||||
|
||||
print_error(msg)
|
||||
return {"success": False, "message": msg}
|
||||
|
||||
|
|
@ -93,7 +173,34 @@ def file_str_replace(filepath: str, old_str: str, new_str: str, *, replace_all:
|
|||
emit_related_files.invoke({"files": [filepath]})
|
||||
except Exception as e:
|
||||
# Don't let related files error affect main function success
|
||||
print_error(f"Note: Could not add to related files: {str(e)}")
|
||||
error_msg = f"Note: Could not add to related files: {str(e)}"
|
||||
|
||||
# Record error in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": error_msg,
|
||||
"display_title": "Error",
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=error_msg,
|
||||
tool_name="file_str_replace",
|
||||
tool_parameters={
|
||||
"filepath": filepath,
|
||||
"old_str": old_str,
|
||||
"new_str": new_str,
|
||||
"replace_all": replace_all
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# Silently handle trajectory recording failures (e.g., in test environments)
|
||||
pass
|
||||
|
||||
print_error(error_msg)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
|
|
@ -102,5 +209,31 @@ def file_str_replace(filepath: str, old_str: str, new_str: str, *, replace_all:
|
|||
|
||||
except Exception as e:
|
||||
msg = f"Error: {str(e)}"
|
||||
|
||||
# Record error in trajectory
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
step_data={
|
||||
"error_message": msg,
|
||||
"display_title": "Error",
|
||||
},
|
||||
record_type="error",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=msg,
|
||||
tool_name="file_str_replace",
|
||||
tool_parameters={
|
||||
"filepath": filepath,
|
||||
"old_str": old_str,
|
||||
"new_str": new_str,
|
||||
"replace_all": replace_all
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# Silently handle trajectory recording failures (e.g., in test environments)
|
||||
pass
|
||||
|
||||
print_error(msg)
|
||||
return {"success": False, "message": msg}
|
||||
return {"success": False, "message": msg}
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
import fnmatch
|
||||
from typing import List, Tuple
|
||||
import logging
|
||||
from typing import List, Tuple, Dict, Optional, Any
|
||||
|
||||
from fuzzywuzzy import process
|
||||
from git import Repo, exc
|
||||
|
|
@ -12,6 +13,49 @@ from ra_aid.file_listing import get_all_project_files, FileListerError
|
|||
|
||||
console = Console()
|
||||
|
||||
|
||||
def record_trajectory(
|
||||
tool_name: str,
|
||||
tool_parameters: Dict,
|
||||
step_data: Dict,
|
||||
record_type: str = "tool_execution",
|
||||
is_error: bool = False,
|
||||
error_message: Optional[str] = None,
|
||||
error_type: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Helper function to record trajectory information, handling the case when repositories are not available.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
tool_parameters: Parameters passed to the tool
|
||||
step_data: UI rendering data
|
||||
record_type: Type of trajectory record
|
||||
is_error: Flag indicating if this record represents an error
|
||||
error_message: The error message
|
||||
error_type: The type/class of the error
|
||||
"""
|
||||
try:
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
tool_name=tool_name,
|
||||
tool_parameters=tool_parameters,
|
||||
step_data=step_data,
|
||||
record_type=record_type,
|
||||
human_input_id=human_input_id,
|
||||
is_error=is_error,
|
||||
error_message=error_message,
|
||||
error_type=error_type
|
||||
)
|
||||
except (ImportError, RuntimeError):
|
||||
# If either the repository modules can't be imported or no repository is available,
|
||||
# just log and continue without recording trajectory
|
||||
logging.debug("Skipping trajectory recording: repositories not available")
|
||||
|
||||
DEFAULT_EXCLUDE_PATTERNS = [
|
||||
"*.pyc",
|
||||
"__pycache__/*",
|
||||
|
|
@ -57,7 +101,32 @@ def fuzzy_find_project_files(
|
|||
"""
|
||||
# Validate threshold
|
||||
if not 0 <= threshold <= 100:
|
||||
raise ValueError("Threshold must be between 0 and 100")
|
||||
error_msg = "Threshold must be between 0 and 100"
|
||||
|
||||
# Record error in trajectory
|
||||
record_trajectory(
|
||||
tool_name="fuzzy_find_project_files",
|
||||
tool_parameters={
|
||||
"search_term": search_term,
|
||||
"repo_path": repo_path,
|
||||
"threshold": threshold,
|
||||
"max_results": max_results,
|
||||
"include_paths": include_paths,
|
||||
"exclude_patterns": exclude_patterns,
|
||||
"include_hidden": include_hidden
|
||||
},
|
||||
step_data={
|
||||
"search_term": search_term,
|
||||
"display_title": "Invalid Threshold Value",
|
||||
"error_message": error_msg
|
||||
},
|
||||
record_type="tool_execution",
|
||||
is_error=True,
|
||||
error_message=error_msg,
|
||||
error_type="ValueError"
|
||||
)
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Handle empty search term as special case
|
||||
if not search_term:
|
||||
|
|
@ -126,6 +195,27 @@ def fuzzy_find_project_files(
|
|||
else:
|
||||
info_sections.append("## Results\n*No matches found*")
|
||||
|
||||
# Record fuzzy find in trajectory
|
||||
record_trajectory(
|
||||
tool_name="fuzzy_find_project_files",
|
||||
tool_parameters={
|
||||
"search_term": search_term,
|
||||
"repo_path": repo_path,
|
||||
"threshold": threshold,
|
||||
"max_results": max_results,
|
||||
"include_paths": include_paths,
|
||||
"exclude_patterns": exclude_patterns,
|
||||
"include_hidden": include_hidden
|
||||
},
|
||||
step_data={
|
||||
"search_term": search_term,
|
||||
"display_title": "Fuzzy Find Results",
|
||||
"total_files": len(all_files),
|
||||
"matches_found": len(filtered_matches)
|
||||
},
|
||||
record_type="tool_execution"
|
||||
)
|
||||
|
||||
# Display the panel
|
||||
console.print(
|
||||
Panel(
|
||||
|
|
@ -138,5 +228,30 @@ def fuzzy_find_project_files(
|
|||
return filtered_matches
|
||||
|
||||
except FileListerError as e:
|
||||
console.print(f"[bold red]Error listing files: {e}[/bold red]")
|
||||
error_msg = f"Error listing files: {e}"
|
||||
|
||||
# Record error in trajectory
|
||||
record_trajectory(
|
||||
tool_name="fuzzy_find_project_files",
|
||||
tool_parameters={
|
||||
"search_term": search_term,
|
||||
"repo_path": repo_path,
|
||||
"threshold": threshold,
|
||||
"max_results": max_results,
|
||||
"include_paths": include_paths,
|
||||
"exclude_patterns": exclude_patterns,
|
||||
"include_hidden": include_hidden
|
||||
},
|
||||
step_data={
|
||||
"search_term": search_term,
|
||||
"display_title": "Fuzzy Find Error",
|
||||
"error_message": error_msg
|
||||
},
|
||||
record_type="tool_execution",
|
||||
is_error=True,
|
||||
error_message=error_msg,
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
|
||||
console.print(f"[bold red]{error_msg}[/bold red]")
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from ra_aid.database.repositories.key_fact_repository import get_key_fact_reposi
|
|||
from ra_aid.database.repositories.key_snippet_repository import get_key_snippet_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
from ra_aid.database.repositories.research_note_repository import get_research_note_repository
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.database.repositories.work_log_repository import get_work_log_repository
|
||||
from ra_aid.model_formatters import key_snippets_formatter
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
|
@ -69,6 +70,22 @@ def emit_research_notes(notes: str) -> str:
|
|||
from ra_aid.model_formatters.research_notes_formatter import format_research_note
|
||||
formatted_note = format_research_note(note_id, notes)
|
||||
|
||||
# Record to trajectory before displaying panel
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
trajectory_repo.create(
|
||||
tool_name="emit_research_notes",
|
||||
tool_parameters={"notes": notes},
|
||||
step_data={
|
||||
"note_id": note_id,
|
||||
"display_title": "Research Notes",
|
||||
},
|
||||
record_type="memory_operation",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Failed to record trajectory: {str(e)}")
|
||||
|
||||
# Display formatted note
|
||||
console.print(Panel(Markdown(formatted_note), title="🔍 Research Notes"))
|
||||
|
||||
|
|
@ -123,6 +140,23 @@ def emit_key_facts(facts: List[str]) -> str:
|
|||
console.print(f"Error storing fact: {str(e)}", style="red")
|
||||
continue
|
||||
|
||||
# Record to trajectory before displaying panel
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
trajectory_repo.create(
|
||||
tool_name="emit_key_facts",
|
||||
tool_parameters={"facts": [fact]},
|
||||
step_data={
|
||||
"fact_id": fact_id,
|
||||
"fact": fact,
|
||||
"display_title": f"Key Fact #{fact_id}",
|
||||
},
|
||||
record_type="memory_operation",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Failed to record trajectory: {str(e)}")
|
||||
|
||||
# Display panel with ID
|
||||
console.print(
|
||||
Panel(
|
||||
|
|
@ -214,6 +248,32 @@ def emit_key_snippet(snippet_info: SnippetInfo) -> str:
|
|||
if snippet_info["description"]:
|
||||
display_text.extend(["", "**Description**:", snippet_info["description"]])
|
||||
|
||||
# Record to trajectory before displaying panel
|
||||
try:
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
trajectory_repo.create(
|
||||
tool_name="emit_key_snippet",
|
||||
tool_parameters={
|
||||
"snippet_info": {
|
||||
"filepath": snippet_info["filepath"],
|
||||
"line_number": snippet_info["line_number"],
|
||||
"description": snippet_info["description"],
|
||||
# Omit the full snippet content to avoid duplicating large text in the database
|
||||
"snippet_length": len(snippet_info["snippet"])
|
||||
}
|
||||
},
|
||||
step_data={
|
||||
"snippet_id": snippet_id,
|
||||
"filepath": snippet_info["filepath"],
|
||||
"line_number": snippet_info["line_number"],
|
||||
"display_title": f"Key Snippet #{snippet_id}",
|
||||
},
|
||||
record_type="memory_operation",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Failed to record trajectory: {str(e)}")
|
||||
|
||||
# Display panel
|
||||
console.print(
|
||||
Panel(
|
||||
|
|
@ -248,6 +308,25 @@ def one_shot_completed(message: str) -> str:
|
|||
message: Completion message to display
|
||||
"""
|
||||
mark_task_completed(message)
|
||||
|
||||
# Record to trajectory before displaying panel
|
||||
human_input_id = None
|
||||
try:
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
trajectory_repo.create(
|
||||
tool_name="one_shot_completed",
|
||||
tool_parameters={"message": message},
|
||||
step_data={
|
||||
"completion_message": message,
|
||||
"display_title": "Task Completed",
|
||||
},
|
||||
record_type="task_completion",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Failed to record trajectory: {str(e)}")
|
||||
|
||||
console.print(Panel(Markdown(message), title="✅ Task Completed"))
|
||||
log_work_event(f"Task completed:\n\n{message}")
|
||||
return "Completion noted."
|
||||
|
|
@ -261,6 +340,25 @@ def task_completed(message: str) -> str:
|
|||
message: Message explaining how/why the task is complete
|
||||
"""
|
||||
mark_task_completed(message)
|
||||
|
||||
# Record to trajectory before displaying panel
|
||||
human_input_id = None
|
||||
try:
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
trajectory_repo.create(
|
||||
tool_name="task_completed",
|
||||
tool_parameters={"message": message},
|
||||
step_data={
|
||||
"completion_message": message,
|
||||
"display_title": "Task Completed",
|
||||
},
|
||||
record_type="task_completion",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Failed to record trajectory: {str(e)}")
|
||||
|
||||
console.print(Panel(Markdown(message), title="✅ Task Completed"))
|
||||
log_work_event(f"Task completed:\n\n{message}")
|
||||
return "Completion noted."
|
||||
|
|
@ -275,6 +373,25 @@ def plan_implementation_completed(message: str) -> str:
|
|||
"""
|
||||
mark_should_exit(propagation_depth=1)
|
||||
mark_plan_completed(message)
|
||||
|
||||
# Record to trajectory before displaying panel
|
||||
human_input_id = None
|
||||
try:
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
trajectory_repo.create(
|
||||
tool_name="plan_implementation_completed",
|
||||
tool_parameters={"message": message},
|
||||
step_data={
|
||||
"completion_message": message,
|
||||
"display_title": "Plan Executed",
|
||||
},
|
||||
record_type="plan_completion",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Failed to record trajectory: {str(e)}")
|
||||
|
||||
console.print(Panel(Markdown(message), title="✅ Plan Executed"))
|
||||
log_work_event(f"Completed implementation:\n\n{message}")
|
||||
return "Plan completion noted."
|
||||
|
|
@ -361,10 +478,29 @@ def emit_related_files(files: List[str]) -> str:
|
|||
|
||||
results.append(f"File ID #{file_id}: {file}")
|
||||
|
||||
# Rich output - single consolidated panel for added files
|
||||
# Record to trajectory before displaying panel for added files
|
||||
if added_files:
|
||||
files_added_md = "\n".join(f"- `{file}`" for id, file in added_files)
|
||||
md_content = f"**Files Noted:**\n{files_added_md}"
|
||||
|
||||
human_input_id = None
|
||||
try:
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
trajectory_repo.create(
|
||||
tool_name="emit_related_files",
|
||||
tool_parameters={"files": files},
|
||||
step_data={
|
||||
"added_files": [file for _, file in added_files],
|
||||
"added_file_ids": [file_id for file_id, _ in added_files],
|
||||
"display_title": "Related Files Noted",
|
||||
},
|
||||
record_type="memory_operation",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Failed to record trajectory: {str(e)}")
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(md_content),
|
||||
|
|
@ -373,10 +509,28 @@ def emit_related_files(files: List[str]) -> str:
|
|||
)
|
||||
)
|
||||
|
||||
# Display skipped binary files
|
||||
# Record to trajectory before displaying panel for binary files
|
||||
if binary_files:
|
||||
binary_files_md = "\n".join(f"- `{file}`" for file in binary_files)
|
||||
md_content = f"**Binary Files Skipped:**\n{binary_files_md}"
|
||||
|
||||
human_input_id = None
|
||||
try:
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
trajectory_repo.create(
|
||||
tool_name="emit_related_files",
|
||||
tool_parameters={"files": files},
|
||||
step_data={
|
||||
"binary_files": binary_files,
|
||||
"display_title": "Binary Files Not Added",
|
||||
},
|
||||
record_type="memory_operation",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Failed to record trajectory: {str(e)}")
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(md_content),
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
import os.path
|
||||
import time
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from rich.console import Console
|
||||
|
|
@ -16,6 +16,49 @@ console = Console()
|
|||
CHUNK_SIZE = 8192
|
||||
|
||||
|
||||
def record_trajectory(
|
||||
tool_name: str,
|
||||
tool_parameters: Dict,
|
||||
step_data: Dict,
|
||||
record_type: str = "tool_execution",
|
||||
is_error: bool = False,
|
||||
error_message: Optional[str] = None,
|
||||
error_type: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Helper function to record trajectory information, handling the case when repositories are not available.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
tool_parameters: Parameters passed to the tool
|
||||
step_data: UI rendering data
|
||||
record_type: Type of trajectory record
|
||||
is_error: Flag indicating if this record represents an error
|
||||
error_message: The error message
|
||||
error_type: The type/class of the error
|
||||
"""
|
||||
try:
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
tool_name=tool_name,
|
||||
tool_parameters=tool_parameters,
|
||||
step_data=step_data,
|
||||
record_type=record_type,
|
||||
human_input_id=human_input_id,
|
||||
is_error=is_error,
|
||||
error_message=error_message,
|
||||
error_type=error_type
|
||||
)
|
||||
except (ImportError, RuntimeError):
|
||||
# If either the repository modules can't be imported or no repository is available,
|
||||
# just log and continue without recording trajectory
|
||||
logging.debug("Skipping trajectory recording: repositories not available")
|
||||
|
||||
|
||||
@tool
|
||||
def read_file_tool(filepath: str, encoding: str = "utf-8") -> Dict[str, str]:
|
||||
"""Read and return the contents of a text file.
|
||||
|
|
@ -29,10 +72,43 @@ def read_file_tool(filepath: str, encoding: str = "utf-8") -> Dict[str, str]:
|
|||
start_time = time.time()
|
||||
try:
|
||||
if not os.path.exists(filepath):
|
||||
# Record error in trajectory
|
||||
record_trajectory(
|
||||
tool_name="read_file_tool",
|
||||
tool_parameters={
|
||||
"filepath": filepath,
|
||||
"encoding": encoding
|
||||
},
|
||||
step_data={
|
||||
"filepath": filepath,
|
||||
"display_title": "File Not Found",
|
||||
"error_message": f"File not found: {filepath}"
|
||||
},
|
||||
is_error=True,
|
||||
error_message=f"File not found: {filepath}",
|
||||
error_type="FileNotFoundError"
|
||||
)
|
||||
raise FileNotFoundError(f"File not found: {filepath}")
|
||||
|
||||
# Check if the file is binary
|
||||
if is_binary_file(filepath):
|
||||
# Record binary file error in trajectory
|
||||
record_trajectory(
|
||||
tool_name="read_file_tool",
|
||||
tool_parameters={
|
||||
"filepath": filepath,
|
||||
"encoding": encoding
|
||||
},
|
||||
step_data={
|
||||
"filepath": filepath,
|
||||
"display_title": "Binary File Detected",
|
||||
"error_message": f"Cannot read binary file: {filepath}"
|
||||
},
|
||||
is_error=True,
|
||||
error_message="Cannot read binary file",
|
||||
error_type="BinaryFileError"
|
||||
)
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
f"Cannot read binary file: {filepath}",
|
||||
|
|
@ -67,6 +143,22 @@ def read_file_tool(filepath: str, encoding: str = "utf-8") -> Dict[str, str]:
|
|||
logging.debug(f"File read complete: {total_bytes} bytes in {elapsed:.2f}s")
|
||||
logging.debug(f"Pre-truncation stats: {total_bytes} bytes, {line_count} lines")
|
||||
|
||||
# Record successful file read in trajectory
|
||||
record_trajectory(
|
||||
tool_name="read_file_tool",
|
||||
tool_parameters={
|
||||
"filepath": filepath,
|
||||
"encoding": encoding
|
||||
},
|
||||
step_data={
|
||||
"filepath": filepath,
|
||||
"display_title": "File Read",
|
||||
"line_count": line_count,
|
||||
"total_bytes": total_bytes,
|
||||
"elapsed_time": elapsed
|
||||
}
|
||||
)
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
f"Read {line_count} lines ({total_bytes} bytes) from {filepath} in {elapsed:.2f}s",
|
||||
|
|
@ -80,6 +172,25 @@ def read_file_tool(filepath: str, encoding: str = "utf-8") -> Dict[str, str]:
|
|||
|
||||
return {"content": truncated}
|
||||
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Record exception in trajectory (if it's not already a handled FileNotFoundError)
|
||||
if not isinstance(e, FileNotFoundError):
|
||||
record_trajectory(
|
||||
tool_name="read_file_tool",
|
||||
tool_parameters={
|
||||
"filepath": filepath,
|
||||
"encoding": encoding
|
||||
},
|
||||
step_data={
|
||||
"filepath": filepath,
|
||||
"display_title": "File Read Error",
|
||||
"error_message": str(e)
|
||||
},
|
||||
is_error=True,
|
||||
error_message=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@ from langchain_core.tools import tool
|
|||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
|
|
@ -10,6 +13,24 @@ def existing_project_detected() -> dict:
|
|||
"""
|
||||
When to call: Once you have confirmed that the current working directory contains project files.
|
||||
"""
|
||||
try:
|
||||
# Record detection in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
tool_name="existing_project_detected",
|
||||
tool_parameters={},
|
||||
step_data={
|
||||
"detection_type": "existing_project",
|
||||
"display_title": "Existing Project Detected",
|
||||
},
|
||||
record_type="tool_execution",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except Exception as e:
|
||||
# Continue even if trajectory recording fails
|
||||
console.print(f"Warning: Could not record trajectory: {str(e)}")
|
||||
|
||||
console.print(Panel("📁 Existing Project Detected", style="bright_blue", padding=0))
|
||||
return {
|
||||
"hint": (
|
||||
|
|
@ -30,6 +51,24 @@ def monorepo_detected() -> dict:
|
|||
"""
|
||||
When to call: After identifying that multiple packages or modules exist within a single repository.
|
||||
"""
|
||||
try:
|
||||
# Record detection in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
tool_name="monorepo_detected",
|
||||
tool_parameters={},
|
||||
step_data={
|
||||
"detection_type": "monorepo",
|
||||
"display_title": "Monorepo Detected",
|
||||
},
|
||||
record_type="tool_execution",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except Exception as e:
|
||||
# Continue even if trajectory recording fails
|
||||
console.print(f"Warning: Could not record trajectory: {str(e)}")
|
||||
|
||||
console.print(Panel("📦 Monorepo Detected", style="bright_blue", padding=0))
|
||||
return {
|
||||
"hint": (
|
||||
|
|
@ -53,6 +92,24 @@ def ui_detected() -> dict:
|
|||
"""
|
||||
When to call: After detecting that the project contains a user interface layer or front-end component.
|
||||
"""
|
||||
try:
|
||||
# Record detection in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
tool_name="ui_detected",
|
||||
tool_parameters={},
|
||||
step_data={
|
||||
"detection_type": "ui",
|
||||
"display_title": "UI Detected",
|
||||
},
|
||||
record_type="tool_execution",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
except Exception as e:
|
||||
# Continue even if trajectory recording fails
|
||||
console.print(f"Warning: Could not record trajectory: {str(e)}")
|
||||
|
||||
console.print(Panel("🎯 UI Detected", style="bright_blue", padding=0))
|
||||
return {
|
||||
"hint": (
|
||||
|
|
@ -64,4 +121,4 @@ def ui_detected() -> dict:
|
|||
"- Find and note established workflows for building, bundling, and deploying the UI layer, ensuring that any new changes do not conflict with the existing pipeline.\n\n"
|
||||
"Your goal is to enhance the user interface without disrupting the cohesive look, feel, and functionality already established."
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
@ -5,6 +5,8 @@ from rich.console import Console
|
|||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.proc.interactive import run_interactive_command
|
||||
from ra_aid.text.processing import truncate_output
|
||||
|
||||
|
|
@ -158,6 +160,30 @@ def ripgrep_search(
|
|||
info_sections.append("\n".join(params))
|
||||
|
||||
# Execute command
|
||||
# Record ripgrep search in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
tool_name="ripgrep_search",
|
||||
tool_parameters={
|
||||
"pattern": pattern,
|
||||
"before_context_lines": before_context_lines,
|
||||
"after_context_lines": after_context_lines,
|
||||
"file_type": file_type,
|
||||
"case_sensitive": case_sensitive,
|
||||
"include_hidden": include_hidden,
|
||||
"follow_links": follow_links,
|
||||
"exclude_dirs": exclude_dirs,
|
||||
"fixed_string": fixed_string
|
||||
},
|
||||
step_data={
|
||||
"search_pattern": pattern,
|
||||
"display_title": "Ripgrep Search",
|
||||
},
|
||||
record_type="tool_execution",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
Markdown(f"Searching for: **{pattern}**"),
|
||||
|
|
@ -179,5 +205,34 @@ def ripgrep_search(
|
|||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
|
||||
# Record error in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
tool_name="ripgrep_search",
|
||||
tool_parameters={
|
||||
"pattern": pattern,
|
||||
"before_context_lines": before_context_lines,
|
||||
"after_context_lines": after_context_lines,
|
||||
"file_type": file_type,
|
||||
"case_sensitive": case_sensitive,
|
||||
"include_hidden": include_hidden,
|
||||
"follow_links": follow_links,
|
||||
"exclude_dirs": exclude_dirs,
|
||||
"fixed_string": fixed_string
|
||||
},
|
||||
step_data={
|
||||
"search_pattern": pattern,
|
||||
"display_title": "Ripgrep Search Error",
|
||||
"error_message": error_msg
|
||||
},
|
||||
record_type="tool_execution",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=error_msg,
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
|
||||
console.print(Panel(error_msg, title="❌ Error", border_style="red"))
|
||||
return {"output": error_msg, "return_code": 1, "success": False}
|
||||
|
|
@ -10,6 +10,8 @@ from ra_aid.proc.interactive import run_interactive_command
|
|||
from ra_aid.text.processing import truncate_output
|
||||
from ra_aid.tools.memory import log_work_event
|
||||
from ra_aid.database.repositories.config_repository import get_config_repository
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
|
||||
console = Console()
|
||||
|
||||
|
|
@ -54,6 +56,20 @@ def run_shell_command(
|
|||
console.print(" " + get_cowboy_message())
|
||||
console.print("")
|
||||
|
||||
# Record tool execution in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
tool_name="run_shell_command",
|
||||
tool_parameters={"command": command, "timeout": timeout},
|
||||
step_data={
|
||||
"command": command,
|
||||
"display_title": "Shell Command",
|
||||
},
|
||||
record_type="tool_execution",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
|
||||
# Show just the command in a simple panel
|
||||
console.print(Panel(command, title="🐚 Shell", border_style="bright_yellow"))
|
||||
|
||||
|
|
@ -96,5 +112,23 @@ def run_shell_command(
|
|||
return result
|
||||
except Exception as e:
|
||||
print()
|
||||
# Record error in trajectory
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
tool_name="run_shell_command",
|
||||
tool_parameters={"command": command, "timeout": timeout},
|
||||
step_data={
|
||||
"command": command,
|
||||
"error": str(e),
|
||||
"display_title": "Shell Error",
|
||||
},
|
||||
record_type="tool_execution",
|
||||
is_error=True,
|
||||
error_message=str(e),
|
||||
error_type=type(e).__name__,
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
|
||||
console.print(Panel(str(e), title="❌ Error", border_style="red"))
|
||||
return {"output": str(e), "return_code": 1, "success": False}
|
||||
|
|
@ -7,6 +7,9 @@ from rich.markdown import Markdown
|
|||
from rich.panel import Panel
|
||||
from tavily import TavilyClient
|
||||
|
||||
from ra_aid.database.repositories.trajectory_repository import get_trajectory_repository
|
||||
from ra_aid.database.repositories.human_input_repository import get_human_input_repository
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
|
|
@ -21,9 +24,44 @@ def web_search_tavily(query: str) -> Dict:
|
|||
Returns:
|
||||
Dict containing search results from Tavily
|
||||
"""
|
||||
client = TavilyClient(api_key=os.environ["TAVILY_API_KEY"])
|
||||
# Record trajectory before displaying panel
|
||||
trajectory_repo = get_trajectory_repository()
|
||||
human_input_id = get_human_input_repository().get_most_recent_id()
|
||||
trajectory_repo.create(
|
||||
tool_name="web_search_tavily",
|
||||
tool_parameters={"query": query},
|
||||
step_data={
|
||||
"query": query,
|
||||
"display_title": "Web Search",
|
||||
},
|
||||
record_type="tool_execution",
|
||||
human_input_id=human_input_id
|
||||
)
|
||||
|
||||
# Display search query panel
|
||||
console.print(
|
||||
Panel(Markdown(query), title="🔍 Searching Tavily", border_style="bright_blue")
|
||||
)
|
||||
search_result = client.search(query=query)
|
||||
return search_result
|
||||
|
||||
try:
|
||||
client = TavilyClient(api_key=os.environ["TAVILY_API_KEY"])
|
||||
search_result = client.search(query=query)
|
||||
return search_result
|
||||
except Exception as e:
|
||||
# Record error in trajectory
|
||||
trajectory_repo.create(
|
||||
tool_name="web_search_tavily",
|
||||
tool_parameters={"query": query},
|
||||
step_data={
|
||||
"query": query,
|
||||
"display_title": "Web Search Error",
|
||||
"error": str(e)
|
||||
},
|
||||
record_type="tool_execution",
|
||||
human_input_id=human_input_id,
|
||||
is_error=True,
|
||||
error_message=str(e),
|
||||
error_type=type(e).__name__
|
||||
)
|
||||
# Re-raise the exception to maintain original behavior
|
||||
raise
|
||||
Loading…
Reference in New Issue