From 186904c0ca421631835c895ce9687d1927fa2f38 Mon Sep 17 00:00:00 2001 From: AI Christianson Date: Sun, 2 Mar 2025 20:06:04 -0500 Subject: [PATCH] added human input table --- ra_aid/__main__.py | 21 ++ ra_aid/database/models.py | 22 +- .../repositories/human_input_repository.py | 240 ++++++++++++++++++ ...4_20250302_200312_add_human_input_model.py | 55 ++++ ra_aid/tools/human.py | 31 ++- 5 files changed, 365 insertions(+), 4 deletions(-) create mode 100644 ra_aid/database/repositories/human_input_repository.py create mode 100644 ra_aid/migrations/004_20250302_200312_add_human_input_model.py diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index 7fbc770..7daa531 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -463,6 +463,16 @@ def main(): initial_request = ask_human.invoke( {"question": "What would you like help with?"} ) + + # Record chat input in database (redundant as ask_human already records it, + # but needed in case the ask_human implementation changes) + try: + from ra_aid.database.repositories.human_input_repository import HumanInputRepository + human_input_repo = HumanInputRepository() + human_input_repo.create(content=initial_request, source='chat') + human_input_repo.garbage_collect() + except Exception as e: + logger.error(f"Failed to record initial chat input: {str(e)}") # Get working directory and current date working_directory = os.getcwd() @@ -525,6 +535,17 @@ def main(): sys.exit(1) base_task = args.message + + # Record CLI input in database + try: + from ra_aid.database.repositories.human_input_repository import HumanInputRepository + human_input_repo = HumanInputRepository() + human_input_repo.create(content=base_task, source='cli') + # Run garbage collection to ensure we don't exceed 100 inputs + human_input_repo.garbage_collect() + logger.debug(f"Recorded CLI input: {base_task}") + except Exception as e: + logger.error(f"Failed to record CLI input: {str(e)}") config = { "configurable": {"thread_id": str(uuid.uuid4())}, "recursion_limit": args.recursion_limit, diff --git a/ra_aid/database/models.py b/ra_aid/database/models.py index 71ec9fb..8f2dacc 100644 --- a/ra_aid/database/models.py +++ b/ra_aid/database/models.py @@ -42,8 +42,8 @@ def initialize_database(): # to avoid circular imports # Note: This import needs to be here, not at the top level try: - from ra_aid.database.models import KeyFact, KeySnippet - db.create_tables([KeyFact, KeySnippet], safe=True) + from ra_aid.database.models import KeyFact, KeySnippet, HumanInput + db.create_tables([KeyFact, KeySnippet, HumanInput], safe=True) logger.debug("Ensured database tables exist") except Exception as e: logger.error(f"Error creating tables: {str(e)}") @@ -128,4 +128,20 @@ class KeySnippet(BaseModel): # created_at and updated_at are inherited from BaseModel class Meta: - table_name = "key_snippet" \ No newline at end of file + table_name = "key_snippet" + + +class HumanInput(BaseModel): + """ + Model representing human input stored in the database. + + Human inputs are text inputs provided by users through various interfaces + such as CLI, chat, or HIL (human-in-the-loop). This model tracks these inputs + along with their source for analysis and reference. + """ + content = peewee.TextField() + source = peewee.TextField() # 'cli', 'chat', or 'hil' + # created_at and updated_at are inherited from BaseModel + + class Meta: + table_name = "human_input" \ No newline at end of file diff --git a/ra_aid/database/repositories/human_input_repository.py b/ra_aid/database/repositories/human_input_repository.py new file mode 100644 index 0000000..871d8a0 --- /dev/null +++ b/ra_aid/database/repositories/human_input_repository.py @@ -0,0 +1,240 @@ +""" +Human input repository implementation for database access. + +This module provides a repository implementation for the HumanInput model, +following the repository pattern for data access abstraction. +""" + +from typing import Dict, List, Optional + +import peewee + +from ra_aid.database.connection import get_db +from ra_aid.database.models import HumanInput, initialize_database +from ra_aid.logging_config import get_logger + +logger = get_logger(__name__) + + +class HumanInputRepository: + """ + Repository for managing HumanInput database operations. + + This class provides methods for performing CRUD operations on the HumanInput model, + abstracting the database access details from the business logic. + + Example: + repo = HumanInputRepository() + input = repo.create("User's message", "chat") + recent_inputs = repo.get_recent(5) + """ + + def __init__(self, db=None): + """ + Initialize the repository with an optional database connection. + + Args: + db: Optional database connection to use. If None, will use initialize_database() + """ + self.db = db + + def create(self, content: str, source: str) -> HumanInput: + """ + Create a new human input record in the database. + + Args: + content: The text content of the human input + source: The source of the input (e.g., "cli", "chat", "hil") + + Returns: + HumanInput: The newly created human input instance + + Raises: + peewee.DatabaseError: If there's an error creating the record + """ + try: + db = self.db if self.db is not None else initialize_database() + input_record = HumanInput.create(content=content, source=source) + logger.debug(f"Created human input ID {input_record.id} from {source}") + return input_record + except peewee.DatabaseError as e: + logger.error(f"Failed to create human input record: {str(e)}") + raise + + def get(self, input_id: int) -> Optional[HumanInput]: + """ + Retrieve a human input record by its ID. + + Args: + input_id: The ID of the human input to retrieve + + Returns: + Optional[HumanInput]: The human input instance if found, None otherwise + + Raises: + peewee.DatabaseError: If there's an error accessing the database + """ + try: + db = self.db if self.db is not None else initialize_database() + return HumanInput.get_or_none(HumanInput.id == input_id) + except peewee.DatabaseError as e: + logger.error(f"Failed to fetch human input {input_id}: {str(e)}") + raise + + def update(self, input_id: int, content: str = None, source: str = None) -> Optional[HumanInput]: + """ + Update an existing human input record. + + Args: + input_id: The ID of the human input to update + content: The new content for the human input + source: The new source for the human input + + Returns: + Optional[HumanInput]: The updated human input if found, None otherwise + + Raises: + peewee.DatabaseError: If there's an error updating the record + """ + try: + db = self.db if self.db is not None else initialize_database() + # First check if the record exists + input_record = self.get(input_id) + if not input_record: + logger.warning(f"Attempted to update non-existent human input {input_id}") + return None + + # Update the fields that were provided + if content is not None: + input_record.content = content + if source is not None: + input_record.source = source + + input_record.save() + logger.debug(f"Updated human input ID {input_id}") + return input_record + except peewee.DatabaseError as e: + logger.error(f"Failed to update human input {input_id}: {str(e)}") + raise + + def delete(self, input_id: int) -> bool: + """ + Delete a human input record by its ID. + + Args: + input_id: The ID of the human input to delete + + Returns: + bool: True if the record was deleted, False if it wasn't found + + Raises: + peewee.DatabaseError: If there's an error deleting the record + """ + try: + db = self.db if self.db is not None else initialize_database() + # First check if the record exists + input_record = self.get(input_id) + if not input_record: + logger.warning(f"Attempted to delete non-existent human input {input_id}") + return False + + # Delete the record + input_record.delete_instance() + logger.debug(f"Deleted human input ID {input_id}") + return True + except peewee.DatabaseError as e: + logger.error(f"Failed to delete human input {input_id}: {str(e)}") + raise + + def get_all(self) -> List[HumanInput]: + """ + Retrieve all human input records from the database. + + Returns: + List[HumanInput]: List of all human input instances + + Raises: + peewee.DatabaseError: If there's an error accessing the database + """ + try: + db = self.db if self.db is not None else initialize_database() + return list(HumanInput.select().order_by(HumanInput.created_at.desc())) + except peewee.DatabaseError as e: + logger.error(f"Failed to fetch all human inputs: {str(e)}") + raise + + def get_recent(self, limit: int = 10) -> List[HumanInput]: + """ + Retrieve the most recent human input records. + + Args: + limit: Maximum number of records to retrieve (default: 10) + + Returns: + List[HumanInput]: List of the most recent human input records + + Raises: + peewee.DatabaseError: If there's an error accessing the database + """ + try: + db = self.db if self.db is not None else initialize_database() + return list(HumanInput.select().order_by(HumanInput.created_at.desc()).limit(limit)) + except peewee.DatabaseError as e: + logger.error(f"Failed to fetch recent human inputs: {str(e)}") + raise + + def get_by_source(self, source: str) -> List[HumanInput]: + """ + Retrieve human input records by source. + + Args: + source: The source to filter by (e.g., "cli", "chat", "hil") + + Returns: + List[HumanInput]: List of human input records from the specified source + + Raises: + peewee.DatabaseError: If there's an error accessing the database + """ + try: + db = self.db if self.db is not None else initialize_database() + return list(HumanInput.select().where(HumanInput.source == source).order_by(HumanInput.created_at.desc())) + except peewee.DatabaseError as e: + logger.error(f"Failed to fetch human inputs by source {source}: {str(e)}") + raise + + def garbage_collect(self) -> int: + """ + Remove old human input records when the count exceeds 100. + + This method keeps the 100 most recent records and deletes any older ones. + + Returns: + int: Number of records deleted + + Raises: + peewee.DatabaseError: If there's an error accessing the database + """ + try: + db = self.db if self.db is not None else initialize_database() + # Get the count of records + record_count = HumanInput.select().count() + + # If we have more than 100 records, delete the oldest ones + if record_count > 100: + # Get IDs of records to keep (100 most recent) + keep_ids = [input_record.id for input_record in HumanInput.select(HumanInput.id) + .order_by(HumanInput.created_at.desc()) + .limit(100)] + + # Delete records not in the keep_ids list + delete_query = HumanInput.delete().where(HumanInput.id.not_in(keep_ids)) + deleted_count = delete_query.execute() + + logger.info(f"Garbage collected {deleted_count} old human input records") + return deleted_count + + return 0 + except peewee.DatabaseError as e: + logger.error(f"Failed to garbage collect human input records: {str(e)}") + raise \ No newline at end of file diff --git a/ra_aid/migrations/004_20250302_200312_add_human_input_model.py b/ra_aid/migrations/004_20250302_200312_add_human_input_model.py new file mode 100644 index 0000000..32dc137 --- /dev/null +++ b/ra_aid/migrations/004_20250302_200312_add_human_input_model.py @@ -0,0 +1,55 @@ +"""Peewee migrations -- 004_20250302_200312_add_human_input_model.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + @migrator.create_model + class HumanInput(pw.Model): + id = pw.AutoField() + created_at = pw.DateTimeField() + updated_at = pw.DateTimeField() + content = pw.TextField() + source = pw.TextField() + + class Meta: + table_name = "human_input" + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + migrator.remove_model('human_input') \ No newline at end of file diff --git a/ra_aid/tools/human.py b/ra_aid/tools/human.py index 60233a3..39f8769 100644 --- a/ra_aid/tools/human.py +++ b/ra_aid/tools/human.py @@ -5,6 +5,9 @@ from rich.console import Console from rich.markdown import Markdown from rich.panel import Panel +from ra_aid.logging_config import get_logger + +logger = get_logger(__name__) console = Console() @@ -50,6 +53,32 @@ def ask_human(question: str) -> str: print() response = session.prompt("> ", wrap_lines=True) - print() + + # Record human response in database + try: + from ra_aid.database.repositories.human_input_repository import HumanInputRepository + from ra_aid.tools.memory import _global_memory + + # Determine the source based on context + config = _global_memory.get("config", {}) + # If chat_mode is enabled, use 'chat', otherwise determine if hil mode is active + if config.get("chat_mode", False): + source = "chat" + elif config.get("hil", False): + source = "hil" + else: + source = "chat" # Default fallback + + # Store the input + human_input_repo = HumanInputRepository() + human_input_repo.create(content=response, source=source) + + # Run garbage collection to ensure we don't exceed 100 inputs + human_input_repo.garbage_collect() + except Exception as e: + from ra_aid.logging_config import get_logger + logger = get_logger(__name__) + logger.error(f"Failed to record human input: {str(e)}") + return response