do not gc key facts/snippets associated with current human input
This commit is contained in:
parent
f88ad5bc7a
commit
9a69bb173e
|
|
@ -15,6 +15,7 @@ from rich.panel import Panel
|
|||
|
||||
from ra_aid.agent_utils import create_agent, run_agent_with_retry
|
||||
from ra_aid.database.repositories.key_fact_repository import KeyFactRepository
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.llm import initialize_llm
|
||||
from ra_aid.prompts.key_facts_gc_prompts import KEY_FACTS_GC_PROMPT
|
||||
from ra_aid.tools.memory import log_work_event, _global_memory
|
||||
|
|
@ -22,6 +23,7 @@ from ra_aid.tools.memory import log_work_event, _global_memory
|
|||
|
||||
console = Console()
|
||||
key_fact_repository = KeyFactRepository()
|
||||
human_input_repository = HumanInputRepository()
|
||||
|
||||
|
||||
@tool
|
||||
|
|
@ -37,12 +39,27 @@ def delete_key_facts(fact_ids: List[int]) -> str:
|
|||
deleted_facts = []
|
||||
not_found_facts = []
|
||||
failed_facts = []
|
||||
protected_facts = []
|
||||
|
||||
# Try to get the current human input to protect its facts
|
||||
current_human_input_id = None
|
||||
try:
|
||||
recent_inputs = human_input_repository.get_recent(1)
|
||||
if recent_inputs and len(recent_inputs) > 0:
|
||||
current_human_input_id = recent_inputs[0].id
|
||||
except Exception as e:
|
||||
console.print(f"Warning: Could not retrieve current human input: {str(e)}")
|
||||
|
||||
for fact_id in fact_ids:
|
||||
# Get the fact first to display information
|
||||
fact = key_fact_repository.get(fact_id)
|
||||
if fact:
|
||||
# Delete the fact
|
||||
# Check if this fact is associated with the current human input
|
||||
if current_human_input_id is not None and fact.human_input_id == current_human_input_id:
|
||||
protected_facts.append((fact_id, fact.content))
|
||||
continue
|
||||
|
||||
# Delete the fact if it's not protected
|
||||
was_deleted = key_fact_repository.delete(fact_id)
|
||||
if was_deleted:
|
||||
deleted_facts.append((fact_id, fact.content))
|
||||
|
|
@ -61,6 +78,13 @@ def delete_key_facts(fact_ids: List[int]) -> str:
|
|||
Panel(Markdown(deleted_msg), title="Facts Deleted", border_style="green")
|
||||
)
|
||||
|
||||
if protected_facts:
|
||||
protected_msg = "Protected facts (associated with current request):\n" + "\n".join([f"- #{fact_id}: {content}" for fact_id, content in protected_facts])
|
||||
result_parts.append(protected_msg)
|
||||
console.print(
|
||||
Panel(Markdown(protected_msg), title="Facts Protected", border_style="blue")
|
||||
)
|
||||
|
||||
if not_found_facts:
|
||||
not_found_msg = f"Facts not found: {', '.join([f'#{fact_id}' for fact_id in not_found_facts])}"
|
||||
result_parts.append(not_found_msg)
|
||||
|
|
@ -77,6 +101,7 @@ def run_key_facts_gc_agent() -> None:
|
|||
|
||||
The agent analyzes all key facts and determines which are the least valuable,
|
||||
deleting them to maintain a manageable collection size of high-value facts.
|
||||
Facts associated with the current human input are excluded from deletion.
|
||||
"""
|
||||
# Get the count of key facts
|
||||
facts = key_fact_repository.get_all()
|
||||
|
|
@ -87,44 +112,75 @@ def run_key_facts_gc_agent() -> None:
|
|||
|
||||
# Only run the agent if we actually have facts to clean
|
||||
if fact_count > 0:
|
||||
# Get all facts as a formatted string for the prompt
|
||||
facts_dict = key_fact_repository.get_facts_dict()
|
||||
formatted_facts = "\n".join([f"Fact #{k}: {v}" for k, v in facts_dict.items()])
|
||||
# Try to get the current human input ID to exclude its facts
|
||||
current_human_input_id = None
|
||||
try:
|
||||
recent_inputs = human_input_repository.get_recent(1)
|
||||
if recent_inputs and len(recent_inputs) > 0:
|
||||
current_human_input_id = recent_inputs[0].id
|
||||
except Exception as e:
|
||||
console.print(f"Warning: Could not retrieve current human input: {str(e)}")
|
||||
|
||||
# Retrieve configuration
|
||||
llm_config = _global_memory.get("config", {})
|
||||
# Get all facts that are not associated with the current human input
|
||||
eligible_facts = []
|
||||
protected_facts = []
|
||||
for fact in facts:
|
||||
if current_human_input_id is not None and fact.human_input_id == current_human_input_id:
|
||||
protected_facts.append(fact)
|
||||
else:
|
||||
eligible_facts.append(fact)
|
||||
|
||||
# Initialize the LLM model
|
||||
model = initialize_llm(
|
||||
llm_config.get("provider", "anthropic"),
|
||||
llm_config.get("model", "claude-3-7-sonnet-20250219"),
|
||||
temperature=llm_config.get("temperature")
|
||||
)
|
||||
# Only process if we have facts that can be deleted
|
||||
if eligible_facts:
|
||||
# Format facts as a dictionary for the prompt
|
||||
facts_dict = {fact.id: fact.content for fact in eligible_facts}
|
||||
formatted_facts = "\n".join([f"Fact #{k}: {v}" for k, v in facts_dict.items()])
|
||||
|
||||
# Create the agent with the delete_key_facts tool
|
||||
agent = create_agent(model, [delete_key_facts])
|
||||
# Retrieve configuration
|
||||
llm_config = _global_memory.get("config", {})
|
||||
|
||||
# Format the prompt with the current facts
|
||||
prompt = KEY_FACTS_GC_PROMPT.format(key_facts=formatted_facts)
|
||||
|
||||
# Set up the agent configuration
|
||||
agent_config = {
|
||||
"recursion_limit": 50 # Set a reasonable recursion limit
|
||||
}
|
||||
|
||||
# Run the agent
|
||||
run_agent_with_retry(agent, prompt, agent_config)
|
||||
|
||||
# Get updated count
|
||||
updated_facts = key_fact_repository.get_all()
|
||||
updated_count = len(updated_facts)
|
||||
|
||||
# Show info panel with updated count
|
||||
console.print(
|
||||
Panel(
|
||||
f"Cleaned key facts: {fact_count} → {updated_count}",
|
||||
title="🗑 GC Complete"
|
||||
# Initialize the LLM model
|
||||
model = initialize_llm(
|
||||
llm_config.get("provider", "anthropic"),
|
||||
llm_config.get("model", "claude-3-7-sonnet-20250219"),
|
||||
temperature=llm_config.get("temperature")
|
||||
)
|
||||
)
|
||||
|
||||
# Create the agent with the delete_key_facts tool
|
||||
agent = create_agent(model, [delete_key_facts])
|
||||
|
||||
# Format the prompt with the eligible facts
|
||||
prompt = KEY_FACTS_GC_PROMPT.format(key_facts=formatted_facts)
|
||||
|
||||
# Set up the agent configuration
|
||||
agent_config = {
|
||||
"recursion_limit": 50 # Set a reasonable recursion limit
|
||||
}
|
||||
|
||||
# Run the agent
|
||||
run_agent_with_retry(agent, prompt, agent_config)
|
||||
|
||||
# Get updated count
|
||||
updated_facts = key_fact_repository.get_all()
|
||||
updated_count = len(updated_facts)
|
||||
|
||||
# Show info panel with updated count and protected facts count
|
||||
protected_count = len(protected_facts)
|
||||
if protected_count > 0:
|
||||
console.print(
|
||||
Panel(
|
||||
f"Cleaned key facts: {fact_count} → {updated_count}\nProtected facts (associated with current request): {protected_count}",
|
||||
title="🗑 GC Complete"
|
||||
)
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
Panel(
|
||||
f"Cleaned key facts: {fact_count} → {updated_count}",
|
||||
title="🗑 GC Complete"
|
||||
)
|
||||
)
|
||||
else:
|
||||
console.print(Panel(f"All {len(protected_facts)} facts are associated with the current request and protected from deletion.", title="🗑 GC Info"))
|
||||
else:
|
||||
console.print(Panel("No key facts to clean.", title="🗑 GC Info"))
|
||||
|
|
@ -15,6 +15,7 @@ from rich.panel import Panel
|
|||
|
||||
from ra_aid.agent_utils import create_agent, run_agent_with_retry
|
||||
from ra_aid.database.repositories.key_snippet_repository import KeySnippetRepository
|
||||
from ra_aid.database.repositories.human_input_repository import HumanInputRepository
|
||||
from ra_aid.llm import initialize_llm
|
||||
from ra_aid.prompts.key_snippets_gc_prompts import KEY_SNIPPETS_GC_PROMPT
|
||||
from ra_aid.tools.memory import log_work_event, _global_memory
|
||||
|
|
@ -22,6 +23,7 @@ from ra_aid.tools.memory import log_work_event, _global_memory
|
|||
|
||||
console = Console()
|
||||
key_snippet_repository = KeySnippetRepository()
|
||||
human_input_repository = HumanInputRepository()
|
||||
|
||||
|
||||
@tool
|
||||
|
|
@ -38,13 +40,29 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
|
|||
results = []
|
||||
not_found_snippets = []
|
||||
failed_snippets = []
|
||||
protected_snippets = []
|
||||
|
||||
# Try to get the current human input to protect its snippets
|
||||
current_human_input_id = None
|
||||
try:
|
||||
recent_inputs = human_input_repository.get_recent(1)
|
||||
if recent_inputs and len(recent_inputs) > 0:
|
||||
current_human_input_id = recent_inputs[0].id
|
||||
except Exception as e:
|
||||
console.print(f"Warning: Could not retrieve current human input: {str(e)}")
|
||||
|
||||
for snippet_id in snippet_ids:
|
||||
# Get the snippet first to capture filepath for the message
|
||||
snippet = key_snippet_repository.get(snippet_id)
|
||||
if snippet:
|
||||
filepath = snippet.filepath
|
||||
# Delete from database
|
||||
|
||||
# Check if this snippet is associated with the current human input
|
||||
if current_human_input_id is not None and snippet.human_input_id == current_human_input_id:
|
||||
protected_snippets.append((snippet_id, filepath))
|
||||
continue
|
||||
|
||||
# Delete from database if not protected
|
||||
success = key_snippet_repository.delete(snippet_id)
|
||||
if success:
|
||||
success_msg = f"Successfully deleted snippet #{snippet_id} from {filepath}"
|
||||
|
|
@ -66,6 +84,13 @@ def delete_key_snippets(snippet_ids: List[int]) -> str:
|
|||
deleted_msg = "Successfully deleted snippets:\n" + "\n".join([f"- #{snippet_id}: {filepath}" for snippet_id, filepath in results])
|
||||
result_parts.append(deleted_msg)
|
||||
|
||||
if protected_snippets:
|
||||
protected_msg = "Protected snippets (associated with current request):\n" + "\n".join([f"- #{snippet_id}: {filepath}" for snippet_id, filepath in protected_snippets])
|
||||
result_parts.append(protected_msg)
|
||||
console.print(
|
||||
Panel(Markdown(protected_msg), title="Snippets Protected", border_style="blue")
|
||||
)
|
||||
|
||||
if not_found_snippets:
|
||||
not_found_msg = f"Snippets not found: {', '.join([f'#{snippet_id}' for snippet_id in not_found_snippets])}"
|
||||
result_parts.append(not_found_msg)
|
||||
|
|
@ -82,6 +107,7 @@ def run_key_snippets_gc_agent() -> None:
|
|||
|
||||
The agent analyzes all key snippets and determines which are the least valuable,
|
||||
deleting them to maintain a manageable collection size of high-value snippets.
|
||||
Snippets associated with the current human input are excluded from deletion.
|
||||
"""
|
||||
# Get the count of key snippets
|
||||
snippets = key_snippet_repository.get_all()
|
||||
|
|
@ -92,47 +118,87 @@ def run_key_snippets_gc_agent() -> None:
|
|||
|
||||
# Only run the agent if we actually have snippets to clean
|
||||
if snippet_count > 0:
|
||||
# Get all snippets as a formatted string for the prompt
|
||||
snippets_dict = key_snippet_repository.get_snippets_dict()
|
||||
formatted_snippets = "\n".join([
|
||||
f"Snippet #{k}: filepath={v['filepath']}, line_number={v['line_number']}, description={v['description']}\n```python\n{v['snippet']}\n```"
|
||||
for k, v in snippets_dict.items()
|
||||
])
|
||||
# Try to get the current human input ID to exclude its snippets
|
||||
current_human_input_id = None
|
||||
try:
|
||||
recent_inputs = human_input_repository.get_recent(1)
|
||||
if recent_inputs and len(recent_inputs) > 0:
|
||||
current_human_input_id = recent_inputs[0].id
|
||||
except Exception as e:
|
||||
console.print(f"Warning: Could not retrieve current human input: {str(e)}")
|
||||
|
||||
# Retrieve configuration
|
||||
llm_config = _global_memory.get("config", {})
|
||||
# Get all snippets that are not associated with the current human input
|
||||
eligible_snippets = []
|
||||
protected_snippets = []
|
||||
for snippet in snippets:
|
||||
if current_human_input_id is not None and snippet.human_input_id == current_human_input_id:
|
||||
protected_snippets.append(snippet)
|
||||
else:
|
||||
eligible_snippets.append(snippet)
|
||||
|
||||
# Initialize the LLM model
|
||||
model = initialize_llm(
|
||||
llm_config.get("provider", "anthropic"),
|
||||
llm_config.get("model", "claude-3-7-sonnet-20250219"),
|
||||
temperature=llm_config.get("temperature")
|
||||
)
|
||||
# Only process if we have snippets that can be deleted
|
||||
if eligible_snippets:
|
||||
# Get eligible snippets as a formatted string for the prompt
|
||||
snippets_dict = {
|
||||
snippet.id: {
|
||||
'filepath': snippet.filepath,
|
||||
'line_number': snippet.line_number,
|
||||
'snippet': snippet.snippet,
|
||||
'description': snippet.description
|
||||
}
|
||||
for snippet in eligible_snippets
|
||||
}
|
||||
|
||||
# Create the agent with the delete_key_snippets tool
|
||||
agent = create_agent(model, [delete_key_snippets])
|
||||
formatted_snippets = "\n".join([
|
||||
f"Snippet #{k}: filepath={v['filepath']}, line_number={v['line_number']}, description={v['description']}\n```python\n{v['snippet']}\n```"
|
||||
for k, v in snippets_dict.items()
|
||||
])
|
||||
|
||||
# Format the prompt with the current snippets
|
||||
prompt = KEY_SNIPPETS_GC_PROMPT.format(key_snippets=formatted_snippets)
|
||||
# Retrieve configuration
|
||||
llm_config = _global_memory.get("config", {})
|
||||
|
||||
# Set up the agent configuration
|
||||
agent_config = {
|
||||
"recursion_limit": 50 # Set a reasonable recursion limit
|
||||
}
|
||||
|
||||
# Run the agent
|
||||
run_agent_with_retry(agent, prompt, agent_config)
|
||||
|
||||
# Get updated count
|
||||
updated_snippets = key_snippet_repository.get_all()
|
||||
updated_count = len(updated_snippets)
|
||||
|
||||
# Show info panel with updated count
|
||||
console.print(
|
||||
Panel(
|
||||
f"Cleaned key snippets: {snippet_count} → {updated_count}",
|
||||
title="🗑 GC Complete"
|
||||
# Initialize the LLM model
|
||||
model = initialize_llm(
|
||||
llm_config.get("provider", "anthropic"),
|
||||
llm_config.get("model", "claude-3-7-sonnet-20250219"),
|
||||
temperature=llm_config.get("temperature")
|
||||
)
|
||||
)
|
||||
|
||||
# Create the agent with the delete_key_snippets tool
|
||||
agent = create_agent(model, [delete_key_snippets])
|
||||
|
||||
# Format the prompt with the eligible snippets
|
||||
prompt = KEY_SNIPPETS_GC_PROMPT.format(key_snippets=formatted_snippets)
|
||||
|
||||
# Set up the agent configuration
|
||||
agent_config = {
|
||||
"recursion_limit": 50 # Set a reasonable recursion limit
|
||||
}
|
||||
|
||||
# Run the agent
|
||||
run_agent_with_retry(agent, prompt, agent_config)
|
||||
|
||||
# Get updated count
|
||||
updated_snippets = key_snippet_repository.get_all()
|
||||
updated_count = len(updated_snippets)
|
||||
|
||||
# Show info panel with updated count and protected snippets count
|
||||
protected_count = len(protected_snippets)
|
||||
if protected_count > 0:
|
||||
console.print(
|
||||
Panel(
|
||||
f"Cleaned key snippets: {snippet_count} → {updated_count}\nProtected snippets (associated with current request): {protected_count}",
|
||||
title="🗑 GC Complete"
|
||||
)
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
Panel(
|
||||
f"Cleaned key snippets: {snippet_count} → {updated_count}",
|
||||
title="🗑 GC Complete"
|
||||
)
|
||||
)
|
||||
else:
|
||||
console.print(Panel(f"All {len(protected_snippets)} snippets are associated with the current request and protected from deletion.", title="🗑 GC Info"))
|
||||
else:
|
||||
console.print(Panel("No key snippets to clean.", title="🗑 GC Info"))
|
||||
|
|
@ -37,31 +37,44 @@ with suppress(ImportError):
|
|||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Add human_input_id foreign key field to KeyFact and KeySnippet tables."""
|
||||
|
||||
# Add human_input_id field to KeyFact model
|
||||
migrator.add_fields(
|
||||
'key_fact',
|
||||
human_input_id=pw.ForeignKeyField(
|
||||
'human_input',
|
||||
null=True,
|
||||
field='id',
|
||||
on_delete='SET NULL'
|
||||
)
|
||||
)
|
||||
# Get the HumanInput model from migrator.orm
|
||||
HumanInput = migrator.orm['human_input']
|
||||
|
||||
# Add human_input_id field to KeySnippet model
|
||||
migrator.add_fields(
|
||||
'key_snippet',
|
||||
human_input_id=pw.ForeignKeyField(
|
||||
'human_input',
|
||||
null=True,
|
||||
field='id',
|
||||
on_delete='SET NULL'
|
||||
# Skip adding fields if they already exist
|
||||
# Check if the column exists before trying to add it
|
||||
try:
|
||||
# Check if key_fact table has human_input_id column
|
||||
database.execute_sql("SELECT human_input_id FROM key_fact LIMIT 1")
|
||||
except pw.OperationalError:
|
||||
# Column doesn't exist, safe to add
|
||||
migrator.add_fields(
|
||||
'key_fact',
|
||||
human_input=pw.ForeignKeyField(
|
||||
HumanInput,
|
||||
null=True,
|
||||
field='id',
|
||||
on_delete='SET NULL'
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# Check if key_snippet table has human_input_id column
|
||||
database.execute_sql("SELECT human_input_id FROM key_snippet LIMIT 1")
|
||||
except pw.OperationalError:
|
||||
# Column doesn't exist, safe to add
|
||||
migrator.add_fields(
|
||||
'key_snippet',
|
||||
human_input=pw.ForeignKeyField(
|
||||
HumanInput,
|
||||
null=True,
|
||||
field='id',
|
||||
on_delete='SET NULL'
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Remove human_input_id field from KeyFact and KeySnippet tables."""
|
||||
|
||||
migrator.remove_fields('key_fact', 'human_input_id')
|
||||
migrator.remove_fields('key_snippet', 'human_input_id')
|
||||
migrator.remove_fields('key_fact', 'human_input')
|
||||
migrator.remove_fields('key_snippet', 'human_input')
|
||||
Loading…
Reference in New Issue