From b00fd475734a5246391f8198562df42b3d761c4c Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Wed, 29 Jan 2025 13:38:58 -0800 Subject: [PATCH] feat(Makefile): add commands for code checking and fixing using ruff (#63) refactor(ra_aid/__init__.py): reorganize imports for better readability refactor(ra_aid/__main__.py): clean up imports and improve structure refactor(ra_aid/agent_utils.py): streamline imports and improve organization refactor(ra_aid/agents/ciayn_agent.py): enhance code structure and readability refactor(ra_aid/chat_models/deepseek_chat.py): tidy up imports for clarity refactor(ra_aid/config.py): maintain consistent formatting and organization refactor(ra_aid/console/__init__.py): improve import structure for clarity refactor(ra_aid/console/cowboy_messages.py): enhance code readability refactor(ra_aid/console/formatting.py): clean up formatting functions for consistency refactor(ra_aid/console/output.py): improve output handling for better clarity refactor(ra_aid/dependencies.py): enhance dependency checking structure refactor(ra_aid/env.py): streamline environment validation logic refactor(ra_aid/exceptions.py): improve exception handling structure refactor(ra_aid/file_listing.py): enhance file listing functionality refactor(ra_aid/llm.py): improve language model initialization logic refactor(ra_aid/logging_config.py): tidy up logging configuration refactor(ra_aid/models_tokens.py): maintain consistent formatting refactor(ra_aid/proc/interactive.py): enhance subprocess handling refactor(ra_aid/project_info.py): improve project information handling refactor(ra_aid/project_state.py): streamline project state management refactor(ra_aid/provider_strategy.py): enhance provider validation logic refactor(ra_aid/tests/test_env.py): improve test structure and readability refactor(ra_aid/text/__init__.py): maintain consistent import structure refactor(ra_aid/text/processing.py): enhance text processing functions refactor(ra_aid/tool_configs.py): improve tool configuration structure refactor(ra_aid/tools/__init__.py): tidy up tool imports for clarity refactor(ra_aid/tools/agent.py): enhance agent tool functionality refactor(ra_aid/tools/expert.py): improve expert tool handling refactor(ra_aid/tools/file_str_replace.py): streamline file string replacement logic refactor(ra_aid/tools/fuzzy_find.py): enhance fuzzy find functionality refactor(ra_aid/tools/handle_user_defined_test_cmd_execution.py): improve test command execution logic refactor(ra_aid/tools/human.py): enhance human interaction handling refactor(ra_aid/tools/list_directory.py): improve directory listing functionality refactor(ra_aid/tools/memory.py): streamline memory management functions refactor(ra_aid/tools/programmer.py): enhance programming task handling refactor(ra_aid/tools/read_file.py): improve file reading functionality refactor(ra_aid/tools/reflection.py): maintain consistent function structure refactor(ra_aid/tools/research.py): enhance research tool handling refactor(ra_aid/tools/ripgrep.py): improve ripgrep search functionality refactor(ra_aid/tools/ripgrep.py): enhance ripgrep search command handling refactor(ra_aid/tools/ripgrep.py): improve search result handling refactor(ra_aid/tools/ripgrep.py): streamline search parameters handling refactor(ra_aid/tools/ripgrep.py): enhance error handling for search operations refactor(ra_aid/tools/ripgrep.py): improve output formatting for search results refactor(ra_aid/tools/ripgrep.py): maintain consistent command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search result display refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor(ra_aid/tools/ripgrep.py): enhance search command execution logic refactor(ra_aid/tools/ripgrep.py): improve search command output formatting refactor(ra_aid/tools/ripgrep.py): streamline search command construction refactor(ra_aid/tools/ripgrep.py): enhance search command error handling refactor(ra_aid/tools/ripgrep.py): improve search command output handling refactor(ra_aid/tools/ripgrep.py): maintain consistent search command structure refactor( style(server.ts): update variable naming from lowercase 'port' to uppercase 'PORT' for consistency and clarity feat(server.ts): allow server to listen on a configurable port using process.env.PORT or default to PORT style(shell.py): reorder import statements for better organization and readability style(shell.py): remove unnecessary blank lines for cleaner code style(web_search_tavily.py): reorder import statements for better organization style(write_file.py): format code for consistency and readability style(extract_changelog.py): format code for consistency and readability style(generate_swebench_dataset.py): format code for consistency and readability style(test_ciayn_agent.py): format code for consistency and readability style(test_cowboy_messages.py): format code for consistency and readability style(test_interactive.py): format code for consistency and readability style(test_agent_utils.py): format code for consistency and readability style(test_default_provider.py): format code for consistency and readability style(test_env.py): format code for consistency and readability style(test_llm.py): format code for consistency and readability style(test_main.py): format code for consistency and readability style(test_programmer.py): format code for consistency and readability style(test_provider_integration.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_utils.py): format code for consistency and readability style(test_execution.py): format code for consistency and readability style(test_expert.py): format code for consistency and readability style(test_file_str_replace.py): format code for consistency and readability style(test_fuzzy_find.py): format code for consistency and readability style(test_handle_user_defined_test_cmd_execution.py): format code for consistency and readability style(test_list_directory.py): format code for consistency and readability style(test_user_defined_test_cmd_execution.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and readability style(test_tool_configs.py): format code for consistency and style(tests): standardize string quotes to double quotes for consistency across test files style(tests): format code for better readability by adjusting spacing and line breaks style(tests): remove unnecessary blank lines to improve code cleanliness style(tests): ensure consistent use of whitespace around operators and after commas style(tests): align comments and docstrings for better visual structure and clarity --- Makefile | 11 + ra_aid/__init__.py | 25 +- ra_aid/__main__.py | 40 +- ra_aid/agent_utils.py | 111 ++--- ra_aid/agents/ciayn_agent.py | 117 +++-- ra_aid/chat_models/deepseek_chat.py | 3 +- ra_aid/config.py | 2 +- ra_aid/console/__init__.py | 17 +- ra_aid/console/cowboy_messages.py | 5 +- ra_aid/console/formatting.py | 41 +- ra_aid/console/output.py | 37 +- ra_aid/dependencies.py | 11 +- ra_aid/env.py | 80 +-- ra_aid/exceptions.py | 7 +- ra_aid/file_listing.py | 24 +- ra_aid/llm.py | 11 +- ra_aid/logging_config.py | 2 + ra_aid/models_tokens.py | 1 - ra_aid/proc/interactive.py | 26 +- ra_aid/project_info.py | 71 +-- ra_aid/project_state.py | 9 +- ra_aid/provider_strategy.py | 241 +++++---- ra_aid/tests/test_env.py | 23 +- ra_aid/text/__init__.py | 2 +- ra_aid/text/processing.py | 21 +- ra_aid/tool_configs.py | 85 +++- ra_aid/tools/__init__.py | 101 ++-- ra_aid/tools/agent.py | 231 +++++---- ra_aid/tools/expert.py | 154 +++--- ra_aid/tools/file_str_replace.py | 53 +- ra_aid/tools/fuzzy_find.py | 88 ++-- .../handle_user_defined_test_cmd_execution.py | 112 +++-- ra_aid/tools/human.py | 34 +- ra_aid/tools/list_directory.py | 89 ++-- ra_aid/tools/memory.py | 9 +- ra_aid/tools/programmer.py | 81 +-- ra_aid/tools/read_file.py | 42 +- ra_aid/tools/reflection.py | 4 +- ra_aid/tools/research.py | 9 +- ra_aid/tools/ripgrep.py | 77 +-- ra_aid/tools/shell.py | 35 +- ra_aid/tools/web_search_tavily.py | 14 +- ra_aid/tools/write_file.py | 49 +- scripts/extract_changelog.py | 6 +- scripts/generate_swebench_dataset.py | 60 +-- tests/ra_aid/agents/test_ciayn_agent.py | 150 +++--- tests/ra_aid/console/test_cowboy_messages.py | 8 +- tests/ra_aid/proc/test_interactive.py | 67 ++- tests/ra_aid/test_agent_utils.py | 20 +- tests/ra_aid/test_default_provider.py | 45 +- tests/ra_aid/test_env.py | 169 ++++--- tests/ra_aid/test_llm.py | 387 ++++++++------- tests/ra_aid/test_main.py | 141 ++++-- tests/ra_aid/test_programmer.py | 53 +- tests/ra_aid/test_provider_integration.py | 12 +- tests/ra_aid/test_tool_configs.py | 40 +- tests/ra_aid/test_utils.py | 29 +- tests/ra_aid/tools/test_execution.py | 130 ++--- tests/ra_aid/tools/test_expert.py | 69 ++- tests/ra_aid/tools/test_file_str_replace.py | 150 +++--- tests/ra_aid/tools/test_fuzzy_find.py | 115 +++-- ..._handle_user_defined_test_cmd_execution.py | 113 +++-- tests/ra_aid/tools/test_list_directory.py | 87 ++-- tests/ra_aid/tools/test_memory.py | 460 ++++++++++-------- tests/ra_aid/tools/test_read_file.py | 26 +- tests/ra_aid/tools/test_reflection.py | 12 +- tests/ra_aid/tools/test_shell.py | 96 ++-- tests/ra_aid/tools/test_write_file.py | 155 +++--- tests/scripts/test_extract_changelog.py | 9 + tests/test_file_listing.py | 42 +- tests/test_project_info.py | 36 +- tests/test_project_state.py | 12 +- 72 files changed, 2743 insertions(+), 2161 deletions(-) diff --git a/Makefile b/Makefile index 35b9254..5b2a827 100644 --- a/Makefile +++ b/Makefile @@ -9,3 +9,14 @@ setup-dev: setup-hooks: setup-dev pre-commit install + +check: + ruff check + +fix: + ruff check . --select I --fix # First sort imports + ruff format . + ruff check --fix + +fix-basic: + ruff check --fix diff --git a/ra_aid/__init__.py b/ra_aid/__init__.py index d9f3fd2..22aa72e 100644 --- a/ra_aid/__init__.py +++ b/ra_aid/__init__.py @@ -1,16 +1,21 @@ from .__version__ import __version__ -from .console.formatting import print_stage_header, print_task_header, print_error, print_interrupt +from .agent_utils import run_agent_with_retry +from .console.formatting import ( + print_error, + print_interrupt, + print_stage_header, + print_task_header, +) from .console.output import print_agent_output from .text.processing import truncate_output -from .agent_utils import run_agent_with_retry __all__ = [ - 'print_stage_header', - 'print_task_header', - 'print_agent_output', - 'truncate_output', - 'print_error', - 'print_interrupt', - 'run_agent_with_retry', - '__version__' + "print_stage_header", + "print_task_header", + "print_agent_output", + "truncate_output", + "print_error", + "print_interrupt", + "run_agent_with_retry", + "__version__", ] diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index a5bcee7..e9a36c6 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -1,30 +1,32 @@ import argparse +import os import sys import uuid from datetime import datetime -from rich.panel import Panel -from rich.console import Console + from langgraph.checkpoint.memory import MemorySaver -from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT -from ra_aid.env import validate_environment -from ra_aid.project_info import get_project_info, format_project_info -from ra_aid.tools.memory import _global_memory -from ra_aid.tools.human import ask_human -from ra_aid import print_stage_header, print_error +from rich.console import Console +from rich.panel import Panel + +from ra_aid import print_error, print_stage_header from ra_aid.__version__ import __version__ from ra_aid.agent_utils import ( AgentInterrupt, - run_agent_with_retry, - run_research_agent, - run_planning_agent, create_agent, + run_agent_with_retry, + run_planning_agent, + run_research_agent, ) -from ra_aid.prompts import CHAT_PROMPT, WEB_RESEARCH_PROMPT_SECTION_CHAT -from ra_aid.llm import initialize_llm -from ra_aid.logging_config import setup_logging, get_logger -from ra_aid.tool_configs import get_chat_tools +from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT from ra_aid.dependencies import check_dependencies -import os +from ra_aid.env import validate_environment +from ra_aid.llm import initialize_llm +from ra_aid.logging_config import get_logger, setup_logging +from ra_aid.project_info import format_project_info, get_project_info +from ra_aid.prompts import CHAT_PROMPT, WEB_RESEARCH_PROMPT_SECTION_CHAT +from ra_aid.tool_configs import get_chat_tools +from ra_aid.tools.human import ask_human +from ra_aid.tools.memory import _global_memory logger = get_logger(__name__) @@ -151,12 +153,12 @@ Examples: parser.add_argument( "--test-cmd", type=str, - help="Test command to run before completing tasks (e.g. 'pytest tests/')" + help="Test command to run before completing tasks (e.g. 'pytest tests/')", ) parser.add_argument( "--auto-test", action="store_true", - help="Automatically run tests before completing tasks" + help="Automatically run tests before completing tasks", ) parser.add_argument( "--max-test-cmd-retries", @@ -207,7 +209,7 @@ Examples: # Validate recursion limit is positive if parsed_args.recursion_limit <= 0: parser.error("Recursion limit must be positive") - + # if auto-test command is provided, validate test-cmd is also provided if parsed_args.auto_test and not parsed_args.test_cmd: parser.error("Test command is required when using --auto-test") diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 7cba70b..009a734 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -1,73 +1,65 @@ """Utility functions for working with agents.""" +import signal import sys +import threading import time import uuid -from typing import Optional, Any, List, Dict, Sequence -from langchain_core.messages import BaseMessage, trim_messages -from litellm import get_model_info +from typing import Any, Dict, List, Optional, Sequence + import litellm - -import signal - -from langgraph.checkpoint.memory import MemorySaver -from langgraph.prebuilt.chat_agent_executor import AgentState -from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT -from ra_aid.models_tokens import DEFAULT_TOKEN_LIMIT, models_tokens -from ra_aid.agents.ciayn_agent import CiaynAgent -import threading - -from ra_aid.project_info import ( - get_project_info, - format_project_info, - display_project_status, -) - -from langgraph.prebuilt import create_react_agent -from ra_aid.console.formatting import print_stage_header, print_error +from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError from langchain_core.language_models import BaseChatModel +from langchain_core.messages import BaseMessage, HumanMessage, trim_messages from langchain_core.tools import tool -from ra_aid.console.output import print_agent_output -from ra_aid.logging_config import get_logger -from ra_aid.exceptions import AgentInterrupt -from ra_aid.tool_configs import ( - get_implementation_tools, - get_research_tools, - get_planning_tools, - get_web_research_tools, -) -from ra_aid.prompts import ( - IMPLEMENTATION_PROMPT, - EXPERT_PROMPT_SECTION_IMPLEMENTATION, - HUMAN_PROMPT_SECTION_IMPLEMENTATION, - EXPERT_PROMPT_SECTION_RESEARCH, - WEB_RESEARCH_PROMPT_SECTION_RESEARCH, - WEB_RESEARCH_PROMPT_SECTION_CHAT, - WEB_RESEARCH_PROMPT_SECTION_PLANNING, - RESEARCH_PROMPT, - RESEARCH_ONLY_PROMPT, - HUMAN_PROMPT_SECTION_RESEARCH, - PLANNING_PROMPT, - EXPERT_PROMPT_SECTION_PLANNING, - HUMAN_PROMPT_SECTION_PLANNING, - WEB_RESEARCH_PROMPT, -) - -from langchain_core.messages import HumanMessage -from anthropic import APIError, APITimeoutError, RateLimitError, InternalServerError -from ra_aid.tools.human import ask_human -from ra_aid.tools.shell import run_shell_command +from langgraph.checkpoint.memory import MemorySaver +from langgraph.prebuilt import create_react_agent +from langgraph.prebuilt.chat_agent_executor import AgentState +from litellm import get_model_info from rich.console import Console from rich.markdown import Markdown from rich.panel import Panel +from ra_aid.agents.ciayn_agent import CiaynAgent +from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT +from ra_aid.console.formatting import print_error, print_stage_header +from ra_aid.console.output import print_agent_output +from ra_aid.exceptions import AgentInterrupt +from ra_aid.logging_config import get_logger +from ra_aid.models_tokens import DEFAULT_TOKEN_LIMIT, models_tokens +from ra_aid.project_info import ( + display_project_status, + format_project_info, + get_project_info, +) +from ra_aid.prompts import ( + EXPERT_PROMPT_SECTION_IMPLEMENTATION, + EXPERT_PROMPT_SECTION_PLANNING, + EXPERT_PROMPT_SECTION_RESEARCH, + HUMAN_PROMPT_SECTION_IMPLEMENTATION, + HUMAN_PROMPT_SECTION_PLANNING, + HUMAN_PROMPT_SECTION_RESEARCH, + IMPLEMENTATION_PROMPT, + PLANNING_PROMPT, + RESEARCH_ONLY_PROMPT, + RESEARCH_PROMPT, + WEB_RESEARCH_PROMPT, + WEB_RESEARCH_PROMPT_SECTION_CHAT, + WEB_RESEARCH_PROMPT_SECTION_PLANNING, + WEB_RESEARCH_PROMPT_SECTION_RESEARCH, +) +from ra_aid.tool_configs import ( + get_implementation_tools, + get_planning_tools, + get_research_tools, + get_web_research_tools, +) +from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command from ra_aid.tools.memory import ( _global_memory, get_memory_value, get_related_files, ) -from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command - console = Console() @@ -723,7 +715,7 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: max_retries = 20 base_delay = 1 test_attempts = 0 - max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES) + _max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES) auto_test = config.get("auto_test", False) original_prompt = prompt @@ -752,13 +744,12 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]: _global_memory["task_completed"] = False _global_memory["completion_message"] = "" break - + # Execute test command if configured - should_break, prompt, auto_test, test_attempts = execute_test_command( - config, - original_prompt, - test_attempts, - auto_test + should_break, prompt, auto_test, test_attempts = ( + execute_test_command( + config, original_prompt, test_attempts, auto_test + ) ) if should_break: break diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index 388cbc5..4265d5b 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -1,21 +1,23 @@ import re from dataclasses import dataclass -from typing import Dict, Any, Generator, List, Optional, Union +from typing import Any, Dict, Generator, List, Optional, Union +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage + +from ra_aid.exceptions import ToolExecutionError +from ra_aid.logging_config import get_logger from ra_aid.models_tokens import DEFAULT_TOKEN_LIMIT from ra_aid.tools.reflection import get_function_info -from langchain_core.messages import AIMessage, HumanMessage, BaseMessage, SystemMessage -from ra_aid.exceptions import ToolExecutionError -from ra_aid.logging_config import get_logger - logger = get_logger(__name__) + @dataclass class ChunkMessage: content: str status: str + def validate_function_call_pattern(s: str) -> bool: """Check if a string matches the expected function call pattern. @@ -34,6 +36,7 @@ def validate_function_call_pattern(s: str) -> bool: pattern = r"^\s*[\w_\-]+\s*\([^)(]*(?:\([^)(]*\)[^)(]*)*\)\s*$" return not re.match(pattern, s, re.DOTALL) + class CiaynAgent: """Code Is All You Need (CIAYN) agent that uses generated Python code for tool interaction. @@ -65,10 +68,15 @@ class CiaynAgent: - Memory management with configurable limits """ - - def __init__(self, model, tools: list, max_history_messages: int = 50, max_tokens: Optional[int] = DEFAULT_TOKEN_LIMIT): + def __init__( + self, + model, + tools: list, + max_history_messages: int = 50, + max_tokens: Optional[int] = DEFAULT_TOKEN_LIMIT, + ): """Initialize the agent with a model and list of tools. - + Args: model: The language model to use tools: List of tools available to the agent @@ -88,12 +96,13 @@ class CiaynAgent: base_prompt = "" if last_result is not None: base_prompt += f"\n{last_result}" - + # Add available functions section functions_list = "\n\n".join(self.available_functions) - + # Build the complete prompt without f-strings for the static parts - base_prompt += """ + base_prompt += ( + """ You are a ReAct agent. You run in a loop and use ONE of the available functions per iteration. @@ -111,7 +120,9 @@ You typically don't want to keep calling the same function over and over with th You must ONLY use ONE of the following functions (these are the ONLY functions that exist): -""" + functions_list + """ +""" + + functions_list + + """ You may use any of the above functions to complete your job. Use the best one for the current step you are on. Be efficient, avoid getting stuck in repetitive loops, and do not hesitate to call functions which delegate your work to make your life easier. @@ -202,15 +213,13 @@ You have often been criticized for: DO NOT CLAIM YOU ARE FINISHED UNTIL YOU ACTUALLY ARE! Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" + ) return base_prompt def _execute_tool(self, code: str) -> str: """Execute a tool call and return its result.""" - globals_dict = { - tool.func.__name__: tool.func - for tool in self.tools - } + globals_dict = {tool.func.__name__: tool.func for tool in self.tools} try: code = code.strip() @@ -229,36 +238,28 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" def _create_agent_chunk(self, content: str) -> Dict[str, Any]: """Create an agent chunk in the format expected by print_agent_output.""" - return { - "agent": { - "messages": [AIMessage(content=content)] - } - } + return {"agent": {"messages": [AIMessage(content=content)]}} def _create_error_chunk(self, content: str) -> Dict[str, Any]: """Create an error chunk in the format expected by print_agent_output.""" message = ChunkMessage(content=content, status="error") - return { - "tools": { - "messages": [message] - } - } + return {"tools": {"messages": [message]}} @staticmethod def _estimate_tokens(content: Optional[Union[str, BaseMessage]]) -> int: """Estimate number of tokens in content using simple byte length heuristic. - + Estimates 1 token per 4 bytes of content. For messages, uses the content field. - + Args: content: String content or Message object to estimate tokens for - + Returns: int: Estimated number of tokens, 0 if content is None/empty """ if content is None: return 0 - + if isinstance(content, BaseMessage): text = content.content else: @@ -267,59 +268,72 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" # create-react-agent tool calls can be lists if isinstance(text, List): return 0 - + if not text: return 0 - - return len(text.encode('utf-8')) // 4 - def _trim_chat_history(self, initial_messages: List[Any], chat_history: List[Any]) -> List[Any]: + return len(text.encode("utf-8")) // 4 + + def _trim_chat_history( + self, initial_messages: List[Any], chat_history: List[Any] + ) -> List[Any]: """Trim chat history based on message count and token limits while preserving initial messages. - + Applies both message count and token limits (if configured) to chat_history, while preserving all initial_messages. Returns concatenated result. - + Args: initial_messages: List of initial messages to preserve chat_history: List of chat messages that may be trimmed - + Returns: List[Any]: Concatenated initial_messages + trimmed chat_history """ # First apply message count limit if len(chat_history) > self.max_history_messages: - chat_history = chat_history[-self.max_history_messages:] - + chat_history = chat_history[-self.max_history_messages :] + # Skip token limiting if max_tokens is None if self.max_tokens is None: return initial_messages + chat_history - + # Calculate initial messages token count initial_tokens = sum(self._estimate_tokens(msg) for msg in initial_messages) - + # Remove messages from start of chat_history until under token limit while chat_history: - total_tokens = initial_tokens + sum(self._estimate_tokens(msg) for msg in chat_history) + total_tokens = initial_tokens + sum( + self._estimate_tokens(msg) for msg in chat_history + ) if total_tokens <= self.max_tokens: break chat_history.pop(0) - + return initial_messages + chat_history - def stream(self, messages_dict: Dict[str, List[Any]], config: Dict[str, Any] = None) -> Generator[Dict[str, Any], None, None]: + def stream( + self, messages_dict: Dict[str, List[Any]], config: Dict[str, Any] = None + ) -> Generator[Dict[str, Any], None, None]: """Stream agent responses in a format compatible with print_agent_output.""" initial_messages = messages_dict.get("messages", []) chat_history = [] last_result = None first_iteration = True - + while True: base_prompt = self._build_prompt(None if first_iteration else last_result) chat_history.append(HumanMessage(content=base_prompt)) - + full_history = self._trim_chat_history(initial_messages, chat_history) - response = self.model.invoke([SystemMessage("Execute efficiently yet completely as a fully autonomous agent.")] + full_history) - + response = self.model.invoke( + [ + SystemMessage( + "Execute efficiently yet completely as a fully autonomous agent." + ) + ] + + full_history + ) + try: logger.debug(f"Code generated by agent: {response.content}") last_result = self._execute_tool(response.content) @@ -328,9 +342,14 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" yield {} except ToolExecutionError as e: - chat_history.append(HumanMessage(content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again.")) + chat_history.append( + HumanMessage( + content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again." + ) + ) yield self._create_error_chunk(str(e)) + def _extract_tool_call(code: str, functions_list: str) -> str: from ra_aid.tools.expert import get_model diff --git a/ra_aid/chat_models/deepseek_chat.py b/ra_aid/chat_models/deepseek_chat.py index aa22be3..0d04696 100644 --- a/ra_aid/chat_models/deepseek_chat.py +++ b/ra_aid/chat_models/deepseek_chat.py @@ -1,8 +1,9 @@ +from typing import Any, Dict, List, Optional + from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.messages import BaseMessage from langchain_core.outputs import ChatResult from langchain_openai import ChatOpenAI -from typing import Any, List, Optional, Dict # Docs: https://api-docs.deepseek.com/guides/reasoning_model diff --git a/ra_aid/config.py b/ra_aid/config.py index 6df55f4..7977167 100644 --- a/ra_aid/config.py +++ b/ra_aid/config.py @@ -1,4 +1,4 @@ """Configuration utilities.""" DEFAULT_RECURSION_LIMIT = 100 -DEFAULT_MAX_TEST_CMD_RETRIES = 3 \ No newline at end of file +DEFAULT_MAX_TEST_CMD_RETRIES = 3 diff --git a/ra_aid/console/__init__.py b/ra_aid/console/__init__.py index 3522f88..e1d6466 100644 --- a/ra_aid/console/__init__.py +++ b/ra_aid/console/__init__.py @@ -1,4 +1,17 @@ -from .formatting import print_stage_header, print_task_header, print_error, print_interrupt, console +from .formatting import ( + console, + print_error, + print_interrupt, + print_stage_header, + print_task_header, +) from .output import print_agent_output -__all__ = ['print_stage_header', 'print_task_header', 'print_agent_output', 'console', 'print_error', 'print_interrupt'] +__all__ = [ + "print_stage_header", + "print_task_header", + "print_agent_output", + "console", + "print_error", + "print_interrupt", +] diff --git a/ra_aid/console/cowboy_messages.py b/ra_aid/console/cowboy_messages.py index 31e2f2b..91d4d48 100644 --- a/ra_aid/console/cowboy_messages.py +++ b/ra_aid/console/cowboy_messages.py @@ -7,12 +7,13 @@ COWBOY_MESSAGES = [ "This ain't my first rodeo! 🀠", "Lock and load, partner! 🀠", "I'm just a baby πŸ‘Ά", - "I'll try not to destroy everything 😏" + "I'll try not to destroy everything 😏", ] + def get_cowboy_message() -> str: """Randomly select and return a cowboy message. - + Returns: str: A randomly selected cowboy message """ diff --git a/ra_aid/console/formatting.py b/ra_aid/console/formatting.py index ef77904..eb3c9a9 100644 --- a/ra_aid/console/formatting.py +++ b/ra_aid/console/formatting.py @@ -1,60 +1,65 @@ from rich.console import Console -from rich.panel import Panel from rich.markdown import Markdown +from rich.panel import Panel console = Console() + def print_stage_header(stage: str) -> None: """Print a stage header with stage-specific styling and icons. - + Args: stage: The stage name to print (automatically formatted to Title Case) """ # Define stage icons mapping - using single-width emojis to prevent line wrapping issues icons = { - 'research stage': 'πŸ”Ž', - 'planning stage': 'πŸ“', - 'implementation stage': 'πŸ”§', # Changed from πŸ› οΈ to prevent wrapping - 'task completed': 'βœ…', - 'debug stage': 'πŸ›', - 'testing stage': 'πŸ§ͺ', - 'research subtasks': 'πŸ“š', - 'skipping implementation stage': '⏭️' + "research stage": "πŸ”Ž", + "planning stage": "πŸ“", + "implementation stage": "πŸ”§", # Changed from πŸ› οΈ to prevent wrapping + "task completed": "βœ…", + "debug stage": "πŸ›", + "testing stage": "πŸ§ͺ", + "research subtasks": "πŸ“š", + "skipping implementation stage": "⏭️", } # Format stage name to Title Case and normalize for mapping lookup stage_title = stage.title() stage_key = stage.lower() - + # Get appropriate icon with fallback - icon = icons.get(stage_key, 'πŸš€') - + icon = icons.get(stage_key, "πŸš€") + # Create styled panel with icon panel_content = f"{icon} {stage_title}" console.print(Panel(panel_content, style="green bold", padding=0)) + def print_task_header(task: str) -> None: """Print a task header with yellow styling and wrench emoji. Content is rendered as Markdown. - + Args: task: The task text to print (supports Markdown formatting) """ console.print(Panel(Markdown(task), title="πŸ”§ Task", border_style="yellow bold")) + def print_error(message: str) -> None: """Print an error message in a red-bordered panel with warning emoji. - + Args: message: The error message to display (supports Markdown formatting) """ console.print(Panel(Markdown(message), title="Error", border_style="red bold")) + def print_interrupt(message: str) -> None: """Print an interrupt message in a yellow-bordered panel with stop emoji. - + Args: message: The interrupt message to display (supports Markdown formatting) """ print() # Add spacing for ^C - console.print(Panel(Markdown(message), title="β›” Interrupt", border_style="yellow bold")) - + console.print( + Panel(Markdown(message), title="β›” Interrupt", border_style="yellow bold") + ) diff --git a/ra_aid/console/output.py b/ra_aid/console/output.py index 57b4593..8b64142 100644 --- a/ra_aid/console/output.py +++ b/ra_aid/console/output.py @@ -1,31 +1,42 @@ from typing import Any, Dict -from rich.console import Console -from rich.panel import Panel -from rich.markdown import Markdown + from langchain_core.messages import AIMessage +from rich.markdown import Markdown +from rich.panel import Panel # Import shared console instance from .formatting import console + def print_agent_output(chunk: Dict[str, Any]) -> None: """Print only the agent's message content, not tool calls. - + Args: chunk: A dictionary containing agent or tool messages """ - if 'agent' in chunk and 'messages' in chunk['agent']: - messages = chunk['agent']['messages'] + if "agent" in chunk and "messages" in chunk["agent"]: + messages = chunk["agent"]["messages"] for msg in messages: if isinstance(msg, AIMessage): # Handle text content if isinstance(msg.content, list): for content in msg.content: - if content['type'] == 'text' and content['text'].strip(): - console.print(Panel(Markdown(content['text']), title="πŸ€– Assistant")) + if content["type"] == "text" and content["text"].strip(): + console.print( + Panel(Markdown(content["text"]), title="πŸ€– Assistant") + ) else: if msg.content.strip(): - console.print(Panel(Markdown(msg.content.strip()), title="πŸ€– Assistant")) - elif 'tools' in chunk and 'messages' in chunk['tools']: - for msg in chunk['tools']['messages']: - if msg.status == 'error' and msg.content: - console.print(Panel(Markdown(msg.content.strip()), title="❌ Tool Error", border_style="red bold")) \ No newline at end of file + console.print( + Panel(Markdown(msg.content.strip()), title="πŸ€– Assistant") + ) + elif "tools" in chunk and "messages" in chunk["tools"]: + for msg in chunk["tools"]["messages"]: + if msg.status == "error" and msg.content: + console.print( + Panel( + Markdown(msg.content.strip()), + title="❌ Tool Error", + border_style="red bold", + ) + ) diff --git a/ra_aid/dependencies.py b/ra_aid/dependencies.py index 0e0fb86..483b673 100644 --- a/ra_aid/dependencies.py +++ b/ra_aid/dependencies.py @@ -1,21 +1,27 @@ """Module for checking system dependencies required by RA.Aid.""" + import os import sys -from ra_aid import print_error from abc import ABC, abstractmethod +from ra_aid import print_error + + class Dependency(ABC): """Base class for system dependencies.""" + @abstractmethod def check(self): """Check if the dependency is installed.""" pass + class RipGrepDependency(Dependency): """Dependency checker for ripgrep.""" + def check(self): """Check if ripgrep is installed.""" - result = os.system('rg --version > /dev/null 2>&1') + result = os.system("rg --version > /dev/null 2>&1") if result != 0: print_error("Required dependency 'ripgrep' is not installed.") print("Please install ripgrep:") @@ -25,6 +31,7 @@ class RipGrepDependency(Dependency): print(" - Other: https://github.com/BurntSushi/ripgrep#installation") sys.exit(1) + def check_dependencies(): """Check if required system dependencies are installed.""" dependencies = [RipGrepDependency()] # Create instances diff --git a/ra_aid/env.py b/ra_aid/env.py index cd8b317..56a1bba 100644 --- a/ra_aid/env.py +++ b/ra_aid/env.py @@ -2,17 +2,10 @@ import os import sys -from dataclasses import dataclass -from typing import Tuple, List, Any +from typing import Any, List -from ra_aid import print_error from ra_aid.provider_strategy import ProviderFactory, ValidationResult -@dataclass -class ValidationResult: - """Result of validation.""" - valid: bool - missing_vars: List[str] def validate_provider(provider: str) -> ValidationResult: """Validate provider configuration.""" @@ -20,9 +13,12 @@ def validate_provider(provider: str) -> ValidationResult: return ValidationResult(valid=False, missing_vars=["No provider specified"]) strategy = ProviderFactory.create(provider) if not strategy: - return ValidationResult(valid=False, missing_vars=[f"Unknown provider: {provider}"]) + return ValidationResult( + valid=False, missing_vars=[f"Unknown provider: {provider}"] + ) return strategy.validate() + def copy_base_to_expert_vars(base_provider: str, expert_provider: str) -> None: """Copy base provider environment variables to expert provider if not set. @@ -32,28 +28,24 @@ def copy_base_to_expert_vars(base_provider: str, expert_provider: str) -> None: """ # Map of base to expert environment variables for each provider provider_vars = { - 'openai': { - 'OPENAI_API_KEY': 'EXPERT_OPENAI_API_KEY', - 'OPENAI_API_BASE': 'EXPERT_OPENAI_API_BASE' + "openai": { + "OPENAI_API_KEY": "EXPERT_OPENAI_API_KEY", + "OPENAI_API_BASE": "EXPERT_OPENAI_API_BASE", }, - 'openai-compatible': { - 'OPENAI_API_KEY': 'EXPERT_OPENAI_API_KEY', - 'OPENAI_API_BASE': 'EXPERT_OPENAI_API_BASE' + "openai-compatible": { + "OPENAI_API_KEY": "EXPERT_OPENAI_API_KEY", + "OPENAI_API_BASE": "EXPERT_OPENAI_API_BASE", }, - 'anthropic': { - 'ANTHROPIC_API_KEY': 'EXPERT_ANTHROPIC_API_KEY', - 'ANTHROPIC_MODEL': 'EXPERT_ANTHROPIC_MODEL' + "anthropic": { + "ANTHROPIC_API_KEY": "EXPERT_ANTHROPIC_API_KEY", + "ANTHROPIC_MODEL": "EXPERT_ANTHROPIC_MODEL", }, - 'openrouter': { - 'OPENROUTER_API_KEY': 'EXPERT_OPENROUTER_API_KEY' + "openrouter": {"OPENROUTER_API_KEY": "EXPERT_OPENROUTER_API_KEY"}, + "gemini": { + "GEMINI_API_KEY": "EXPERT_GEMINI_API_KEY", + "GEMINI_MODEL": "EXPERT_GEMINI_MODEL", }, - 'gemini': { - 'GEMINI_API_KEY': 'EXPERT_GEMINI_API_KEY', - 'GEMINI_MODEL': 'EXPERT_GEMINI_MODEL' - }, - 'deepseek': { - 'DEEPSEEK_API_KEY': 'EXPERT_DEEPSEEK_API_KEY' - } + "deepseek": {"DEEPSEEK_API_KEY": "EXPERT_DEEPSEEK_API_KEY"}, } # Get the variables to copy based on the expert provider @@ -63,14 +55,17 @@ def copy_base_to_expert_vars(base_provider: str, expert_provider: str) -> None: if not os.environ.get(expert_var) and os.environ.get(base_var): os.environ[expert_var] = os.environ[base_var] + def validate_expert_provider(provider: str) -> ValidationResult: """Validate expert provider configuration with fallback.""" if not provider: return ValidationResult(valid=True, missing_vars=[]) - + strategy = ProviderFactory.create(provider) if not strategy: - return ValidationResult(valid=False, missing_vars=[f"Unknown expert provider: {provider}"]) + return ValidationResult( + valid=False, missing_vars=[f"Unknown expert provider: {provider}"] + ) # Copy base vars to expert vars for fallback copy_base_to_expert_vars(provider, provider) @@ -78,7 +73,7 @@ def validate_expert_provider(provider: str) -> ValidationResult: # Validate expert configuration result = strategy.validate() missing = [] - + for var in result.missing_vars: key = var.split()[0] # Get the key name without the error message expert_key = f"EXPERT_{key}" @@ -87,20 +82,25 @@ def validate_expert_provider(provider: str) -> ValidationResult: return ValidationResult(valid=len(missing) == 0, missing_vars=missing) + def validate_web_research() -> ValidationResult: """Validate web research configuration.""" key = "TAVILY_API_KEY" return ValidationResult( valid=bool(os.environ.get(key)), - missing_vars=[] if os.environ.get(key) else [f"{key} environment variable is not set"] + missing_vars=[] + if os.environ.get(key) + else [f"{key} environment variable is not set"], ) + def print_missing_dependencies(missing_vars: List[str]) -> None: """Print missing dependencies and exit.""" for var in missing_vars: print(f"Error: {var}", file=sys.stderr) sys.exit(1) + def validate_research_only_provider(args: Any) -> None: """Validate provider and model for research-only mode. @@ -111,16 +111,17 @@ def validate_research_only_provider(args: Any) -> None: SystemExit: If provider or model validation fails """ # Get provider from args - provider = args.provider if args and hasattr(args, 'provider') else None + provider = args.provider if args and hasattr(args, "provider") else None if not provider: sys.exit("No provider specified") # For non-Anthropic providers in research-only mode, model must be specified - if provider != 'anthropic': - model = args.model if hasattr(args, 'model') and args.model else None + if provider != "anthropic": + model = args.model if hasattr(args, "model") and args.model else None if not model: sys.exit("Model is required for non-Anthropic providers") + def validate_research_only(args: Any) -> tuple[bool, list[str], bool, list[str]]: """Validate environment variables for research-only mode. @@ -141,14 +142,15 @@ def validate_research_only(args: Any) -> tuple[bool, list[str], bool, list[str]] web_research_missing = [] # Validate web research dependencies - tavily_key = os.environ.get('TAVILY_API_KEY') + tavily_key = os.environ.get("TAVILY_API_KEY") if not tavily_key: - web_research_missing.append('TAVILY_API_KEY environment variable is not set') + web_research_missing.append("TAVILY_API_KEY environment variable is not set") else: web_research_enabled = True return expert_enabled, expert_missing, web_research_enabled, web_research_missing + def validate_environment(args: Any) -> tuple[bool, list[str], bool, list[str]]: """Validate environment variables for providers and web research tools. @@ -163,9 +165,9 @@ def validate_environment(args: Any) -> tuple[bool, list[str], bool, list[str]]: - web_research_missing: List of missing web research dependencies """ # For research-only mode, use separate validation - if hasattr(args, 'research_only') and args.research_only: + if hasattr(args, "research_only") and args.research_only: # Only validate provider and model when testing provider validation - if hasattr(args, 'model') and args.model is None: + if hasattr(args, "model") and args.model is None: validate_research_only_provider(args) return validate_research_only(args) @@ -176,7 +178,7 @@ def validate_environment(args: Any) -> tuple[bool, list[str], bool, list[str]]: web_research_missing = [] # Get provider from args - provider = args.provider if args and hasattr(args, 'provider') else None + provider = args.provider if args and hasattr(args, "provider") else None if not provider: sys.exit("No provider specified") diff --git a/ra_aid/exceptions.py b/ra_aid/exceptions.py index 9831a24..696b47e 100644 --- a/ra_aid/exceptions.py +++ b/ra_aid/exceptions.py @@ -1,18 +1,21 @@ """Custom exceptions for RA.Aid.""" + class AgentInterrupt(Exception): """Exception raised when an agent's execution is interrupted. - + This exception is used for internal agent interruption handling, separate from KeyboardInterrupt which is reserved for top-level handling. """ + pass class ToolExecutionError(Exception): """Exception raised when a tool execution fails. - + This exception is used to distinguish tool execution failures from other types of errors in the agent system. """ + pass diff --git a/ra_aid/file_listing.py b/ra_aid/file_listing.py index 1e6cf91..c83cd7b 100644 --- a/ra_aid/file_listing.py +++ b/ra_aid/file_listing.py @@ -7,21 +7,25 @@ from typing import List, Optional, Tuple class FileListerError(Exception): """Base exception for file listing related errors.""" + pass class GitCommandError(FileListerError): """Raised when a git command fails.""" + pass class DirectoryNotFoundError(FileListerError): """Raised when the specified directory does not exist.""" + pass class DirectoryAccessError(FileListerError): """Raised when the directory cannot be accessed due to permissions.""" + pass @@ -51,7 +55,7 @@ def is_git_repo(directory: str) -> bool: ["git", "rev-parse", "--git-dir"], cwd=str(path), capture_output=True, - text=True + text=True, ) return result.returncode == 0 @@ -65,7 +69,9 @@ def is_git_repo(directory: str) -> bool: raise FileListerError(f"Error checking git repository: {e}") -def get_file_listing(directory: str, limit: Optional[int] = None) -> Tuple[List[str], int]: +def get_file_listing( + directory: str, limit: Optional[int] = None +) -> Tuple[List[str], int]: """ Get a list of tracked files in a git repository. @@ -99,29 +105,25 @@ def get_file_listing(directory: str, limit: Optional[int] = None) -> Tuple[List[ cwd=directory, capture_output=True, text=True, - check=True + check=True, ) # Process the output - files = [ - line.strip() - for line in result.stdout.splitlines() - if line.strip() - ] + files = [line.strip() for line in result.stdout.splitlines() if line.strip()] # Deduplicate and sort for consistency files = list(dict.fromkeys(files)) # Remove duplicates while preserving order # Sort for consistency files.sort() - + # Get total count before truncation total_count = len(files) - + # Truncate if limit specified if limit is not None: files = files[:limit] - + return files, total_count except subprocess.CalledProcessError as e: diff --git a/ra_aid/llm.py b/ra_aid/llm.py index 5e22fa1..16d5036 100644 --- a/ra_aid/llm.py +++ b/ra_aid/llm.py @@ -1,14 +1,17 @@ import os -from typing import Optional, Dict, Any -from langchain_openai import ChatOpenAI +from typing import Any, Dict, Optional + from langchain_anthropic import ChatAnthropic from langchain_core.language_models import BaseChatModel from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_openai import ChatOpenAI + from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner from ra_aid.logging_config import get_logger logger = get_logger(__name__) + def get_env_var(name: str, expert: bool = False) -> Optional[str]: """Get environment variable with optional expert prefix and fallback.""" prefix = "EXPERT_" if expert else "" @@ -129,7 +132,7 @@ def create_llm_client( provider, model_name, temperature, - is_expert + is_expert, ) # Handle temperature settings @@ -197,4 +200,4 @@ def initialize_expert_llm( provider: str = "openai", model_name: str = "o1" ) -> BaseChatModel: """Initialize an expert language model client based on the specified provider and model.""" - return create_llm_client(provider, model_name, temperature=None, is_expert=True) \ No newline at end of file + return create_llm_client(provider, model_name, temperature=None, is_expert=True) diff --git a/ra_aid/logging_config.py b/ra_aid/logging_config.py index 8790622..a40aa3a 100644 --- a/ra_aid/logging_config.py +++ b/ra_aid/logging_config.py @@ -2,6 +2,7 @@ import logging import sys from typing import Optional + def setup_logging(verbose: bool = False) -> None: logger = logging.getLogger("ra_aid") logger.setLevel(logging.DEBUG if verbose else logging.INFO) @@ -14,5 +15,6 @@ def setup_logging(verbose: bool = False) -> None: handler.setFormatter(formatter) logger.addHandler(handler) + def get_logger(name: Optional[str] = None) -> logging.Logger: return logging.getLogger(f"ra_aid.{name}" if name else "ra_aid") diff --git a/ra_aid/models_tokens.py b/ra_aid/models_tokens.py index 01aa2ae..a15214e 100644 --- a/ra_aid/models_tokens.py +++ b/ra_aid/models_tokens.py @@ -5,7 +5,6 @@ List of model tokens DEFAULT_TOKEN_LIMIT = 100000 - models_tokens = { "openai": { "gpt-3.5-turbo-0125": 16385, diff --git a/ra_aid/proc/interactive.py b/ra_aid/proc/interactive.py index a097fa9..9d7722b 100644 --- a/ra_aid/proc/interactive.py +++ b/ra_aid/proc/interactive.py @@ -4,15 +4,15 @@ Module for running interactive subprocesses with output capture. import os import re -import tempfile import shlex import shutil +import tempfile from typing import List, Tuple - # Add macOS detection IS_MACOS = os.uname().sysname == "Darwin" + def run_interactive_command(cmd: List[str]) -> Tuple[bytes, int]: """ Runs an interactive command with a pseudo-tty, capturing combined output. @@ -31,7 +31,7 @@ def run_interactive_command(cmd: List[str]) -> Tuple[bytes, int]: # Fail early if cmd is empty if not cmd: raise ValueError("No command provided.") - + # Check that the command exists if shutil.which(cmd[0]) is None: raise FileNotFoundError(f"Command '{cmd[0]}' not found in PATH.") @@ -45,7 +45,7 @@ def run_interactive_command(cmd: List[str]) -> Tuple[bytes, int]: retcode_file.close() # Quote arguments for safety - quoted_cmd = ' '.join(shlex.quote(c) for c in cmd) + quoted_cmd = " ".join(shlex.quote(c) for c in cmd) # Use script to capture output with TTY and save return code shell_cmd = f"{quoted_cmd}; echo $? > {shlex.quote(retcode_path)}" @@ -56,23 +56,25 @@ def run_interactive_command(cmd: List[str]) -> Tuple[bytes, int]: try: # Disable pagers by setting environment variables - os.environ['GIT_PAGER'] = '' - os.environ['PAGER'] = '' - + os.environ["GIT_PAGER"] = "" + os.environ["PAGER"] = "" + # Run command with script for TTY and output capture if IS_MACOS: os.system(f"script -q {shlex.quote(output_path)} {shell_cmd}") else: - os.system(f"script -q -c {shlex.quote(shell_cmd)} {shlex.quote(output_path)}") + os.system( + f"script -q -c {shlex.quote(shell_cmd)} {shlex.quote(output_path)}" + ) # Read and clean the output with open(output_path, "rb") as f: output = f.read() - + # Clean ANSI escape sequences and control characters - output = re.sub(rb'\x1b\[[0-9;]*[a-zA-Z]', b'', output) # ANSI escape sequences - output = re.sub(rb'[\x00-\x08\x0b\x0c\x0e-\x1f]', b'', output) # Control chars - + output = re.sub(rb"\x1b\[[0-9;]*[a-zA-Z]", b"", output) # ANSI escape sequences + output = re.sub(rb"[\x00-\x08\x0b\x0c\x0e-\x1f]", b"", output) # Control chars + # Get the return code with open(retcode_path, "r") as f: return_code = int(f.read().strip()) diff --git a/ra_aid/project_info.py b/ra_aid/project_info.py index 9b35d28..5c4371b 100644 --- a/ra_aid/project_info.py +++ b/ra_aid/project_info.py @@ -2,25 +2,33 @@ from dataclasses import dataclass from typing import List, Optional + from rich.console import Console from rich.markdown import Markdown from rich.panel import Panel -__all__ = ['ProjectInfo', 'ProjectInfoError', 'get_project_info', 'format_project_info', 'display_project_status'] +__all__ = [ + "ProjectInfo", + "ProjectInfoError", + "get_project_info", + "format_project_info", + "display_project_status", +] -from ra_aid.project_state import is_new_project, ProjectStateError -from ra_aid.file_listing import get_file_listing, FileListerError +from ra_aid.file_listing import FileListerError, get_file_listing +from ra_aid.project_state import ProjectStateError, is_new_project @dataclass class ProjectInfo: """Data class containing project information. - + Attributes: is_new: Whether the project is new/empty files: List of tracked files in the project total_files: Total number of tracked files (before any limit) """ + is_new: bool files: List[str] total_files: int @@ -28,6 +36,7 @@ class ProjectInfo: class ProjectInfoError(Exception): """Base exception for project info related errors.""" + pass @@ -50,17 +59,13 @@ def get_project_info(directory: str, file_limit: Optional[int] = None) -> Projec try: # Check if project is new new_status = is_new_project(directory) - + # Get file listing files, total = get_file_listing(directory, limit=file_limit) - - return ProjectInfo( - is_new=new_status, - files=files, - total_files=total - ) - - except (ProjectStateError, FileListerError) as e: + + return ProjectInfo(is_new=new_status, files=files, total_files=total) + + except (ProjectStateError, FileListerError): # Re-raise known errors raise except Exception as e: @@ -70,51 +75,61 @@ def get_project_info(directory: str, file_limit: Optional[int] = None) -> Projec def format_project_info(info: ProjectInfo) -> str: """Format project information into a displayable string. - + Args: info: ProjectInfo object to format - + Returns: Formatted string containing project status and file listing """ # Create project status line status = "New/Empty Project" if info.is_new else "Existing Project" - + # Handle empty project case if info.total_files == 0: return f"Project Status: {status}\nTotal Files: 0\nFiles: None" - + # Format file count with truncation notice if needed - file_count = f"{len(info.files)} of {info.total_files}" if len(info.files) < info.total_files else str(info.total_files) + file_count = ( + f"{len(info.files)} of {info.total_files}" + if len(info.files) < info.total_files + else str(info.total_files) + ) file_count_line = f"Total Files: {file_count}" - + # Format file listing files_section = "Files:\n" + "\n".join(f"- {f}" for f in info.files) - + # Add truncation notice if list was truncated if len(info.files) < info.total_files: - files_section += f"\n[Note: Showing {len(info.files)} of {info.total_files} total files]" - + files_section += ( + f"\n[Note: Showing {len(info.files)} of {info.total_files} total files]" + ) + return f"Project Status: {status}\n{file_count_line}\n{files_section}" def display_project_status(info: ProjectInfo) -> None: """Display project status in a visual panel. - + Args: info: ProjectInfo object containing project state - """ + """ # Create project status text status = "**New/empty project**" if info.is_new else "**Existing project**" - + # Format file count (with truncation notice if needed) - file_count = f"{len(info.files)} of {info.total_files}" if len(info.files) < info.total_files else str(info.total_files) - + file_count = ( + f"{len(info.files)} of {info.total_files}" + if len(info.files) < info.total_files + else str(info.total_files) + ) + # Build status text with markdown status_text = f""" {status} with **{file_count} file(s)** """ - # Add truncation notice if list was truncated + # Add truncation notice if list was truncated if len(info.files) < info.total_files: status_text += f"\n[*Note: File listing truncated ({len(info.files)} of {info.total_files} shown)*]" diff --git a/ra_aid/project_state.py b/ra_aid/project_state.py index d49c24b..05f576a 100644 --- a/ra_aid/project_state.py +++ b/ra_aid/project_state.py @@ -6,16 +6,19 @@ from typing import Set class ProjectStateError(Exception): """Base exception for project state related errors.""" + pass class DirectoryNotFoundError(ProjectStateError): """Raised when the specified directory does not exist.""" + pass class DirectoryAccessError(ProjectStateError): """Raised when the directory cannot be accessed due to permissions.""" + pass @@ -47,18 +50,18 @@ def is_new_project(directory: str) -> bool: raise DirectoryNotFoundError(f"Path is not a directory: {directory}") # Get all files/dirs in the directory, excluding contents of .git - allowed_items: Set[str] = {'.git', '.gitignore'} + _allowed_items: Set[str] = {".git", ".gitignore"} try: contents = set() for item in path.iterdir(): # Only consider top-level items - if item.name != '.git': + if item.name != ".git": contents.add(item.name) except PermissionError as e: raise DirectoryAccessError(f"Cannot access directory {directory}: {e}") # Directory is new if empty or only contains .gitignore - return len(contents) == 0 or contents.issubset({'.gitignore'}) + return len(contents) == 0 or contents.issubset({".gitignore"}) except Exception as e: if isinstance(e, ProjectStateError): diff --git a/ra_aid/provider_strategy.py b/ra_aid/provider_strategy.py index 8ac82aa..735ef3a 100644 --- a/ra_aid/provider_strategy.py +++ b/ra_aid/provider_strategy.py @@ -1,17 +1,20 @@ """Provider validation strategies.""" -from abc import ABC, abstractmethod import os import re +from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional, List, Any +from typing import Any, List, Optional + @dataclass class ValidationResult: """Result of validation.""" + valid: bool missing_vars: List[str] + class ProviderStrategy(ABC): """Abstract base class for provider validation strategies.""" @@ -20,6 +23,7 @@ class ProviderStrategy(ABC): """Validate provider environment variables.""" pass + class OpenAIStrategy(ProviderStrategy): """OpenAI provider validation strategy.""" @@ -28,41 +32,50 @@ class OpenAIStrategy(ProviderStrategy): missing = [] # Check if we're validating expert config - if args and hasattr(args, 'expert_provider') and args.expert_provider == 'openai': - key = os.environ.get('EXPERT_OPENAI_API_KEY') - if not key or key == '': + if ( + args + and hasattr(args, "expert_provider") + and args.expert_provider == "openai" + ): + key = os.environ.get("EXPERT_OPENAI_API_KEY") + if not key or key == "": # Try to copy from base if not set - base_key = os.environ.get('OPENAI_API_KEY') + base_key = os.environ.get("OPENAI_API_KEY") if base_key: - os.environ['EXPERT_OPENAI_API_KEY'] = base_key + os.environ["EXPERT_OPENAI_API_KEY"] = base_key key = base_key if not key: - missing.append('EXPERT_OPENAI_API_KEY environment variable is not set') + missing.append("EXPERT_OPENAI_API_KEY environment variable is not set") # Check expert model only for research-only mode - if hasattr(args, 'research_only') and args.research_only: - model = args.expert_model if hasattr(args, 'expert_model') else None + if hasattr(args, "research_only") and args.research_only: + model = args.expert_model if hasattr(args, "expert_model") else None if not model: - model = os.environ.get('EXPERT_OPENAI_MODEL') + model = os.environ.get("EXPERT_OPENAI_MODEL") if not model: - model = os.environ.get('OPENAI_MODEL') + model = os.environ.get("OPENAI_MODEL") if not model: - missing.append('Model is required for OpenAI provider in research-only mode') + missing.append( + "Model is required for OpenAI provider in research-only mode" + ) else: - key = os.environ.get('OPENAI_API_KEY') + key = os.environ.get("OPENAI_API_KEY") if not key: - missing.append('OPENAI_API_KEY environment variable is not set') + missing.append("OPENAI_API_KEY environment variable is not set") # Check model only for research-only mode - if hasattr(args, 'research_only') and args.research_only: - model = args.model if hasattr(args, 'model') else None + if hasattr(args, "research_only") and args.research_only: + model = args.model if hasattr(args, "model") else None if not model: - model = os.environ.get('OPENAI_MODEL') + model = os.environ.get("OPENAI_MODEL") if not model: - missing.append('Model is required for OpenAI provider in research-only mode') + missing.append( + "Model is required for OpenAI provider in research-only mode" + ) return ValidationResult(valid=len(missing) == 0, missing_vars=missing) + class OpenAICompatibleStrategy(ProviderStrategy): """OpenAI-compatible provider validation strategy.""" @@ -71,84 +84,97 @@ class OpenAICompatibleStrategy(ProviderStrategy): missing = [] # Check if we're validating expert config - if args and hasattr(args, 'expert_provider') and args.expert_provider == 'openai-compatible': - key = os.environ.get('EXPERT_OPENAI_API_KEY') - base = os.environ.get('EXPERT_OPENAI_API_BASE') + if ( + args + and hasattr(args, "expert_provider") + and args.expert_provider == "openai-compatible" + ): + key = os.environ.get("EXPERT_OPENAI_API_KEY") + base = os.environ.get("EXPERT_OPENAI_API_BASE") # Try to copy from base if not set - if not key or key == '': - base_key = os.environ.get('OPENAI_API_KEY') + if not key or key == "": + base_key = os.environ.get("OPENAI_API_KEY") if base_key: - os.environ['EXPERT_OPENAI_API_KEY'] = base_key + os.environ["EXPERT_OPENAI_API_KEY"] = base_key key = base_key - if not base or base == '': - base_base = os.environ.get('OPENAI_API_BASE') + if not base or base == "": + base_base = os.environ.get("OPENAI_API_BASE") if base_base: - os.environ['EXPERT_OPENAI_API_BASE'] = base_base + os.environ["EXPERT_OPENAI_API_BASE"] = base_base base = base_base if not key: - missing.append('EXPERT_OPENAI_API_KEY environment variable is not set') + missing.append("EXPERT_OPENAI_API_KEY environment variable is not set") if not base: - missing.append('EXPERT_OPENAI_API_BASE environment variable is not set') + missing.append("EXPERT_OPENAI_API_BASE environment variable is not set") # Check expert model only for research-only mode - if hasattr(args, 'research_only') and args.research_only: - model = args.expert_model if hasattr(args, 'expert_model') else None + if hasattr(args, "research_only") and args.research_only: + model = args.expert_model if hasattr(args, "expert_model") else None if not model: - model = os.environ.get('EXPERT_OPENAI_MODEL') + model = os.environ.get("EXPERT_OPENAI_MODEL") if not model: - model = os.environ.get('OPENAI_MODEL') + model = os.environ.get("OPENAI_MODEL") if not model: - missing.append('Model is required for OpenAI-compatible provider in research-only mode') + missing.append( + "Model is required for OpenAI-compatible provider in research-only mode" + ) else: - key = os.environ.get('OPENAI_API_KEY') - base = os.environ.get('OPENAI_API_BASE') + key = os.environ.get("OPENAI_API_KEY") + base = os.environ.get("OPENAI_API_BASE") if not key: - missing.append('OPENAI_API_KEY environment variable is not set') + missing.append("OPENAI_API_KEY environment variable is not set") if not base: - missing.append('OPENAI_API_BASE environment variable is not set') + missing.append("OPENAI_API_BASE environment variable is not set") # Check model only for research-only mode - if hasattr(args, 'research_only') and args.research_only: - model = args.model if hasattr(args, 'model') else None + if hasattr(args, "research_only") and args.research_only: + model = args.model if hasattr(args, "model") else None if not model: - model = os.environ.get('OPENAI_MODEL') + model = os.environ.get("OPENAI_MODEL") if not model: - missing.append('Model is required for OpenAI-compatible provider in research-only mode') + missing.append( + "Model is required for OpenAI-compatible provider in research-only mode" + ) return ValidationResult(valid=len(missing) == 0, missing_vars=missing) + class AnthropicStrategy(ProviderStrategy): """Anthropic provider validation strategy.""" - VALID_MODELS = [ - "claude-" - ] + VALID_MODELS = ["claude-"] def validate(self, args: Optional[Any] = None) -> ValidationResult: """Validate Anthropic environment variables and model.""" missing = [] # Check if we're validating expert config - is_expert = args and hasattr(args, 'expert_provider') and args.expert_provider == 'anthropic' + is_expert = ( + args + and hasattr(args, "expert_provider") + and args.expert_provider == "anthropic" + ) # Check API key if is_expert: - key = os.environ.get('EXPERT_ANTHROPIC_API_KEY') - if not key or key == '': + key = os.environ.get("EXPERT_ANTHROPIC_API_KEY") + if not key or key == "": # Try to copy from base if not set - base_key = os.environ.get('ANTHROPIC_API_KEY') + base_key = os.environ.get("ANTHROPIC_API_KEY") if base_key: - os.environ['EXPERT_ANTHROPIC_API_KEY'] = base_key + os.environ["EXPERT_ANTHROPIC_API_KEY"] = base_key key = base_key if not key: - missing.append('EXPERT_ANTHROPIC_API_KEY environment variable is not set') + missing.append( + "EXPERT_ANTHROPIC_API_KEY environment variable is not set" + ) else: - key = os.environ.get('ANTHROPIC_API_KEY') + key = os.environ.get("ANTHROPIC_API_KEY") if not key: - missing.append('ANTHROPIC_API_KEY environment variable is not set') + missing.append("ANTHROPIC_API_KEY environment variable is not set") # Check model model_matched = False @@ -156,25 +182,25 @@ class AnthropicStrategy(ProviderStrategy): # First check command line argument if is_expert: - if hasattr(args, 'expert_model') and args.expert_model: + if hasattr(args, "expert_model") and args.expert_model: model_to_check = args.expert_model else: # If no expert model, check environment variable - model_to_check = os.environ.get('EXPERT_ANTHROPIC_MODEL') - if not model_to_check or model_to_check == '': + model_to_check = os.environ.get("EXPERT_ANTHROPIC_MODEL") + if not model_to_check or model_to_check == "": # Try to copy from base if not set - base_model = os.environ.get('ANTHROPIC_MODEL') + base_model = os.environ.get("ANTHROPIC_MODEL") if base_model: - os.environ['EXPERT_ANTHROPIC_MODEL'] = base_model + os.environ["EXPERT_ANTHROPIC_MODEL"] = base_model model_to_check = base_model else: - if hasattr(args, 'model') and args.model: + if hasattr(args, "model") and args.model: model_to_check = args.model else: - model_to_check = os.environ.get('ANTHROPIC_MODEL') + model_to_check = os.environ.get("ANTHROPIC_MODEL") if not model_to_check: - missing.append('ANTHROPIC_MODEL environment variable is not set') + missing.append("ANTHROPIC_MODEL environment variable is not set") return ValidationResult(valid=len(missing) == 0, missing_vars=missing) # Validate model format @@ -184,10 +210,13 @@ class AnthropicStrategy(ProviderStrategy): break if not model_matched: - missing.append(f'Invalid Anthropic model: {model_to_check}. Must match one of these patterns: {", ".join(self.VALID_MODELS)}') + missing.append( + f'Invalid Anthropic model: {model_to_check}. Must match one of these patterns: {", ".join(self.VALID_MODELS)}' + ) return ValidationResult(valid=len(missing) == 0, missing_vars=missing) + class OpenRouterStrategy(ProviderStrategy): """OpenRouter provider validation strategy.""" @@ -196,23 +225,30 @@ class OpenRouterStrategy(ProviderStrategy): missing = [] # Check if we're validating expert config - if args and hasattr(args, 'expert_provider') and args.expert_provider == 'openrouter': - key = os.environ.get('EXPERT_OPENROUTER_API_KEY') - if not key or key == '': + if ( + args + and hasattr(args, "expert_provider") + and args.expert_provider == "openrouter" + ): + key = os.environ.get("EXPERT_OPENROUTER_API_KEY") + if not key or key == "": # Try to copy from base if not set - base_key = os.environ.get('OPENROUTER_API_KEY') + base_key = os.environ.get("OPENROUTER_API_KEY") if base_key: - os.environ['EXPERT_OPENROUTER_API_KEY'] = base_key + os.environ["EXPERT_OPENROUTER_API_KEY"] = base_key key = base_key if not key: - missing.append('EXPERT_OPENROUTER_API_KEY environment variable is not set') + missing.append( + "EXPERT_OPENROUTER_API_KEY environment variable is not set" + ) else: - key = os.environ.get('OPENROUTER_API_KEY') + key = os.environ.get("OPENROUTER_API_KEY") if not key: - missing.append('OPENROUTER_API_KEY environment variable is not set') + missing.append("OPENROUTER_API_KEY environment variable is not set") return ValidationResult(valid=len(missing) == 0, missing_vars=missing) + class GeminiStrategy(ProviderStrategy): """Gemini provider validation strategy.""" @@ -221,20 +257,24 @@ class GeminiStrategy(ProviderStrategy): missing = [] # Check if we're validating expert config - if args and hasattr(args, 'expert_provider') and args.expert_provider == 'gemini': - key = os.environ.get('EXPERT_GEMINI_API_KEY') - if not key or key == '': + if ( + args + and hasattr(args, "expert_provider") + and args.expert_provider == "gemini" + ): + key = os.environ.get("EXPERT_GEMINI_API_KEY") + if not key or key == "": # Try to copy from base if not set - base_key = os.environ.get('GEMINI_API_KEY') + base_key = os.environ.get("GEMINI_API_KEY") if base_key: - os.environ['EXPERT_GEMINI_API_KEY'] = base_key + os.environ["EXPERT_GEMINI_API_KEY"] = base_key key = base_key if not key: - missing.append('EXPERT_GEMINI_API_KEY environment variable is not set') + missing.append("EXPERT_GEMINI_API_KEY environment variable is not set") else: - key = os.environ.get('GEMINI_API_KEY') + key = os.environ.get("GEMINI_API_KEY") if not key: - missing.append('GEMINI_API_KEY environment variable is not set') + missing.append("GEMINI_API_KEY environment variable is not set") return ValidationResult(valid=len(missing) == 0, missing_vars=missing) @@ -246,20 +286,26 @@ class DeepSeekStrategy(ProviderStrategy): """Validate DeepSeek environment variables.""" missing = [] - if args and hasattr(args, 'expert_provider') and args.expert_provider == 'deepseek': - key = os.environ.get('EXPERT_DEEPSEEK_API_KEY') - if not key or key == '': + if ( + args + and hasattr(args, "expert_provider") + and args.expert_provider == "deepseek" + ): + key = os.environ.get("EXPERT_DEEPSEEK_API_KEY") + if not key or key == "": # Try to copy from base if not set - base_key = os.environ.get('DEEPSEEK_API_KEY') + base_key = os.environ.get("DEEPSEEK_API_KEY") if base_key: - os.environ['EXPERT_DEEPSEEK_API_KEY'] = base_key + os.environ["EXPERT_DEEPSEEK_API_KEY"] = base_key key = base_key if not key: - missing.append('EXPERT_DEEPSEEK_API_KEY environment variable is not set') + missing.append( + "EXPERT_DEEPSEEK_API_KEY environment variable is not set" + ) else: - key = os.environ.get('DEEPSEEK_API_KEY') + key = os.environ.get("DEEPSEEK_API_KEY") if not key: - missing.append('DEEPSEEK_API_KEY environment variable is not set') + missing.append("DEEPSEEK_API_KEY environment variable is not set") return ValidationResult(valid=len(missing) == 0, missing_vars=missing) @@ -270,13 +316,14 @@ class OllamaStrategy(ProviderStrategy): def validate(self, args: Optional[Any] = None) -> ValidationResult: """Validate Ollama environment variables.""" missing = [] - - base_url = os.environ.get('OLLAMA_BASE_URL') + + base_url = os.environ.get("OLLAMA_BASE_URL") if not base_url: - missing.append('OLLAMA_BASE_URL environment variable is not set') + missing.append("OLLAMA_BASE_URL environment variable is not set") return ValidationResult(valid=len(missing) == 0, missing_vars=missing) + class ProviderFactory: """Factory for creating provider validation strategies.""" @@ -292,13 +339,13 @@ class ProviderFactory: Provider validation strategy or None if provider not found """ strategies = { - 'openai': OpenAIStrategy(), - 'openai-compatible': OpenAICompatibleStrategy(), - 'anthropic': AnthropicStrategy(), - 'openrouter': OpenRouterStrategy(), - 'gemini': GeminiStrategy(), - 'ollama': OllamaStrategy(), - 'deepseek': DeepSeekStrategy() + "openai": OpenAIStrategy(), + "openai-compatible": OpenAICompatibleStrategy(), + "anthropic": AnthropicStrategy(), + "openrouter": OpenRouterStrategy(), + "gemini": GeminiStrategy(), + "ollama": OllamaStrategy(), + "deepseek": DeepSeekStrategy(), } strategy = strategies.get(provider) return strategy diff --git a/ra_aid/tests/test_env.py b/ra_aid/tests/test_env.py index ce44a3c..3c6f01e 100644 --- a/ra_aid/tests/test_env.py +++ b/ra_aid/tests/test_env.py @@ -1,46 +1,47 @@ """Unit tests for environment validation.""" -import pytest from dataclasses import dataclass -from typing import List + +import pytest + from ra_aid.env import validate_environment + @dataclass class MockArgs: """Mock arguments for testing.""" + research_only: bool provider: str expert_provider: str = None + TEST_CASES = [ pytest.param( "research_only_no_model", MockArgs(research_only=True, provider="openai"), (False, [], False, ["TAVILY_API_KEY environment variable is not set"]), {}, - id="research_only_no_model" + id="research_only_no_model", ), pytest.param( "research_only_with_model", MockArgs(research_only=True, provider="openai"), (False, [], True, []), {"TAVILY_API_KEY": "test_key"}, - id="research_only_with_model" - ) + id="research_only_with_model", + ), ] + @pytest.mark.parametrize("test_name,args,expected,env_vars", TEST_CASES) def test_validate_environment_research_only( - test_name: str, - args: MockArgs, - expected: tuple, - env_vars: dict, - monkeypatch + test_name: str, args: MockArgs, expected: tuple, env_vars: dict, monkeypatch ): """Test validate_environment with research_only flag.""" # Clear any existing environment variables monkeypatch.delenv("TAVILY_API_KEY", raising=False) - + # Set test environment variables for key, value in env_vars.items(): monkeypatch.setenv(key, value) diff --git a/ra_aid/text/__init__.py b/ra_aid/text/__init__.py index d53522c..0b66541 100644 --- a/ra_aid/text/__init__.py +++ b/ra_aid/text/__init__.py @@ -1,3 +1,3 @@ from .processing import truncate_output -__all__ = ['truncate_output'] +__all__ = ["truncate_output"] diff --git a/ra_aid/text/processing.py b/ra_aid/text/processing.py index 360e464..65279ac 100644 --- a/ra_aid/text/processing.py +++ b/ra_aid/text/processing.py @@ -1,42 +1,43 @@ from typing import Optional + def truncate_output(output: str, max_lines: Optional[int] = 5000) -> str: """Truncate output string to keep only the most recent lines if it exceeds max_lines. - + When truncation occurs, adds a message indicating how many lines were removed. Preserves original line endings and handles Unicode characters correctly. - + Args: output: The string output to potentially truncate max_lines: Maximum number of lines to keep (default: 5000) - + Returns: The truncated string if it exceeded max_lines, or the original string if not """ # Handle empty output if not output: return "" - + # Set max_lines to default if None if max_lines is None: max_lines = 5000 - + # Split while preserving line endings lines = output.splitlines(keepends=True) total_lines = len(lines) - + # Return original if under limit if total_lines <= max_lines: return output - + # Calculate lines to remove lines_removed = total_lines - max_lines - + # Keep only the most recent lines truncated_lines = lines[-max_lines:] - + # Add truncation message at start truncation_msg = f"[{lines_removed} lines of output truncated]\n" - + # Combine message with remaining lines return truncation_msg + "".join(truncated_lines) diff --git a/ra_aid/tool_configs.py b/ra_aid/tool_configs.py index bc6eb90..241c32f 100644 --- a/ra_aid/tool_configs.py +++ b/ra_aid/tool_configs.py @@ -1,17 +1,41 @@ from ra_aid.tools import ( - ask_expert, ask_human, run_shell_command, run_programming_task, - emit_research_notes, emit_plan, emit_related_files, - emit_expert_context, emit_key_facts, delete_key_facts, - emit_key_snippets, delete_key_snippets, deregister_related_files, read_file_tool, - fuzzy_find_project_files, ripgrep_search, list_directory_tree, - monorepo_detected, ui_detected, - task_completed, plan_implementation_completed, web_search_tavily + ask_expert, + ask_human, + delete_key_facts, + delete_key_snippets, + deregister_related_files, + emit_expert_context, + emit_key_facts, + emit_key_snippets, + emit_plan, + emit_related_files, + emit_research_notes, + fuzzy_find_project_files, + list_directory_tree, + monorepo_detected, + plan_implementation_completed, + read_file_tool, + ripgrep_search, + run_programming_task, + run_shell_command, + task_completed, + ui_detected, + web_search_tavily, +) +from ra_aid.tools.agent import ( + request_implementation, + request_research, + request_research_and_implementation, + request_task_implementation, + request_web_research, ) from ra_aid.tools.memory import one_shot_completed -from ra_aid.tools.agent import request_research, request_implementation, request_research_and_implementation, request_task_implementation, request_web_research + # Read-only tools that don't modify system state -def get_read_only_tools(human_interaction: bool = False, web_research_enabled: bool = False) -> list: +def get_read_only_tools( + human_interaction: bool = False, web_research_enabled: bool = False +) -> list: """Get the list of read-only tools, optionally including human interaction tools. Args: @@ -32,7 +56,7 @@ def get_read_only_tools(human_interaction: bool = False, web_research_enabled: b read_file_tool, fuzzy_find_project_files, ripgrep_search, - run_shell_command # can modify files, but we still need it for read-only tasks. + run_shell_command, # can modify files, but we still need it for read-only tasks. ] if web_research_enabled: @@ -43,6 +67,7 @@ def get_read_only_tools(human_interaction: bool = False, web_research_enabled: b return tools + # Define constant tool groups READ_ONLY_TOOLS = get_read_only_tools() MODIFICATION_TOOLS = [run_programming_task] @@ -52,10 +77,16 @@ RESEARCH_TOOLS = [ emit_research_notes, one_shot_completed, monorepo_detected, - ui_detected + ui_detected, ] -def get_research_tools(research_only: bool = False, expert_enabled: bool = True, human_interaction: bool = False, web_research_enabled: bool = False) -> list: + +def get_research_tools( + research_only: bool = False, + expert_enabled: bool = True, + human_interaction: bool = False, + web_research_enabled: bool = False, +) -> list: """Get the list of research tools based on mode and whether expert is enabled. Args: @@ -83,7 +114,10 @@ def get_research_tools(research_only: bool = False, expert_enabled: bool = True, return tools -def get_planning_tools(expert_enabled: bool = True, web_research_enabled: bool = False) -> list: + +def get_planning_tools( + expert_enabled: bool = True, web_research_enabled: bool = False +) -> list: """Get the list of planning tools based on whether expert is enabled. Args: @@ -97,7 +131,7 @@ def get_planning_tools(expert_enabled: bool = True, web_research_enabled: bool = planning_tools = [ emit_plan, request_task_implementation, - plan_implementation_completed + plan_implementation_completed, ] tools.extend(planning_tools) @@ -107,7 +141,10 @@ def get_planning_tools(expert_enabled: bool = True, web_research_enabled: bool = return tools -def get_implementation_tools(expert_enabled: bool = True, web_research_enabled: bool = False) -> list: + +def get_implementation_tools( + expert_enabled: bool = True, web_research_enabled: bool = False +) -> list: """Get the list of implementation tools based on whether expert is enabled. Args: @@ -119,9 +156,7 @@ def get_implementation_tools(expert_enabled: bool = True, web_research_enabled: # Add modification tools since it's not research-only tools.extend(MODIFICATION_TOOLS) - tools.extend([ - task_completed - ]) + tools.extend([task_completed]) # Add expert tools if enabled if expert_enabled: @@ -129,6 +164,7 @@ def get_implementation_tools(expert_enabled: bool = True, web_research_enabled: return tools + def get_web_research_tools(expert_enabled: bool = True) -> list: """Get the list of tools available for web research. @@ -140,11 +176,7 @@ def get_web_research_tools(expert_enabled: bool = True) -> list: Returns: list: List of tools configured for web research """ - tools = [ - web_search_tavily, - emit_research_notes, - task_completed - ] + tools = [web_search_tavily, emit_research_notes, task_completed] if expert_enabled: tools.append(emit_expert_context) @@ -152,7 +184,10 @@ def get_web_research_tools(expert_enabled: bool = True) -> list: return tools -def get_chat_tools(expert_enabled: bool = True, web_research_enabled: bool = False) -> list: + +def get_chat_tools( + expert_enabled: bool = True, web_research_enabled: bool = False +) -> list: """Get the list of tools available in chat mode. Chat mode includes research and implementation capabilities but excludes @@ -169,7 +204,7 @@ def get_chat_tools(expert_enabled: bool = True, web_research_enabled: bool = Fal emit_key_facts, delete_key_facts, delete_key_snippets, - deregister_related_files + deregister_related_files, ] if web_research_enabled: diff --git a/ra_aid/tools/__init__.py b/ra_aid/tools/__init__.py index ccd1334..a39c63b 100644 --- a/ra_aid/tools/__init__.py +++ b/ra_aid/tools/__init__.py @@ -1,51 +1,62 @@ +from .expert import ask_expert, emit_expert_context +from .file_str_replace import file_str_replace +from .fuzzy_find import fuzzy_find_project_files +from .human import ask_human +from .list_directory import list_directory_tree +from .memory import ( + delete_key_facts, + delete_key_snippets, + delete_tasks, + deregister_related_files, + emit_key_facts, + emit_key_snippets, + emit_plan, + emit_related_files, + emit_research_notes, + emit_task, + get_memory_value, + plan_implementation_completed, + request_implementation, + swap_task_order, + task_completed, +) +from .programmer import run_programming_task +from .read_file import read_file_tool +from .research import existing_project_detected, monorepo_detected, ui_detected +from .ripgrep import ripgrep_search from .shell import run_shell_command from .web_search_tavily import web_search_tavily -from .research import monorepo_detected, existing_project_detected, ui_detected -from .human import ask_human -from .programmer import run_programming_task -from .expert import ask_expert, emit_expert_context -from .read_file import read_file_tool -from .file_str_replace import file_str_replace from .write_file import write_file_tool -from .fuzzy_find import fuzzy_find_project_files -from .list_directory import list_directory_tree -from .ripgrep import ripgrep_search -from .memory import ( - delete_tasks, emit_research_notes, emit_plan, emit_task, get_memory_value, emit_key_facts, - request_implementation, delete_key_facts, - emit_key_snippets, delete_key_snippets, emit_related_files, swap_task_order, task_completed, - plan_implementation_completed, deregister_related_files -) __all__ = [ - 'ask_expert', - 'delete_key_facts', - 'delete_key_snippets', - 'web_search_tavily', - 'deregister_related_files', - 'emit_expert_context', - 'emit_key_facts', - 'emit_key_snippets', - 'emit_plan', - 'emit_related_files', - 'emit_research_notes', - 'emit_task', - 'fuzzy_find_project_files', - 'get_memory_value', - 'list_directory_tree', - 'read_file_tool', - 'request_implementation', - 'run_programming_task', - 'run_shell_command', - 'write_file_tool', - 'ripgrep_search', - 'file_str_replace', - 'delete_tasks', - 'swap_task_order', - 'monorepo_detected', - 'existing_project_detected', - 'ui_detected', - 'ask_human', - 'task_completed', - 'plan_implementation_completed' + "ask_expert", + "delete_key_facts", + "delete_key_snippets", + "web_search_tavily", + "deregister_related_files", + "emit_expert_context", + "emit_key_facts", + "emit_key_snippets", + "emit_plan", + "emit_related_files", + "emit_research_notes", + "emit_task", + "fuzzy_find_project_files", + "get_memory_value", + "list_directory_tree", + "read_file_tool", + "request_implementation", + "run_programming_task", + "run_shell_command", + "write_file_tool", + "ripgrep_search", + "file_str_replace", + "delete_tasks", + "swap_task_order", + "monorepo_detected", + "existing_project_detected", + "ui_detected", + "ask_human", + "task_completed", + "plan_implementation_completed", ] diff --git a/ra_aid/tools/agent.py b/ra_aid/tools/agent.py index 4f3fc5d..1872d9c 100644 --- a/ra_aid/tools/agent.py +++ b/ra_aid/tools/agent.py @@ -1,18 +1,20 @@ """Tools for spawning and managing sub-agents.""" +from typing import Any, Dict, List, Union + from langchain_core.tools import tool -from typing import Dict, Any, Union, List -from typing_extensions import TypeAlias -from ..agent_utils import AgentInterrupt -from ra_aid.exceptions import AgentInterrupt -ResearchResult = Dict[str, Union[str, bool, Dict[int, Any], List[Any], None]] from rich.console import Console -from ra_aid.tools.memory import _global_memory + from ra_aid.console.formatting import print_error -from .memory import get_memory_value, get_related_files, get_work_log -from .human import ask_human -from ..llm import initialize_llm +from ra_aid.exceptions import AgentInterrupt +from ra_aid.tools.memory import _global_memory + from ..console import print_task_header +from ..llm import initialize_llm +from .human import ask_human +from .memory import get_memory_value, get_related_files, get_work_log + +ResearchResult = Dict[str, Union[str, bool, Dict[int, Any], List[Any], None]] CANCELLED_BY_USER_REASON = "The operation was explicitly cancelled by the user. This typically is an indication that the action requested was not aligned with the user request." @@ -20,22 +22,26 @@ RESEARCH_AGENT_RECURSION_LIMIT = 3 console = Console() + @tool("request_research") def request_research(query: str) -> ResearchResult: """Spawn a research-only agent to investigate the given query. - This function creates a new research agent to investigate the given query. It includes + This function creates a new research agent to investigate the given query. It includes recursion depth limiting to prevent infinite recursive research calls. Args: query: The research question or project description """ # Initialize model from config - config = _global_memory.get('config', {}) - model = initialize_llm(config.get('provider', 'anthropic'), config.get('model', 'claude-3-5-sonnet-20241022')) - + config = _global_memory.get("config", {}) + model = initialize_llm( + config.get("provider", "anthropic"), + config.get("model", "claude-3-5-sonnet-20241022"), + ) + # Check recursion depth - current_depth = _global_memory.get('agent_depth', 0) + current_depth = _global_memory.get("agent_depth", 0) if current_depth >= RESEARCH_AGENT_RECURSION_LIMIT: print_error("Maximum research recursion depth reached") return { @@ -45,23 +51,24 @@ def request_research(query: str) -> ResearchResult: "research_notes": get_memory_value("research_notes"), "key_snippets": get_memory_value("key_snippets"), "success": False, - "reason": "max_depth_exceeded" + "reason": "max_depth_exceeded", } success = True reason = None - + try: # Run research agent from ..agent_utils import run_research_agent - result = run_research_agent( + + _result = run_research_agent( query, model, expert_enabled=True, research_only=True, - hil=config.get('hil', False), + hil=config.get("hil", False), console_message=query, - config=config + config=config, ) except AgentInterrupt: print() @@ -76,14 +83,17 @@ def request_research(query: str) -> ResearchResult: reason = f"error: {str(e)}" finally: # Get completion message if available - completion_message = _global_memory.get('completion_message', 'Task was completed successfully.' if success else None) - + completion_message = _global_memory.get( + "completion_message", + "Task was completed successfully." if success else None, + ) + work_log = get_work_log() - + # Clear completion state from global memory - _global_memory['completion_message'] = '' - _global_memory['task_completed'] = False - + _global_memory["completion_message"] = "" + _global_memory["task_completed"] = False + response_data = { "completion_message": completion_message, "key_facts": get_memory_value("key_facts"), @@ -91,36 +101,41 @@ def request_research(query: str) -> ResearchResult: "research_notes": get_memory_value("research_notes"), "key_snippets": get_memory_value("key_snippets"), "success": success, - "reason": reason + "reason": reason, } if work_log is not None: response_data["work_log"] = work_log return response_data + @tool("request_web_research") def request_web_research(query: str) -> ResearchResult: """Spawn a web research agent to investigate the given query using web search. - + Args: query: The research question or project description """ # Initialize model from config - config = _global_memory.get('config', {}) - model = initialize_llm(config.get('provider', 'anthropic'), config.get('model', 'claude-3-5-sonnet-20241022')) - + config = _global_memory.get("config", {}) + model = initialize_llm( + config.get("provider", "anthropic"), + config.get("model", "claude-3-5-sonnet-20241022"), + ) + success = True reason = None - + try: # Run web research agent from ..agent_utils import run_web_research_agent - result = run_web_research_agent( + + _result = run_web_research_agent( query, model, expert_enabled=True, - hil=config.get('hil', False), + hil=config.get("hil", False), console_message=query, - config=config + config=config, ) except AgentInterrupt: print() @@ -135,52 +150,60 @@ def request_web_research(query: str) -> ResearchResult: reason = f"error: {str(e)}" finally: # Get completion message if available - completion_message = _global_memory.get('completion_message', 'Task was completed successfully.' if success else None) - + completion_message = _global_memory.get( + "completion_message", + "Task was completed successfully." if success else None, + ) + work_log = get_work_log() - + # Clear completion state from global memory - _global_memory['completion_message'] = '' - _global_memory['task_completed'] = False - + _global_memory["completion_message"] = "" + _global_memory["task_completed"] = False + response_data = { "completion_message": completion_message, "key_snippets": get_memory_value("key_snippets"), "research_notes": get_memory_value("research_notes"), "success": success, - "reason": reason + "reason": reason, } if work_log is not None: response_data["work_log"] = work_log return response_data + @tool("request_research_and_implementation") def request_research_and_implementation(query: str) -> Dict[str, Any]: """Spawn a research agent to investigate and implement the given query. - + If you are calling this on behalf of a user request, you must *faithfully* represent all info the user gave you, sometimes even to the point of repeating the user query verbatim. - + Args: query: The research question or project description """ # Initialize model from config - config = _global_memory.get('config', {}) - model = initialize_llm(config.get('provider', 'anthropic'), config.get('model', 'claude-3-5-sonnet-20241022')) - + config = _global_memory.get("config", {}) + model = initialize_llm( + config.get("provider", "anthropic"), + config.get("model", "claude-3-5-sonnet-20241022"), + ) + try: # Run research agent from ..agent_utils import run_research_agent - result = run_research_agent( + + _result = run_research_agent( query, model, expert_enabled=True, research_only=False, - hil=config.get('hil', False), + hil=config.get("hil", False), console_message=query, - config=config + config=config, ) - + success = True reason = None except AgentInterrupt: @@ -194,16 +217,18 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]: console.print(f"\n[red]Error during research: {str(e)}[/red]") success = False reason = f"error: {str(e)}" - + # Get completion message if available - completion_message = _global_memory.get('completion_message', 'Task was completed successfully.' if success else None) - + completion_message = _global_memory.get( + "completion_message", "Task was completed successfully." if success else None + ) + work_log = get_work_log() - + # Clear completion state from global memory - _global_memory['completion_message'] = '' - _global_memory['task_completed'] = False - _global_memory['plan_completed'] = False + _global_memory["completion_message"] = "" + _global_memory["task_completed"] = False + _global_memory["plan_completed"] = False response_data = { "completion_message": completion_message, @@ -212,43 +237,50 @@ def request_research_and_implementation(query: str) -> Dict[str, Any]: "research_notes": get_memory_value("research_notes"), "key_snippets": get_memory_value("key_snippets"), "success": success, - "reason": reason + "reason": reason, } if work_log is not None: response_data["work_log"] = work_log return response_data + @tool("request_task_implementation") def request_task_implementation(task_spec: str) -> Dict[str, Any]: """Spawn an implementation agent to execute the given task. - + Args: task_spec: REQUIRED The full task specification (markdown format, typically one part of the overall plan) """ # Initialize model from config - config = _global_memory.get('config', {}) - model = initialize_llm(config.get('provider', 'anthropic'), config.get('model', 'claude-3-5-sonnet-20241022')) - + config = _global_memory.get("config", {}) + model = initialize_llm( + config.get("provider", "anthropic"), + config.get("model", "claude-3-5-sonnet-20241022"), + ) + # Get required parameters - tasks = [_global_memory['tasks'][task_id] for task_id in sorted(_global_memory['tasks'])] - plan = _global_memory.get('plan', '') - related_files = list(_global_memory['related_files'].values()) - + tasks = [ + _global_memory["tasks"][task_id] for task_id in sorted(_global_memory["tasks"]) + ] + plan = _global_memory.get("plan", "") + related_files = list(_global_memory["related_files"].values()) + try: print_task_header(task_spec) # Run implementation agent from ..agent_utils import run_task_implementation_agent - result = run_task_implementation_agent( - base_task=_global_memory.get('base_task', ''), + + _result = run_task_implementation_agent( + base_task=_global_memory.get("base_task", ""), tasks=tasks, task=task_spec, - plan=plan, + plan=plan, related_files=related_files, model=model, expert_enabled=True, - config=config + config=config, ) - + success = True reason = None except AgentInterrupt: @@ -262,51 +294,58 @@ def request_task_implementation(task_spec: str) -> Dict[str, Any]: print_error(f"Error during task implementation: {str(e)}") success = False reason = f"error: {str(e)}" - + # Get completion message if available - completion_message = _global_memory.get('completion_message', 'Task was completed successfully.' if success else None) - + completion_message = _global_memory.get( + "completion_message", "Task was completed successfully." if success else None + ) + # Get and reset work log if at root depth work_log = get_work_log() - + # Clear completion state from global memory - _global_memory['completion_message'] = '' - _global_memory['task_completed'] = False - + _global_memory["completion_message"] = "" + _global_memory["task_completed"] = False + response_data = { "key_facts": get_memory_value("key_facts"), "related_files": get_related_files(), "key_snippets": get_memory_value("key_snippets"), "completion_message": completion_message, "success": success, - "reason": reason + "reason": reason, } if work_log is not None: response_data["work_log"] = work_log return response_data + @tool("request_implementation") def request_implementation(task_spec: str) -> Dict[str, Any]: """Spawn a planning agent to create an implementation plan for the given task. - + Args: task_spec: The task specification to plan implementation for """ # Initialize model from config - config = _global_memory.get('config', {}) - model = initialize_llm(config.get('provider', 'anthropic'), config.get('model', 'claude-3-5-sonnet-20241022')) - + config = _global_memory.get("config", {}) + model = initialize_llm( + config.get("provider", "anthropic"), + config.get("model", "claude-3-5-sonnet-20241022"), + ) + try: # Run planning agent from ..agent_utils import run_planning_agent - result = run_planning_agent( + + _result = run_planning_agent( task_spec, model, config=config, expert_enabled=True, - hil=config.get('hil', False) + hil=config.get("hil", False), ) - + success = True reason = None except AgentInterrupt: @@ -320,25 +359,27 @@ def request_implementation(task_spec: str) -> Dict[str, Any]: print_error(f"Error during planning: {str(e)}") success = False reason = f"error: {str(e)}" - + # Get completion message if available - completion_message = _global_memory.get('completion_message', 'Task was completed successfully.' if success else None) - + completion_message = _global_memory.get( + "completion_message", "Task was completed successfully." if success else None + ) + # Get and reset work log if at root depth work_log = get_work_log() - + # Clear completion state from global memory - _global_memory['completion_message'] = '' - _global_memory['task_completed'] = False - _global_memory['plan_completed'] = False - + _global_memory["completion_message"] = "" + _global_memory["task_completed"] = False + _global_memory["plan_completed"] = False + response_data = { "completion_message": completion_message, "key_facts": get_memory_value("key_facts"), "related_files": get_related_files(), "key_snippets": get_memory_value("key_snippets"), "success": success, - "reason": reason + "reason": reason, } if work_log is not None: response_data["work_log"] = work_log diff --git a/ra_aid/tools/expert.py b/ra_aid/tools/expert.py index 4ec370d..d760a7d 100644 --- a/ra_aid/tools/expert.py +++ b/ra_aid/tools/expert.py @@ -1,34 +1,45 @@ -from typing import List import os +from typing import List + from langchain_core.tools import tool from rich.console import Console -from rich.panel import Panel from rich.markdown import Markdown +from rich.panel import Panel + from ..llm import initialize_expert_llm -from .memory import get_memory_value, _global_memory +from .memory import _global_memory, get_memory_value console = Console() _model = None + def get_model(): global _model try: if _model is None: - provider = _global_memory['config']['expert_provider'] or 'openai' - model = _global_memory['config']['expert_model'] or 'o1' + provider = _global_memory["config"]["expert_provider"] or "openai" + model = _global_memory["config"]["expert_model"] or "o1" _model = initialize_expert_llm(provider, model) except Exception as e: _model = None - console.print(Panel(f"Failed to initialize expert model: {e}", title="Error", border_style="red")) + console.print( + Panel( + f"Failed to initialize expert model: {e}", + title="Error", + border_style="red", + ) + ) raise return _model + # Keep track of context globally expert_context = { - 'text': [], # Additional textual context - 'files': [] # File paths to include + "text": [], # Additional textual context + "files": [], # File paths to include } + @tool("emit_expert_context") def emit_expert_context(context: str) -> str: """Add context for the next expert question. @@ -42,25 +53,26 @@ def emit_expert_context(context: str) -> str: You must give the complete contents. Expert context will be reset after the ask_expert tool is called. - + Args: context: The context to add """ - expert_context['text'].append(context) - + expert_context["text"].append(context) + # Create and display status panel panel_content = f"Added expert context ({len(context)} characters)" console.print(Panel(panel_content, title="Expert Context", border_style="blue")) - - return f"Context added." + + return "Context added." + def read_files_with_limit(file_paths: List[str], max_lines: int = 10000) -> str: """Read multiple files and concatenate contents, stopping at line limit. - + Args: file_paths: List of file paths to read max_lines: Maximum total lines to read (default: 10000) - + Note: - Each file's contents will be prefaced with its path as a header - Stops reading files when max_lines limit is reached @@ -68,46 +80,50 @@ def read_files_with_limit(file_paths: List[str], max_lines: int = 10000) -> str: """ total_lines = 0 contents = [] - + for path in file_paths: try: if not os.path.exists(path): console.print(f"Warning: File not found: {path}", style="yellow") continue - - with open(path, 'r', encoding='utf-8') as f: + + with open(path, "r", encoding="utf-8") as f: file_content = [] for i, line in enumerate(f): if total_lines + i >= max_lines: - file_content.append(f"\n... truncated after {max_lines} lines ...") + file_content.append( + f"\n... truncated after {max_lines} lines ..." + ) break file_content.append(line) - + if file_content: - contents.append(f'\n## File: {path}\n') - contents.append(''.join(file_content)) + contents.append(f"\n## File: {path}\n") + contents.append("".join(file_content)) total_lines += len(file_content) - + except Exception as e: console.print(f"Error reading file {path}: {str(e)}", style="red") continue - - return ''.join(contents) + + return "".join(contents) + def read_related_files(file_paths: List[str]) -> str: """Read the provided files and return their contents. - + Args: file_paths: List of file paths to read - + Returns: String containing concatenated file contents, or empty string if no paths """ if not file_paths: - return '' - + return "" + return read_files_with_limit(file_paths, max_lines=10000) + @tool("ask_expert") def ask_expert(question: str) -> str: """Ask a question to an expert AI model. @@ -126,60 +142,60 @@ def ask_expert(question: str) -> str: The expert can be prone to overthinking depending on what and how you ask it. """ global expert_context - + # Get all content first - file_paths = list(_global_memory['related_files'].values()) + file_paths = list(_global_memory["related_files"].values()) related_contents = read_related_files(file_paths) - key_snippets = get_memory_value('key_snippets') - key_facts = get_memory_value('key_facts') - research_notes = get_memory_value('research_notes') - + key_snippets = get_memory_value("key_snippets") + key_facts = get_memory_value("key_facts") + research_notes = get_memory_value("research_notes") + # Build display query (just question) display_query = "# Question\n" + question - + # Show only question in panel - console.print(Panel( - Markdown(display_query), - title="πŸ€” Expert Query", - border_style="yellow" - )) - + console.print( + Panel(Markdown(display_query), title="πŸ€” Expert Query", border_style="yellow") + ) + # Clear context after panel display - expert_context['text'].clear() - expert_context['files'].clear() - + expert_context["text"].clear() + expert_context["files"].clear() + # Build full query in specified order query_parts = [] - + if related_contents: - query_parts.extend(['# Related Files', related_contents]) - + query_parts.extend(["# Related Files", related_contents]) + if related_contents: - query_parts.extend(['# Research Notes', research_notes]) - + query_parts.extend(["# Research Notes", research_notes]) + if key_snippets and len(key_snippets) > 0: - query_parts.extend(['# Key Snippets', key_snippets]) - + query_parts.extend(["# Key Snippets", key_snippets]) + if key_facts and len(key_facts) > 0: - query_parts.extend(['# Key Facts About This Project', key_facts]) - - if expert_context['text']: - query_parts.extend(['\n# Additional Context', '\n'.join(expert_context['text'])]) - - query_parts.extend(['# Question', question]) - query_parts.extend(['\n # Addidional Requirements', "Do not expand the scope unnecessarily."]) - + query_parts.extend(["# Key Facts About This Project", key_facts]) + + if expert_context["text"]: + query_parts.extend( + ["\n# Additional Context", "\n".join(expert_context["text"])] + ) + + query_parts.extend(["# Question", question]) + query_parts.extend( + ["\n # Addidional Requirements", "Do not expand the scope unnecessarily."] + ) + # Join all parts - full_query = '\n'.join(query_parts) - + full_query = "\n".join(query_parts) + # Get response using full query response = get_model().invoke(full_query) - + # Format and display response - console.print(Panel( - Markdown(response.content), - title="Expert Response", - border_style="blue" - )) - + console.print( + Panel(Markdown(response.content), title="Expert Response", border_style="blue") + ) + return response.content diff --git a/ra_aid/tools/file_str_replace.py b/ra_aid/tools/file_str_replace.py index af430e6..ba596e9 100644 --- a/ra_aid/tools/file_str_replace.py +++ b/ra_aid/tools/file_str_replace.py @@ -1,17 +1,20 @@ -from langchain_core.tools import tool -from typing import Dict from pathlib import Path +from typing import Dict + +from langchain_core.tools import tool from rich.panel import Panel + from ra_aid.console import console from ra_aid.console.formatting import print_error + def truncate_display_str(s: str, max_length: int = 30) -> str: """Truncate a string for display purposes if it exceeds max length. - + Args: s: String to truncate max_length: Maximum length before truncating - + Returns: Truncated string with ellipsis if needed """ @@ -19,34 +22,32 @@ def truncate_display_str(s: str, max_length: int = 30) -> str: return s return s[:max_length] + "..." + def format_string_for_display(s: str, threshold: int = 30) -> str: """Format a string for display, showing either quoted string or length. - + Args: s: String to format threshold: Max length before switching to character count display - + Returns: Formatted string for display """ if len(s) <= threshold: return f"'{s}'" - return f'[{len(s)} characters]' + return f"[{len(s)} characters]" + @tool -def file_str_replace( - filepath: str, - old_str: str, - new_str: str -) -> Dict[str, any]: +def file_str_replace(filepath: str, old_str: str, new_str: str) -> Dict[str, any]: """Replace an exact string match in a file with a new string. Only performs replacement if the old string appears exactly once. - + Args: filepath: Path to the file to modify old_str: Exact string to replace new_str: String to replace with - + Returns: Dict containing: - success: Whether the operation succeeded @@ -58,10 +59,10 @@ def file_str_replace( msg = f"File not found: {filepath}" print_error(msg) return {"success": False, "message": msg} - + content = path.read_text() count = content.count(old_str) - + if count == 0: msg = f"String not found: {truncate_display_str(old_str)}" print_error(msg) @@ -70,20 +71,22 @@ def file_str_replace( msg = f"String appears {count} times - must be unique" print_error(msg) return {"success": False, "message": msg} - + new_content = content.replace(old_str, new_str) path.write_text(new_content) - - console.print(Panel( - f"Replaced in {filepath}:\n{format_string_for_display(old_str)} β†’ {format_string_for_display(new_str)}", - title="βœ“ String Replaced", - border_style="bright_blue" - )) + + console.print( + Panel( + f"Replaced in {filepath}:\n{format_string_for_display(old_str)} β†’ {format_string_for_display(new_str)}", + title="βœ“ String Replaced", + border_style="bright_blue", + ) + ) return { "success": True, - "message": f"Successfully replaced '{old_str}' with '{new_str}' in {filepath}" + "message": f"Successfully replaced '{old_str}' with '{new_str}' in {filepath}", } - + except Exception as e: msg = f"Error: {str(e)}" print_error(msg) diff --git a/ra_aid/tools/fuzzy_find.py b/ra_aid/tools/fuzzy_find.py index 1a1dc6e..0b0deee 100644 --- a/ra_aid/tools/fuzzy_find.py +++ b/ra_aid/tools/fuzzy_find.py @@ -1,23 +1,25 @@ -from typing import List, Tuple import fnmatch -from git import Repo +from typing import List, Tuple + from fuzzywuzzy import process +from git import Repo from langchain_core.tools import tool from rich.console import Console -from rich.panel import Panel from rich.markdown import Markdown +from rich.panel import Panel console = Console() DEFAULT_EXCLUDE_PATTERNS = [ - '*.pyc', - '__pycache__/*', - '.git/*', - '*.so', - '*.o', - '*.class' + "*.pyc", + "__pycache__/*", + ".git/*", + "*.so", + "*.o", + "*.class", ] + @tool def fuzzy_find_project_files( search_term: str, @@ -26,14 +28,14 @@ def fuzzy_find_project_files( threshold: int = 60, max_results: int = 10, include_paths: List[str] = None, - exclude_patterns: List[str] = None + exclude_patterns: List[str] = None, ) -> List[Tuple[str, int]]: """Fuzzy find files in a git repository matching the search term. - + This tool searches for files within a git repository using fuzzy string matching, allowing for approximate matches to the search term. It returns a list of matched files along with their match scores. - + Args: search_term: String to match against file paths repo_path: Path to git repository (defaults to current directory) @@ -41,10 +43,10 @@ def fuzzy_find_project_files( max_results: Maximum number of results to return (default: 10) include_paths: Optional list of path patterns to include in search exclude_patterns: Optional list of path patterns to exclude from search - + Returns: List of tuples containing (file_path, match_score) - + Raises: InvalidGitRepositoryError: If repo_path is not a git repository ValueError: If threshold is not between 0 and 100 @@ -52,65 +54,51 @@ def fuzzy_find_project_files( # Validate threshold if not 0 <= threshold <= 100: raise ValueError("Threshold must be between 0 and 100") - + # Handle empty search term as special case if not search_term: return [] # Initialize repo for normal search repo = Repo(repo_path) - + # Get all tracked files tracked_files = repo.git.ls_files().splitlines() - + # Get all untracked files untracked_files = repo.untracked_files - + # Combine file lists all_files = tracked_files + untracked_files - + # Apply include patterns if specified if include_paths: filtered_files = [] for pattern in include_paths: - filtered_files.extend( - f for f in all_files - if fnmatch.fnmatch(f, pattern) - ) + filtered_files.extend(f for f in all_files if fnmatch.fnmatch(f, pattern)) all_files = filtered_files - + # Apply exclude patterns patterns = DEFAULT_EXCLUDE_PATTERNS + (exclude_patterns or []) for pattern in patterns: - all_files = [ - f for f in all_files - if not fnmatch.fnmatch(f, pattern) - ] - + all_files = [f for f in all_files if not fnmatch.fnmatch(f, pattern)] + # Perform fuzzy matching - matches = process.extract( - search_term, - all_files, - limit=max_results - ) - + matches = process.extract(search_term, all_files, limit=max_results) + # Filter by threshold - filtered_matches = [ - (path, score) - for path, score in matches - if score >= threshold - ] + filtered_matches = [(path, score) for path, score in matches if score >= threshold] # Build info panel content info_sections = [] - + # Search parameters section params_section = [ "## Search Parameters", f"**Search Term**: `{search_term}`", f"**Repository**: `{repo_path}`", f"**Threshold**: {threshold}", - f"**Max Results**: {max_results}" + f"**Max Results**: {max_results}", ] if include_paths: params_section.append("\n**Include Patterns**:") @@ -126,7 +114,7 @@ def fuzzy_find_project_files( stats_section = [ "## Results Statistics", f"**Total Files Scanned**: {len(all_files)}", - f"**Matches Found**: {len(filtered_matches)}" + f"**Matches Found**: {len(filtered_matches)}", ] info_sections.append("\n".join(stats_section)) @@ -140,10 +128,12 @@ def fuzzy_find_project_files( info_sections.append("## Results\n*No matches found*") # Display the panel - console.print(Panel( - Markdown("\n\n".join(info_sections)), - title="πŸ” Fuzzy Find Results", - border_style="bright_blue" - )) - + console.print( + Panel( + Markdown("\n\n".join(info_sections)), + title="πŸ” Fuzzy Find Results", + border_style="bright_blue", + ) + ) + return filtered_matches diff --git a/ra_aid/tools/handle_user_defined_test_cmd_execution.py b/ra_aid/tools/handle_user_defined_test_cmd_execution.py index 35eb7da..92bd998 100644 --- a/ra_aid/tools/handle_user_defined_test_cmd_execution.py +++ b/ra_aid/tools/handle_user_defined_test_cmd_execution.py @@ -1,32 +1,43 @@ """Utilities for executing and managing user-defined test commands.""" -from typing import Dict, Any, Tuple import subprocess from dataclasses import dataclass +from typing import Any, Dict, Tuple + from rich.console import Console from rich.markdown import Markdown from rich.panel import Panel + +from ra_aid.logging_config import get_logger from ra_aid.tools.human import ask_human from ra_aid.tools.shell import run_shell_command -from ra_aid.logging_config import get_logger console = Console() logger = get_logger(__name__) + @dataclass class TestState: """State for test execution.""" + prompt: str test_attempts: int auto_test: bool should_break: bool = False + class TestCommandExecutor: """Class for executing and managing test commands.""" - - def __init__(self, config: Dict[str, Any], original_prompt: str, test_attempts: int = 0, auto_test: bool = False): + + def __init__( + self, + config: Dict[str, Any], + original_prompt: str, + test_attempts: int = 0, + auto_test: bool = False, + ): """Initialize the test command executor. - + Args: config: Configuration dictionary containing test settings original_prompt: The original prompt to append errors to @@ -38,7 +49,7 @@ class TestCommandExecutor: prompt=original_prompt, test_attempts=test_attempts, auto_test=auto_test, - should_break=False + should_break=False, ) self.max_retries = config.get("max_test_cmd_retries", 5) @@ -46,15 +57,19 @@ class TestCommandExecutor: """Display test failure message.""" console.print( Panel( - Markdown(f"Test failed. Attempt number {self.state.test_attempts} of {self.max_retries}. Retrying and informing of failure output"), + Markdown( + f"Test failed. Attempt number {self.state.test_attempts} of {self.max_retries}. Retrying and informing of failure output" + ), title="πŸ”Ž User Defined Test", - border_style="red bold" + border_style="red bold", ) ) - def handle_test_failure(self, original_prompt: str, test_result: Dict[str, Any]) -> None: + def handle_test_failure( + self, original_prompt: str, test_result: Dict[str, Any] + ) -> None: """Handle test command failure. - + Args: original_prompt: Original prompt text test_result: Test command result @@ -65,42 +80,48 @@ class TestCommandExecutor: def run_test_command(self, cmd: str, original_prompt: str) -> None: """Run test command and handle result. - + Args: cmd: Test command to execute original_prompt: Original prompt text """ - timeout = self.config.get('timeout', 30) + timeout = self.config.get("timeout", 30) try: logger.info(f"Executing test command: {cmd} with timeout {timeout}s") test_result = run_shell_command(cmd, timeout=timeout) self.state.test_attempts += 1 - + if not test_result["success"]: self.handle_test_failure(original_prompt, test_result) return - + self.state.should_break = True logger.info("Test command executed successfully") - - except subprocess.TimeoutExpired as e: + + except subprocess.TimeoutExpired: logger.warning(f"Test command timed out after {timeout}s: {cmd}") self.state.test_attempts += 1 - self.state.prompt = f"{original_prompt}. Previous attempt timed out after {timeout} seconds" + self.state.prompt = ( + f"{original_prompt}. Previous attempt timed out after {timeout} seconds" + ) self.display_test_failure() - + except subprocess.CalledProcessError as e: - logger.error(f"Test command failed with exit code {e.returncode}: {cmd}\nOutput: {e.output}") + logger.error( + f"Test command failed with exit code {e.returncode}: {cmd}\nOutput: {e.output}" + ) self.state.test_attempts += 1 self.state.prompt = f"{original_prompt}. Previous attempt failed with exit code {e.returncode}: {e.output}" self.display_test_failure() - + except Exception as e: logger.warning(f"Test command execution failed: {str(e)}") self.state.test_attempts += 1 self.state.should_break = True - def handle_user_response(self, response: str, cmd: str, original_prompt: str) -> None: + def handle_user_response( + self, response: str, cmd: str, original_prompt: str + ) -> None: """Handle user's response to test prompt. Args: response: User's response (y/n/a) @@ -108,22 +129,22 @@ class TestCommandExecutor: original_prompt: Original prompt text """ response = response.strip().lower() - + if response == "n": self.state.should_break = True return - + if response == "a": self.state.auto_test = True self.run_test_command(cmd, original_prompt) return - + if response == "y": self.run_test_command(cmd, original_prompt) def check_max_retries(self) -> bool: """Check if max retries reached. - + Returns: True if max retries reached """ @@ -134,7 +155,7 @@ class TestCommandExecutor: def execute(self) -> Tuple[bool, str, bool, int]: """Execute test command and handle retries. - + Returns: Tuple containing: - bool: Whether to break the retry loop @@ -144,23 +165,46 @@ class TestCommandExecutor: """ if not self.config.get("test_cmd"): self.state.should_break = True - return self.state.should_break, self.state.prompt, self.state.auto_test, self.state.test_attempts + return ( + self.state.should_break, + self.state.prompt, + self.state.auto_test, + self.state.test_attempts, + ) cmd = self.config["test_cmd"] if not self.state.auto_test: print() - response = ask_human.invoke({"question": "Would you like to run the test command? (y=yes, n=no, a=enable auto-test)"}) + response = ask_human.invoke( + { + "question": "Would you like to run the test command? (y=yes, n=no, a=enable auto-test)" + } + ) self.handle_user_response(response, cmd, self.state.prompt) else: if self.check_max_retries(): - logger.error(f"Maximum number of test retries ({self.max_retries}) reached. Stopping test execution.") - console.print(Panel(f"Maximum retries ({self.max_retries}) reached. Test execution stopped.", title="⚠️ Test Execution", border_style="yellow bold")) + logger.error( + f"Maximum number of test retries ({self.max_retries}) reached. Stopping test execution." + ) + console.print( + Panel( + f"Maximum retries ({self.max_retries}) reached. Test execution stopped.", + title="⚠️ Test Execution", + border_style="yellow bold", + ) + ) self.state.should_break = True else: self.run_test_command(cmd, self.state.prompt) - return self.state.should_break, self.state.prompt, self.state.auto_test, self.state.test_attempts + return ( + self.state.should_break, + self.state.prompt, + self.state.auto_test, + self.state.test_attempts, + ) + def execute_test_command( config: Dict[str, Any], @@ -169,13 +213,13 @@ def execute_test_command( auto_test: bool = False, ) -> Tuple[bool, str, bool, int]: """Execute a test command and handle retries. - + Args: config: Configuration dictionary containing test settings original_prompt: The original prompt to append errors to test_attempts: Current number of test attempts auto_test: Whether auto-test mode is enabled - + Returns: Tuple containing: - bool: Whether to break the retry loop @@ -184,4 +228,4 @@ def execute_test_command( - int: Updated test_attempts count """ executor = TestCommandExecutor(config, original_prompt, test_attempts, auto_test) - return executor.execute() \ No newline at end of file + return executor.execute() diff --git a/ra_aid/tools/human.py b/ra_aid/tools/human.py index f604642..60233a3 100644 --- a/ra_aid/tools/human.py +++ b/ra_aid/tools/human.py @@ -2,50 +2,54 @@ from langchain_core.tools import tool from prompt_toolkit import PromptSession from prompt_toolkit.key_binding import KeyBindings from rich.console import Console -from rich.panel import Panel from rich.markdown import Markdown +from rich.panel import Panel console = Console() + def create_keybindings(): """Create custom key bindings for Ctrl+D submission.""" bindings = KeyBindings() - @bindings.add('c-d') + @bindings.add("c-d") def submit(event): """Trigger submission when Ctrl+D is pressed.""" event.current_buffer.validate_and_handle() return bindings + @tool def ask_human(question: str) -> str: """Ask the human user a question with a nicely formatted display. - + Args: question: The question to ask the human user (supports markdown) - + Returns: The user's response as a string """ - console.print(Panel( - Markdown(question + "\n\n*Multiline input is supported; use Ctrl+D to submit. Use Ctrl+C to exit the program.*"), - title="πŸ’­ Question for Human", - border_style="yellow bold" - )) + console.print( + Panel( + Markdown( + question + + "\n\n*Multiline input is supported; use Ctrl+D to submit. Use Ctrl+C to exit the program.*" + ), + title="πŸ’­ Question for Human", + border_style="yellow bold", + ) + ) session = PromptSession( multiline=True, key_bindings=create_keybindings(), - prompt_continuation='. ', + prompt_continuation=". ", ) print() - - response = session.prompt( - "> ", - wrap_lines=True - ) + + response = session.prompt("> ", wrap_lines=True) print() return response diff --git a/ra_aid/tools/list_directory.py b/ra_aid/tools/list_directory.py index 118d7de..1cc7085 100644 --- a/ra_aid/tools/list_directory.py +++ b/ra_aid/tools/list_directory.py @@ -1,39 +1,45 @@ +import datetime +import fnmatch +from dataclasses import dataclass from pathlib import Path from typing import List, Optional -import datetime -from dataclasses import dataclass + import pathspec -from rich.tree import Tree -from rich.console import Console -from rich.panel import Panel -from rich.markdown import Markdown from langchain_core.tools import tool -import fnmatch +from rich.console import Console +from rich.markdown import Markdown +from rich.panel import Panel +from rich.tree import Tree console = Console() + @dataclass class DirScanConfig: """Configuration for directory scanning""" + max_depth: int follow_links: bool show_size: bool show_modified: bool exclude_patterns: List[str] + def format_size(size_bytes: int) -> str: """Format file size in human readable format""" - for unit in ['B', 'KB', 'MB', 'GB']: + for unit in ["B", "KB", "MB", "GB"]: if size_bytes < 1024: return f"{size_bytes:.1f}{unit}" size_bytes /= 1024 return f"{size_bytes:.1f}TB" + def format_time(timestamp: float) -> str: """Format timestamp as readable date""" dt = datetime.datetime.fromtimestamp(timestamp) return dt.strftime("%Y-%m-%d %H:%M") + # Default patterns to exclude DEFAULT_EXCLUDE_PATTERNS = [ ".*", # Hidden files @@ -54,16 +60,17 @@ DEFAULT_EXCLUDE_PATTERNS = [ "*.cache", # Cache files ] + def load_gitignore_patterns(path: Path) -> pathspec.PathSpec: """Load gitignore patterns from .gitignore file or use defaults. - + Args: path: Directory path to search for .gitignore - + Returns: PathSpec object configured with the loaded patterns """ - gitignore_path = path / '.gitignore' + gitignore_path = path / ".gitignore" patterns = [] def modify_path(p: str) -> str: @@ -93,26 +100,29 @@ def load_gitignore_patterns(path: Path) -> pathspec.PathSpec: for line in f if line.strip() and not line.startswith("#") ) - + # Add default patterns patterns.extend(DEFAULT_EXCLUDE_PATTERNS) - + return pathspec.PathSpec.from_lines(pathspec.patterns.GitWildMatchPattern, patterns) + def should_ignore(path: str, spec: pathspec.PathSpec) -> bool: """Check if a path should be ignored based on gitignore patterns""" return spec.match_file(path) + def should_exclude(name: str, patterns: List[str]) -> bool: """Check if a file/directory name matches any exclude patterns""" return any(fnmatch.fnmatch(name, pattern) for pattern in patterns) + def build_tree( path: Path, tree: Tree, config: DirScanConfig, current_depth: int = 0, - spec: Optional[pathspec.PathSpec] = None + spec: Optional[pathspec.PathSpec] = None, ) -> None: """Recursively build a Rich tree representation of the directory""" if current_depth >= config.max_depth: @@ -121,17 +131,17 @@ def build_tree( try: # Get sorted list of directory contents entries = sorted(path.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower())) - + for entry in entries: # Get relative path from root for pattern matching rel_path = entry.relative_to(path) - + # Skip if path matches exclude patterns if spec and should_ignore(str(rel_path), spec): continue if should_exclude(entry.name, config.exclude_patterns): continue - + # Skip if symlink and not following links if entry.is_symlink() and not config.follow_links: continue @@ -139,10 +149,8 @@ def build_tree( try: if entry.is_dir(): # Add directory node - branch = tree.add( - f"πŸ“ {entry.name}/" - ) - + branch = tree.add(f"πŸ“ {entry.name}/") + # Recursively process subdirectory build_tree(entry, branch, config, current_depth + 1, spec) else: @@ -152,19 +160,20 @@ def build_tree( meta.append(format_size(entry.stat().st_size)) if config.show_modified: meta.append(format_time(entry.stat().st_mtime)) - + label = entry.name if meta: label = f"{label} ({', '.join(meta)})" - + tree.add(label) - + except PermissionError: tree.add(f"πŸ”’ {entry.name} (Permission denied)") - + except PermissionError: tree.add("πŸ”’ (Permission denied)") + @tool def list_directory_tree( path: str = ".", @@ -173,10 +182,10 @@ def list_directory_tree( follow_links: bool = False, show_size: bool = False, # Default to not showing size show_modified: bool = False, # Default to not showing modified time - exclude_patterns: List[str] = None + exclude_patterns: List[str] = None, ) -> str: """List directory contents in a tree format with optional metadata. - + Args: path: Directory path to list max_depth: Maximum depth to traverse (default: 1 for no recursion) @@ -184,7 +193,7 @@ def list_directory_tree( show_size: Show file sizes (default: False) show_modified: Show last modified times (default: False) exclude_patterns: List of patterns to exclude (uses gitignore syntax) - + Returns: Rendered tree string """ @@ -196,7 +205,7 @@ def list_directory_tree( # Load .gitignore patterns if present spec = load_gitignore_patterns(root_path) - + # Create tree tree = Tree(f"πŸ“ {root_path}/") config = DirScanConfig( @@ -204,22 +213,24 @@ def list_directory_tree( follow_links=follow_links, show_size=show_size, show_modified=show_modified, - exclude_patterns=DEFAULT_EXCLUDE_PATTERNS + (exclude_patterns or []) + exclude_patterns=DEFAULT_EXCLUDE_PATTERNS + (exclude_patterns or []), ) - + # Build tree build_tree(root_path, tree, config, 0, spec) - + # Capture tree output with console.capture() as capture: console.print(tree) tree_str = capture.get() - + # Display panel - console.print(Panel( - Markdown(f"```\n{tree_str}\n```"), - title="πŸ“‚ Directory Tree", - border_style="bright_blue" - )) - + console.print( + Panel( + Markdown(f"```\n{tree_str}\n```"), + title="πŸ“‚ Directory Tree", + border_style="bright_blue", + ) + ) + return tree_str diff --git a/ra_aid/tools/memory.py b/ra_aid/tools/memory.py index dc1ea2f..fab91a3 100644 --- a/ra_aid/tools/memory.py +++ b/ra_aid/tools/memory.py @@ -1,10 +1,11 @@ import os -from typing import Dict, List, Any, Union, Optional, Set -from typing_extensions import TypedDict +from typing import Any, Dict, List, Optional, Set, Union + +from langchain_core.tools import tool from rich.console import Console from rich.markdown import Markdown from rich.panel import Panel -from langchain_core.tools import tool +from typing_extensions import TypedDict class WorkLogEntry(TypedDict): @@ -249,7 +250,7 @@ def emit_key_snippets(snippets: List[SnippetInfo]) -> str: # Format display text as markdown display_text = [ - f"**Source Location**:", + "**Source Location**:", f"- File: `{snippet_info['filepath']}`", f"- Line: `{snippet_info['line_number']}`", "", # Empty line before code block diff --git a/ra_aid/tools/programmer.py b/ra_aid/tools/programmer.py index 438d27f..a9e6b8a 100644 --- a/ra_aid/tools/programmer.py +++ b/ra_aid/tools/programmer.py @@ -1,42 +1,51 @@ import os -from typing import List, Dict, Union -from ra_aid.logging_config import get_logger -from ra_aid.tools.memory import _global_memory +from typing import Dict, List, Union + from langchain_core.tools import tool from rich.console import Console -from rich.panel import Panel from rich.markdown import Markdown +from rich.panel import Panel from rich.text import Text + +from ra_aid.logging_config import get_logger from ra_aid.proc.interactive import run_interactive_command from ra_aid.text.processing import truncate_output +from ra_aid.tools.memory import _global_memory console = Console() logger = get_logger(__name__) + @tool -def run_programming_task(instructions: str, files: List[str] = []) -> Dict[str, Union[str, int, bool]]: +def run_programming_task( + instructions: str, files: List[str] = [] +) -> Dict[str, Union[str, int, bool]]: """Assign a programming task to a human programmer. Use this instead of trying to write code to files yourself. -Before using this tool, ensure all related files have been emitted with emit_related_files. + Before using this tool, ensure all related files have been emitted with emit_related_files. -The programmer sees only what you provide, no conversation history. + The programmer sees only what you provide, no conversation history. -Give detailed instructions but do not write their code. + Give detailed instructions but do not write their code. -They are intelligent and can edit multiple files. + They are intelligent and can edit multiple files. -If new files are created, emit them after finishing. + If new files are created, emit them after finishing. -They can add/modify files, but not remove. Use run_shell_command to remove files. If referencing files you’ll delete, remove them after they finish. + They can add/modify files, but not remove. Use run_shell_command to remove files. If referencing files you’ll delete, remove them after they finish. -Args: - instructions: REQUIRED Programming task instructions (markdown format, use newlines and as many tokens as needed) - files: Optional; if not provided, uses related_files + Args: + instructions: REQUIRED Programming task instructions (markdown format, use newlines and as many tokens as needed) + files: Optional; if not provided, uses related_files -Returns: { "output": stdout+stderr, "return_code": 0 if success, "success": True/False } + Returns: { "output": stdout+stderr, "return_code": 0 if success, "success": True/False } """ # Get related files if no specific files provided - file_paths = list(_global_memory['related_files'].values()) if 'related_files' in _global_memory else [] + file_paths = ( + list(_global_memory["related_files"].values()) + if "related_files" in _global_memory + else [] + ) # Build command command = [ @@ -50,14 +59,14 @@ Returns: { "output": stdout+stderr, "return_code": 0 if success, "success": True ] # Add config file if specified - if 'config' in _global_memory and _global_memory['config'].get('aider_config'): - command.extend(['--config', _global_memory['config']['aider_config']]) + if "config" in _global_memory and _global_memory["config"].get("aider_config"): + command.extend(["--config", _global_memory["config"]["aider_config"]]) # if environment variable AIDER_FLAGS exists then parse - if 'AIDER_FLAGS' in os.environ: + if "AIDER_FLAGS" in os.environ: # wrap in try catch in case of any error and log the error try: - command.extend(parse_aider_flags(os.environ['AIDER_FLAGS'])) + command.extend(parse_aider_flags(os.environ["AIDER_FLAGS"])) except Exception as e: print(f"Error parsing AIDER_FLAGS: {e}") @@ -72,19 +81,21 @@ Returns: { "output": stdout+stderr, "return_code": 0 if success, "success": True command.extend(files_to_use) # Create a pretty display of what we're doing - task_display = [ - "## Instructions\n", - f"{instructions}\n" - ] + task_display = ["## Instructions\n", f"{instructions}\n"] if files_to_use: - task_display.extend([ - "\n## Files\n", - *[f"- `{file}`\n" for file in files_to_use] - ]) + task_display.extend( + ["\n## Files\n", *[f"- `{file}`\n" for file in files_to_use]] + ) markdown_content = "".join(task_display) - console.print(Panel(Markdown(markdown_content), title="πŸ€– Aider Task", border_style="bright_blue")) + console.print( + Panel( + Markdown(markdown_content), + title="πŸ€– Aider Task", + border_style="bright_blue", + ) + ) logger.debug(f"command: {command}") try: @@ -97,7 +108,7 @@ Returns: { "output": stdout+stderr, "return_code": 0 if success, "success": True return { "output": truncate_output(output.decode() if output else ""), "return_code": return_code, - "success": return_code == 0 + "success": return_code == 0, } except Exception as e: @@ -107,11 +118,8 @@ Returns: { "output": stdout+stderr, "return_code": 0 if success, "success": True error_text.append(str(e), style="red") console.print(error_text) - return { - "output": str(e), - "return_code": 1, - "success": False - } + return {"output": str(e), "return_code": 1, "success": False} + def parse_aider_flags(aider_flags: str) -> List[str]: """Parse a string of aider flags into a list of flags. @@ -140,5 +148,6 @@ def parse_aider_flags(aider_flags: str) -> List[str]: # Add '--' prefix if not present and filter out empty flags return [f"--{flag.lstrip('-')}" for flag in flags if flag.strip()] + # Export the functions -__all__ = ['run_programming_task'] +__all__ = ["run_programming_task"] diff --git a/ra_aid/tools/read_file.py b/ra_aid/tools/read_file.py index 420e87e..564bc79 100644 --- a/ra_aid/tools/read_file.py +++ b/ra_aid/tools/read_file.py @@ -1,10 +1,12 @@ -import os.path import logging +import os.path import time from typing import Dict + from langchain_core.tools import tool from rich.console import Console from rich.panel import Panel + from ra_aid.text.processing import truncate_output console = Console() @@ -12,11 +14,9 @@ console = Console() # Standard buffer size for file reading CHUNK_SIZE = 8192 + @tool -def read_file_tool( - filepath: str, - encoding: str = 'utf-8' -) -> Dict[str, str]: +def read_file_tool(filepath: str, encoding: str = "utf-8") -> Dict[str, str]: """Read and return the contents of a text file. Args: @@ -33,35 +33,39 @@ def read_file_tool( line_count = 0 total_bytes = 0 - with open(filepath, 'r', encoding=encoding) as f: + with open(filepath, "r", encoding=encoding) as f: while True: chunk = f.read(CHUNK_SIZE) if not chunk: break - + content.append(chunk) total_bytes += len(chunk) - line_count += chunk.count('\n') - - logging.debug(f"Read chunk: {len(chunk)} bytes, running total: {total_bytes} bytes") + line_count += chunk.count("\n") - full_content = ''.join(content) + logging.debug( + f"Read chunk: {len(chunk)} bytes, running total: {total_bytes} bytes" + ) + + full_content = "".join(content) elapsed = time.time() - start_time - + logging.debug(f"File read complete: {total_bytes} bytes in {elapsed:.2f}s") logging.debug(f"Pre-truncation stats: {total_bytes} bytes, {line_count} lines") - console.print(Panel( - f"Read {line_count} lines ({total_bytes} bytes) from {filepath} in {elapsed:.2f}s", - title="πŸ“„ File Read", - border_style="bright_blue" - )) - + console.print( + Panel( + f"Read {line_count} lines ({total_bytes} bytes) from {filepath} in {elapsed:.2f}s", + title="πŸ“„ File Read", + border_style="bright_blue", + ) + ) + # Truncate if needed truncated = truncate_output(full_content) if full_content else "" return {"content": truncated} - except Exception as e: + except Exception: elapsed = time.time() - start_time raise diff --git a/ra_aid/tools/reflection.py b/ra_aid/tools/reflection.py index 2059629..9fb68d7 100644 --- a/ra_aid/tools/reflection.py +++ b/ra_aid/tools/reflection.py @@ -7,7 +7,7 @@ This module provides utilities for: import inspect -__all__ = ['get_function_info'] +__all__ = ["get_function_info"] def get_function_info(func): @@ -32,5 +32,3 @@ def get_function_info(func): {docstring} \"\"\"""" return info - - diff --git a/ra_aid/tools/research.py b/ra_aid/tools/research.py index 12b96a9..888f55f 100644 --- a/ra_aid/tools/research.py +++ b/ra_aid/tools/research.py @@ -4,6 +4,7 @@ from rich.panel import Panel console = Console() + @tool("existing_project_detected") def existing_project_detected() -> dict: """ @@ -11,7 +12,7 @@ def existing_project_detected() -> dict: """ console.print(Panel("πŸ“ Existing Project Detected", style="bright_blue", padding=0)) return { - 'hint': ( + "hint": ( "You are working within an existing codebase that already has established patterns and standards. " "Integrate any new functionality by adhering to the project's conventions:\n\n" "- Carefully discover existing folder structure, naming conventions, and architecture.\n" @@ -23,6 +24,7 @@ def existing_project_detected() -> dict: ) } + @tool("monorepo_detected") def monorepo_detected() -> dict: """ @@ -30,7 +32,7 @@ def monorepo_detected() -> dict: """ console.print(Panel("πŸ“¦ Monorepo Detected", style="bright_blue", padding=0)) return { - 'hint': ( + "hint": ( "You are researching in a monorepo environment that manages multiple packages or services under one roof. " "Ensure new work fits cohesively within the broader structure:\n\n" "- Search all packages for shared libraries, utilities, and patterns, and reuse them to avoid redundancy.\n" @@ -45,6 +47,7 @@ def monorepo_detected() -> dict: ) } + @tool("ui_detected") def ui_detected() -> dict: """ @@ -52,7 +55,7 @@ def ui_detected() -> dict: """ console.print(Panel("🎯 UI Detected", style="bright_blue", padding=0)) return { - 'hint': ( + "hint": ( "You are working with a user interface component where established UI conventions, styles, and frameworks are likely in place. " "Any modifications or additions should blend seamlessly with the existing design and user experience:\n\n" "- Locate and note existing UI design conventions, including layout, typography, color schemes, and interaction patterns.\n" diff --git a/ra_aid/tools/ripgrep.py b/ra_aid/tools/ripgrep.py index cfa293f..1722290 100644 --- a/ra_aid/tools/ripgrep.py +++ b/ra_aid/tools/ripgrep.py @@ -1,27 +1,29 @@ -from typing import Dict, Union, List +from typing import Dict, List, Union + from langchain_core.tools import tool from rich.console import Console -from rich.panel import Panel from rich.markdown import Markdown +from rich.panel import Panel + from ra_aid.proc.interactive import run_interactive_command from ra_aid.text.processing import truncate_output console = Console() DEFAULT_EXCLUDE_DIRS = [ - '.git', - 'node_modules', - 'vendor', - '.venv', - '__pycache__', - '.cache', - 'dist', - 'build', - 'env', - '.env', - 'venv', - '.idea', - '.vscode' + ".git", + "node_modules", + "vendor", + ".venv", + "__pycache__", + ".cache", + "dist", + "build", + "env", + ".env", + "venv", + ".idea", + ".vscode", ] @@ -64,6 +66,7 @@ FILE_TYPE_MAP = { "psql": "postgres", } + @tool def ripgrep_search( pattern: str, @@ -72,7 +75,7 @@ def ripgrep_search( case_sensitive: bool = True, include_hidden: bool = False, follow_links: bool = False, - exclude_dirs: List[str] = None + exclude_dirs: List[str] = None, ) -> Dict[str, Union[str, int, bool]]: """Execute a ripgrep (rg) search with formatting and common options. @@ -91,17 +94,17 @@ def ripgrep_search( - success: Boolean indicating if search succeeded """ # Build rg command with options - cmd = ['rg', '--color', 'always'] - + cmd = ["rg", "--color", "always"] + if not case_sensitive: - cmd.append('-i') - + cmd.append("-i") + if include_hidden: - cmd.append('--hidden') - + cmd.append("--hidden") + if follow_links: - cmd.append('--follow') - + cmd.append("--follow") + if file_type: if FILE_TYPE_MAP.get(file_type): file_type = FILE_TYPE_MAP.get(file_type) @@ -110,20 +113,20 @@ def ripgrep_search( # Add exclusions exclusions = DEFAULT_EXCLUDE_DIRS + (exclude_dirs or []) for dir in exclusions: - cmd.extend(['--glob', f'!{dir}']) + cmd.extend(["--glob", f"!{dir}"]) # Add the search pattern cmd.append(pattern) # Build info sections for display info_sections = [] - + # Search parameters section params = [ "## Search Parameters", f"**Pattern**: `{pattern}`", f"**Case Sensitive**: {case_sensitive}", - f"**File Type**: {file_type or 'all'}" + f"**File Type**: {file_type or 'all'}", ] if include_hidden: params.append("**Including Hidden Files**: yes") @@ -136,24 +139,26 @@ def ripgrep_search( info_sections.append("\n".join(params)) # Execute command - console.print(Panel(Markdown(f"Searching for: **{pattern}**"), title="πŸ”Ž Ripgrep Search", border_style="bright_blue")) + console.print( + Panel( + Markdown(f"Searching for: **{pattern}**"), + title="πŸ”Ž Ripgrep Search", + border_style="bright_blue", + ) + ) try: print() output, return_code = run_interactive_command(cmd) print() decoded_output = output.decode() if output else "" - + return { "output": truncate_output(decoded_output), "return_code": return_code, - "success": return_code == 0 + "success": return_code == 0, } - + except Exception as e: error_msg = str(e) console.print(Panel(error_msg, title="❌ Error", border_style="red")) - return { - "output": error_msg, - "return_code": 1, - "success": False - } + return {"output": error_msg, "return_code": 1, "success": False} diff --git a/ra_aid/tools/shell.py b/ra_aid/tools/shell.py index fcc7114..271128a 100644 --- a/ra_aid/tools/shell.py +++ b/ra_aid/tools/shell.py @@ -1,22 +1,25 @@ from typing import Dict, Union + from langchain_core.tools import tool from rich.console import Console from rich.panel import Panel from rich.prompt import Prompt -from ra_aid.tools.memory import _global_memory + +from ra_aid.console.cowboy_messages import get_cowboy_message from ra_aid.proc.interactive import run_interactive_command from ra_aid.text.processing import truncate_output -from ra_aid.console.cowboy_messages import get_cowboy_message +from ra_aid.tools.memory import _global_memory console = Console() + @tool def run_shell_command(command: str) -> Dict[str, Union[str, int, bool]]: """Execute a shell command and return its output. Important notes: 1. Try to constrain/limit the output. Output processing is expensive, and infinite/looping output will cause us to fail. - 2. When using commands like 'find', 'grep', or similar recursive search tools, always exclude common + 2. When using commands like 'find', 'grep', or similar recursive search tools, always exclude common development directories and files that can cause excessive output or slow performance: - Version control: .git - Dependencies: node_modules, vendor, .venv @@ -28,8 +31,8 @@ def run_shell_command(command: str) -> Dict[str, Union[str, int, bool]]: 4. Add flags e.g. git --no-pager in order to reduce interaction required by the human. """ # Check if we need approval - cowboy_mode = _global_memory.get('config', {}).get('cowboy_mode', False) - + cowboy_mode = _global_memory.get("config", {}).get("cowboy_mode", False) + if cowboy_mode: console.print("") console.print(" " + get_cowboy_message()) @@ -37,7 +40,7 @@ def run_shell_command(command: str) -> Dict[str, Union[str, int, bool]]: # Show just the command in a simple panel console.print(Panel(command, title="🐚 Shell", border_style="bright_yellow")) - + if not cowboy_mode: choices = ["y", "n", "c"] response = Prompt.ask( @@ -45,36 +48,32 @@ def run_shell_command(command: str) -> Dict[str, Union[str, int, bool]]: choices=choices, default="y", show_choices=True, - show_default=True + show_default=True, ) - + if response == "n": print() return { "output": "Command execution cancelled by user", "return_code": 1, - "success": False + "success": False, } elif response == "c": - _global_memory['config']['cowboy_mode'] = True + _global_memory["config"]["cowboy_mode"] = True console.print("") console.print(" " + get_cowboy_message()) console.print("") - + try: print() - output, return_code = run_interactive_command(['/bin/bash', '-c', command]) + output, return_code = run_interactive_command(["/bin/bash", "-c", command]) print() return { "output": truncate_output(output.decode()) if output else "", "return_code": return_code, - "success": return_code == 0 + "success": return_code == 0, } except Exception as e: print() console.print(Panel(str(e), title="❌ Error", border_style="red")) - return { - "output": str(e), - "return_code": 1, - "success": False - } + return {"output": str(e), "return_code": 1, "success": False} diff --git a/ra_aid/tools/web_search_tavily.py b/ra_aid/tools/web_search_tavily.py index e7cc8b7..897df45 100644 --- a/ra_aid/tools/web_search_tavily.py +++ b/ra_aid/tools/web_search_tavily.py @@ -1,25 +1,29 @@ import os from typing import Dict -from tavily import TavilyClient + from langchain_core.tools import tool from rich.console import Console -from rich.panel import Panel from rich.markdown import Markdown +from rich.panel import Panel +from tavily import TavilyClient console = Console() + @tool def web_search_tavily(query: str) -> Dict: """ Perform a web search using Tavily API. - + Args: query: The search query string - + Returns: Dict containing search results from Tavily """ client = TavilyClient(api_key=os.environ["TAVILY_API_KEY"]) - console.print(Panel(Markdown(query), title="πŸ” Searching Tavily", border_style="bright_blue")) + console.print( + Panel(Markdown(query), title="πŸ” Searching Tavily", border_style="bright_blue") + ) search_result = client.search(query=query) return search_result diff --git a/ra_aid/tools/write_file.py b/ra_aid/tools/write_file.py index 29a6452..bffff18 100644 --- a/ra_aid/tools/write_file.py +++ b/ra_aid/tools/write_file.py @@ -1,19 +1,18 @@ -import os import logging +import os import time from typing import Dict + from langchain_core.tools import tool from rich.console import Console from rich.panel import Panel console = Console() + @tool def write_file_tool( - filepath: str, - content: str, - encoding: str = 'utf-8', - verbose: bool = True + filepath: str, content: str, encoding: str = "utf-8", verbose: bool = True ) -> Dict[str, any]: """Write content to a text file. @@ -40,7 +39,7 @@ def write_file_tool( "elapsed_time": 0, "error": None, "filepath": None, - "message": None + "message": None, } try: @@ -50,42 +49,48 @@ def write_file_tool( os.makedirs(dirpath, exist_ok=True) logging.debug(f"Starting to write file: {filepath}") - - with open(filepath, 'w', encoding=encoding) as f: + + with open(filepath, "w", encoding=encoding) as f: f.write(content) result["bytes_written"] = len(content.encode(encoding)) - + elapsed = time.time() - start_time result["elapsed_time"] = elapsed result["success"] = True result["filepath"] = filepath result["message"] = "Operation completed successfully" - logging.debug(f"File write complete: {result['bytes_written']} bytes in {elapsed:.2f}s") + logging.debug( + f"File write complete: {result['bytes_written']} bytes in {elapsed:.2f}s" + ) if verbose: - console.print(Panel( - f"Wrote {result['bytes_written']} bytes to {filepath} in {elapsed:.2f}s", - title="πŸ’Ύ File Write", - border_style="bright_green" - )) + console.print( + Panel( + f"Wrote {result['bytes_written']} bytes to {filepath} in {elapsed:.2f}s", + title="πŸ’Ύ File Write", + border_style="bright_green", + ) + ) except Exception as e: elapsed = time.time() - start_time error_msg = str(e) - + result["elapsed_time"] = elapsed result["error"] = error_msg if "embedded null byte" in error_msg.lower(): result["message"] = "Invalid file path: contains null byte character" else: result["message"] = error_msg - + if verbose: - console.print(Panel( - f"Failed to write {filepath}\nError: {error_msg}", - title="❌ File Write Error", - border_style="red" - )) + console.print( + Panel( + f"Failed to write {filepath}\nError: {error_msg}", + title="❌ File Write Error", + border_style="red", + ) + ) return result diff --git a/scripts/extract_changelog.py b/scripts/extract_changelog.py index 14dbb84..62d0952 100755 --- a/scripts/extract_changelog.py +++ b/scripts/extract_changelog.py @@ -6,8 +6,8 @@ Usage: python extract_changelog.py VERSION """ -import sys import re +import sys from pathlib import Path @@ -16,11 +16,11 @@ def extract_version_content(content: str, version: str) -> str: # Escape version for regex pattern version_escaped = re.escape(version) pattern = rf"## \[{version_escaped}\].*?(?=## \[|$)" - + match = re.search(pattern, content, re.DOTALL) if not match: raise ValueError(f"Version {version} not found in changelog") - + return match.group(0).strip() diff --git a/scripts/generate_swebench_dataset.py b/scripts/generate_swebench_dataset.py index 5bd0caf..4ba32d5 100755 --- a/scripts/generate_swebench_dataset.py +++ b/scripts/generate_swebench_dataset.py @@ -22,7 +22,7 @@ import subprocess import sys from datetime import datetime from pathlib import Path -from typing import Optional, Tuple, Dict, Any, List +from typing import Any, Dict, List, Optional, Tuple from git import Repo from rich.logging import RichHandler @@ -44,15 +44,13 @@ def setup_logging(log_dir: Path, verbose: bool = False) -> None: file_handler = logging.FileHandler(log_file) file_handler.setLevel(logging.DEBUG) file_formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) file_handler.setFormatter(file_formatter) root_logger.addHandler(file_handler) console_handler = RichHandler( - rich_tracebacks=True, - show_time=False, - show_path=False + rich_tracebacks=True, show_time=False, show_path=False ) console_handler.setLevel(logging.DEBUG if verbose else logging.INFO) root_logger.addHandler(console_handler) @@ -62,6 +60,7 @@ def load_dataset_safely() -> Optional[Any]: """Load SWE-bench Lite dataset with error handling.""" try: from datasets import load_dataset + dataset = load_dataset("princeton-nlp/SWE-bench_Lite") return dataset except Exception as e: @@ -122,18 +121,14 @@ def uv_run_raaid(repo_dir: Path, prompt: str) -> Optional[str]: streaming output directly to the console (capture_output=False). Returns the patch if successful, else None. """ - cmd = [ - "uv", "run", "ra-aid", - "--cowboy-mode", - "-m", prompt - ] + cmd = ["uv", "run", "ra-aid", "--cowboy-mode", "-m", prompt] # We are NOT capturing output, so it streams live: try: result = subprocess.run( cmd, cwd=repo_dir, text=True, - check=False, # We manually handle exit code + check=False, # We manually handle exit code ) if result.returncode != 0: logging.error("ra-aid returned non-zero exit code.") @@ -160,7 +155,7 @@ def get_git_patch(repo_dir: Path) -> Optional[str]: patch_text = repo.git.diff(unified=3) if not patch_text.strip(): return None - if not any(line.startswith('+') for line in patch_text.splitlines()): + if not any(line.startswith("+") for line in patch_text.splitlines()): return None return patch_text except Exception as e: @@ -214,7 +209,9 @@ def setup_venv_and_deps(repo_dir: Path, repo_name: str, force_venv: bool) -> Non uv_pip_install(repo_dir, ["-e", "."]) -def build_prompt(problem_statement: str, fail_tests: List[str], pass_tests: List[str]) -> str: +def build_prompt( + problem_statement: str, fail_tests: List[str], pass_tests: List[str] +) -> str: """ Construct the prompt text from problem_statement, FAIL_TO_PASS, PASS_TO_PASS. """ @@ -232,10 +229,7 @@ def build_prompt(problem_statement: str, fail_tests: List[str], pass_tests: List def process_instance( - instance: Dict[str, Any], - projects_dir: Path, - reuse_repo: bool, - force_venv: bool + instance: Dict[str, Any], projects_dir: Path, reuse_repo: bool, force_venv: bool ) -> Dict[str, Any]: """ Process a single dataset instance without a progress bar/spinner. @@ -291,7 +285,7 @@ def process_instance( return { "instance_id": inst_id, "model_patch": patch if patch else "", - "model_name_or_path": "ra-aid" + "model_name_or_path": "ra-aid", } except Exception as e: @@ -299,7 +293,7 @@ def process_instance( return { "instance_id": inst_id, "model_patch": "", - "model_name_or_path": "ra-aid" + "model_name_or_path": "ra-aid", } @@ -308,41 +302,33 @@ def main() -> None: description="Generate predictions for SWE-bench Lite using uv + ra-aid (no progress bar)." ) parser.add_argument( - "output_dir", - type=Path, - help="Directory to store prediction file" + "output_dir", type=Path, help="Directory to store prediction file" ) parser.add_argument( "--projects-dir", type=Path, required=True, - help="Directory where projects will be cloned." + help="Directory where projects will be cloned.", ) parser.add_argument( "--num-instances", type=int, default=None, - help="Number of instances to process (default: all)" + help="Number of instances to process (default: all)", ) parser.add_argument( "--reuse-repo", action="store_true", - help="If set, do not delete an existing repo directory. We'll reuse it." + help="If set, do not delete an existing repo directory. We'll reuse it.", ) parser.add_argument( "--force-venv", action="store_true", - help="If set, recreate the .venv even if it exists." - ) - parser.add_argument( - "--verbose", - action="store_true", - help="Enable verbose logging" + help="If set, recreate the .venv even if it exists.", ) + parser.add_argument("--verbose", action="store_true", help="Enable verbose logging") args = parser.parse_args() - from datasets import load_dataset - # Create base/log dirs and set up logging base_dir, log_dir = create_output_dirs() setup_logging(log_dir, args.verbose) @@ -373,7 +359,9 @@ def main() -> None: break logging.info(f"=== Instance {i+1}/{limit}, ID={inst.get('instance_id')} ===") - pred = process_instance(inst, args.projects_dir, args.reuse_repo, args.force_venv) + pred = process_instance( + inst, args.projects_dir, args.reuse_repo, args.force_venv + ) predictions.append(pred) # Save predictions @@ -389,6 +377,6 @@ if __name__ == "__main__": except KeyboardInterrupt: print("\nOperation cancelled by user.") sys.exit(1) - except Exception as e: + except Exception: logging.exception("Unhandled error occurred.") - sys.exit(1) \ No newline at end of file + sys.exit(1) diff --git a/tests/ra_aid/agents/test_ciayn_agent.py b/tests/ra_aid/agents/test_ciayn_agent.py index 5803412..87c80ee 100644 --- a/tests/ra_aid/agents/test_ciayn_agent.py +++ b/tests/ra_aid/agents/test_ciayn_agent.py @@ -1,8 +1,10 @@ +from unittest.mock import Mock + import pytest -from unittest.mock import Mock, patch -from langchain_core.messages import HumanMessage, AIMessage -from ra_aid.agents.ciayn_agent import CiaynAgent -from ra_aid.agents.ciayn_agent import validate_function_call_pattern +from langchain_core.messages import AIMessage, HumanMessage + +from ra_aid.agents.ciayn_agent import CiaynAgent, validate_function_call_pattern + @pytest.fixture def mock_model(): @@ -11,47 +13,48 @@ def mock_model(): model.invoke = Mock() return model + @pytest.fixture def agent(mock_model): """Create a CiaynAgent instance with mock model.""" tools = [] # Empty tools list for testing trimming functionality return CiaynAgent(mock_model, tools, max_history_messages=3) + def test_trim_chat_history_preserves_initial_messages(agent): """Test that initial messages are preserved during trimming.""" initial_messages = [ HumanMessage(content="Initial 1"), - AIMessage(content="Initial 2") + AIMessage(content="Initial 2"), ] chat_history = [ HumanMessage(content="Chat 1"), AIMessage(content="Chat 2"), HumanMessage(content="Chat 3"), - AIMessage(content="Chat 4") + AIMessage(content="Chat 4"), ] - + result = agent._trim_chat_history(initial_messages, chat_history) - + # Verify initial messages are preserved assert result[:2] == initial_messages # Verify only last 3 chat messages are kept (due to max_history_messages=3) assert len(result[2:]) == 3 assert result[2:] == chat_history[-3:] + def test_trim_chat_history_under_limit(agent): """Test trimming when chat history is under the maximum limit.""" initial_messages = [HumanMessage(content="Initial")] - chat_history = [ - HumanMessage(content="Chat 1"), - AIMessage(content="Chat 2") - ] - + chat_history = [HumanMessage(content="Chat 1"), AIMessage(content="Chat 2")] + result = agent._trim_chat_history(initial_messages, chat_history) - + # Verify no trimming occurred assert len(result) == 3 assert result == initial_messages + chat_history + def test_trim_chat_history_over_limit(agent): """Test trimming when chat history exceeds the maximum limit.""" initial_messages = [HumanMessage(content="Initial")] @@ -60,16 +63,17 @@ def test_trim_chat_history_over_limit(agent): AIMessage(content="Chat 2"), HumanMessage(content="Chat 3"), AIMessage(content="Chat 4"), - HumanMessage(content="Chat 5") + HumanMessage(content="Chat 5"), ] - + result = agent._trim_chat_history(initial_messages, chat_history) - + # Verify correct trimming assert len(result) == 4 # initial + max_history_messages assert result[0] == initial_messages[0] # Initial message preserved assert result[1:] == chat_history[-3:] # Last 3 chat messages kept + def test_trim_chat_history_empty_initial(agent): """Test trimming with empty initial messages.""" initial_messages = [] @@ -77,79 +81,83 @@ def test_trim_chat_history_empty_initial(agent): HumanMessage(content="Chat 1"), AIMessage(content="Chat 2"), HumanMessage(content="Chat 3"), - AIMessage(content="Chat 4") + AIMessage(content="Chat 4"), ] - + result = agent._trim_chat_history(initial_messages, chat_history) - + # Verify only last 3 messages are kept assert len(result) == 3 assert result == chat_history[-3:] + def test_trim_chat_history_empty_chat(agent): """Test trimming with empty chat history.""" initial_messages = [ HumanMessage(content="Initial 1"), - AIMessage(content="Initial 2") + AIMessage(content="Initial 2"), ] chat_history = [] - + result = agent._trim_chat_history(initial_messages, chat_history) - + # Verify initial messages are preserved and no trimming occurred assert result == initial_messages assert len(result) == 2 + def test_trim_chat_history_token_limit(): """Test trimming based on token limit.""" agent = CiaynAgent(Mock(), [], max_history_messages=10, max_tokens=20) - - initial_messages = [HumanMessage(content="Initial")] # ~2 tokens + + initial_messages = [HumanMessage(content="Initial")] # ~2 tokens chat_history = [ HumanMessage(content="A" * 40), # ~10 tokens - AIMessage(content="B" * 40), # ~10 tokens - HumanMessage(content="C" * 40) # ~10 tokens + AIMessage(content="B" * 40), # ~10 tokens + HumanMessage(content="C" * 40), # ~10 tokens ] - + result = agent._trim_chat_history(initial_messages, chat_history) - + # Should keep initial message (~2 tokens) and last message (~10 tokens) assert len(result) == 2 assert result[0] == initial_messages[0] assert result[1] == chat_history[-1] + def test_trim_chat_history_no_token_limit(): """Test trimming with no token limit set.""" agent = CiaynAgent(Mock(), [], max_history_messages=2, max_tokens=None) - + initial_messages = [HumanMessage(content="Initial")] chat_history = [ HumanMessage(content="A" * 1000), AIMessage(content="B" * 1000), - HumanMessage(content="C" * 1000) + HumanMessage(content="C" * 1000), ] - + result = agent._trim_chat_history(initial_messages, chat_history) - + # Should keep initial message and last 2 messages (max_history_messages=2) assert len(result) == 3 assert result[0] == initial_messages[0] assert result[1:] == chat_history[-2:] + def test_trim_chat_history_both_limits(): """Test trimming with both message count and token limits.""" agent = CiaynAgent(Mock(), [], max_history_messages=3, max_tokens=15) - - initial_messages = [HumanMessage(content="Init")] # ~1 token + + initial_messages = [HumanMessage(content="Init")] # ~1 token chat_history = [ HumanMessage(content="A" * 40), # ~10 tokens - AIMessage(content="B" * 40), # ~10 tokens + AIMessage(content="B" * 40), # ~10 tokens HumanMessage(content="C" * 40), # ~10 tokens - AIMessage(content="D" * 40) # ~10 tokens + AIMessage(content="D" * 40), # ~10 tokens ] - + result = agent._trim_chat_history(initial_messages, chat_history) - + # Should first apply message limit (keeping last 3) # Then token limit should further reduce to fit under 15 tokens assert len(result) == 2 # Initial message + 1 message under token limit @@ -158,46 +166,58 @@ def test_trim_chat_history_both_limits(): class TestFunctionCallValidation: - @pytest.mark.parametrize("test_input", [ - "basic_func()", - "func_with_arg(\"test\")", - "complex_func(1, \"two\", three)", - "nested_parens(func(\"test\"))", - "under_score()", - "with-dash()" - ]) + @pytest.mark.parametrize( + "test_input", + [ + "basic_func()", + 'func_with_arg("test")', + 'complex_func(1, "two", three)', + 'nested_parens(func("test"))', + "under_score()", + "with-dash()", + ], + ) def test_valid_function_calls(self, test_input): """Test function call patterns that should pass validation.""" assert not validate_function_call_pattern(test_input) - @pytest.mark.parametrize("test_input", [ - "", - "Invalid!function()", - "missing_parens", - "unmatched(parens))", - "multiple()calls()", - "no spaces()()" - ]) + @pytest.mark.parametrize( + "test_input", + [ + "", + "Invalid!function()", + "missing_parens", + "unmatched(parens))", + "multiple()calls()", + "no spaces()()", + ], + ) def test_invalid_function_calls(self, test_input): """Test function call patterns that should fail validation.""" assert validate_function_call_pattern(test_input) - @pytest.mark.parametrize("test_input", [ - " leading_space()", - "trailing_space() ", - "func (arg)", - "func( spaced args )" - ]) + @pytest.mark.parametrize( + "test_input", + [ + " leading_space()", + "trailing_space() ", + "func (arg)", + "func( spaced args )", + ], + ) def test_whitespace_handling(self, test_input): """Test whitespace variations in function calls.""" assert not validate_function_call_pattern(test_input) - @pytest.mark.parametrize("test_input", [ - """multiline( + @pytest.mark.parametrize( + "test_input", + [ + """multiline( arg )""", - "func(\n arg1,\n arg2\n)" - ]) + "func(\n arg1,\n arg2\n)", + ], + ) def test_multiline_responses(self, test_input): """Test function calls spanning multiple lines.""" assert not validate_function_call_pattern(test_input) diff --git a/tests/ra_aid/console/test_cowboy_messages.py b/tests/ra_aid/console/test_cowboy_messages.py index f023068..36db3ba 100644 --- a/tests/ra_aid/console/test_cowboy_messages.py +++ b/tests/ra_aid/console/test_cowboy_messages.py @@ -1,5 +1,5 @@ -import pytest -from ra_aid.console.cowboy_messages import get_cowboy_message, COWBOY_MESSAGES +from ra_aid.console.cowboy_messages import COWBOY_MESSAGES, get_cowboy_message + def test_get_cowboy_message_returns_string(): """Test that get_cowboy_message returns a non-empty string""" @@ -7,12 +7,14 @@ def test_get_cowboy_message_returns_string(): assert isinstance(message, str) assert len(message) > 0 + def test_cowboy_message_contains_emoji(): """Test that returned message contains at least one of the expected emojis""" message = get_cowboy_message() - expected_emojis = ['🀠', 'πŸ‘Ά', '😏'] + expected_emojis = ["🀠", "πŸ‘Ά", "😏"] assert any(emoji in message for emoji in expected_emojis) + def test_message_from_predefined_list(): """Test that returned message is from our predefined list""" message = get_cowboy_message() diff --git a/tests/ra_aid/proc/test_interactive.py b/tests/ra_aid/proc/test_interactive.py index 687b45c..a87e7a5 100644 --- a/tests/ra_aid/proc/test_interactive.py +++ b/tests/ra_aid/proc/test_interactive.py @@ -1,9 +1,10 @@ """Tests for the interactive subprocess module.""" import os -import sys -import pytest import tempfile + +import pytest + from ra_aid.proc.interactive import run_interactive_command @@ -16,7 +17,9 @@ def test_basic_command(): def test_shell_pipeline(): """Test running a shell pipeline command.""" - output, retcode = run_interactive_command(["/bin/bash", "-c", "echo 'hello world' | grep 'world'"]) + output, retcode = run_interactive_command( + ["/bin/bash", "-c", "echo 'hello world' | grep 'world'"] + ) assert b"world" in output assert retcode == 0 @@ -24,7 +27,9 @@ def test_shell_pipeline(): def test_stderr_capture(): """Test that stderr is properly captured in combined output.""" # Use a command that definitely writes to stderr - output, retcode = run_interactive_command(["/bin/bash", "-c", "ls /nonexistent/path"]) + output, retcode = run_interactive_command( + ["/bin/bash", "-c", "ls /nonexistent/path"] + ) assert b"No such file or directory" in output assert retcode != 0 # ls returns 0 upon success @@ -43,25 +48,32 @@ def test_empty_command(): def test_interactive_command(): """Test running an interactive command. - + This test verifies that output appears in real-time using process substitution. We use a command that prints to both stdout and stderr to verify capture.""" - output, retcode = run_interactive_command(["/bin/bash", "-c", "echo stdout; echo stderr >&2"]) + output, retcode = run_interactive_command( + ["/bin/bash", "-c", "echo stdout; echo stderr >&2"] + ) assert b"stdout" in output assert b"stderr" in output assert retcode == 0 + def test_large_output(): """Test handling of commands that produce large output.""" # Generate a large output with predictable content - cmd = "for i in {1..10000}; do echo \"Line $i of test output\"; done" + cmd = 'for i in {1..10000}; do echo "Line $i of test output"; done' output, retcode = run_interactive_command(["/bin/bash", "-c", cmd]) # Clean up specific artifacts (e.g., ^D) - output_cleaned = output.lstrip(b'^D') # Remove the leading ^D if present + output_cleaned = output.lstrip(b"^D") # Remove the leading ^D if present # Split and filter lines - lines = [line.strip() for line in output_cleaned.splitlines() if b"Script" not in line and line.strip()] + lines = [ + line.strip() + for line in output_cleaned.splitlines() + if b"Script" not in line and line.strip() + ] # Verify we got all 10000 lines assert len(lines) == 10000, f"Expected 10000 lines, but got {len(lines)}" @@ -78,14 +90,18 @@ def test_large_output(): def test_unicode_handling(): """Test handling of unicode characters.""" test_string = "Hello " - output, retcode = run_interactive_command(["/bin/bash", "-c", f"echo '{test_string}'"]) + output, retcode = run_interactive_command( + ["/bin/bash", "-c", f"echo '{test_string}'"] + ) assert test_string.encode() in output assert retcode == 0 def test_multiple_commands(): """Test running multiple commands in sequence.""" - output, retcode = run_interactive_command(["/bin/bash", "-c", "echo 'first'; echo 'second'"]) + output, retcode = run_interactive_command( + ["/bin/bash", "-c", "echo 'first'; echo 'second'"] + ) assert b"first" in output assert b"second" in output assert retcode == 0 @@ -94,18 +110,24 @@ def test_multiple_commands(): def test_cat_medium_file(): """Test that cat command properly captures output for medium-length files.""" # Create a temporary file with known content - with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: for i in range(500): f.write(f"This is test line {i}\n") temp_path = f.name try: - output, retcode = run_interactive_command(["/bin/bash", "-c", f"cat {temp_path}"]) + output, retcode = run_interactive_command( + ["/bin/bash", "-c", f"cat {temp_path}"] + ) # Split by newlines and filter out script header/footer lines - lines = [line for line in output.splitlines() if b"Script" not in line and line.strip()] + lines = [ + line + for line in output.splitlines() + if b"Script" not in line and line.strip() + ] assert len(lines) == 500 assert retcode == 0 - + # Verify content integrity by checking first and last lines assert b"This is test line 0" in lines[0] assert b"This is test line 499" in lines[-1] @@ -118,26 +140,31 @@ def test_realtime_output(): # Create a command that sleeps briefly between outputs cmd = "echo 'first'; sleep 0.1; echo 'second'; sleep 0.1; echo 'third'" output, retcode = run_interactive_command(["/bin/bash", "-c", cmd]) - + # Filter out script header/footer lines - lines = [line for line in output.splitlines() if b"Script" not in line and line.strip()] - + lines = [ + line for line in output.splitlines() if b"Script" not in line and line.strip() + ] + assert b"first" in lines[0] assert b"second" in lines[1] assert b"third" in lines[2] assert retcode == 0 + def test_tty_available(): """Test that commands have access to a TTY.""" # Run the tty command output, retcode = run_interactive_command(["/bin/bash", "-c", "tty"]) # Clean up specific artifacts (e.g., ^D) - output_cleaned = output.lstrip(b'^D') # Remove leading ^D if present + output_cleaned = output.lstrip(b"^D") # Remove leading ^D if present # Debug: Print cleaned output print(f"Cleaned TTY Output: {output_cleaned}") # Check if the output contains a valid TTY path - assert b"/dev/pts/" in output_cleaned or b"/dev/ttys" in output_cleaned, f"Unexpected TTY output: {output_cleaned}" + assert ( + b"/dev/pts/" in output_cleaned or b"/dev/ttys" in output_cleaned + ), f"Unexpected TTY output: {output_cleaned}" assert retcode == 0 diff --git a/tests/ra_aid/test_agent_utils.py b/tests/ra_aid/test_agent_utils.py index fa7cfc2..35e59d1 100644 --- a/tests/ra_aid/test_agent_utils.py +++ b/tests/ra_aid/test_agent_utils.py @@ -1,16 +1,19 @@ """Unit tests for agent_utils.py.""" -import pytest -from langchain_core.messages import SystemMessage, HumanMessage, AIMessage from unittest.mock import Mock, patch -from langchain_core.language_models import BaseChatModel + import litellm +import pytest +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage -from ra_aid.models_tokens import DEFAULT_TOKEN_LIMIT -from ra_aid.agent_utils import state_modifier, AgentState - -from ra_aid.agent_utils import create_agent, get_model_token_limit -from ra_aid.models_tokens import models_tokens +from ra_aid.agent_utils import ( + AgentState, + create_agent, + get_model_token_limit, + state_modifier, +) +from ra_aid.models_tokens import DEFAULT_TOKEN_LIMIT, models_tokens @pytest.fixture @@ -60,7 +63,6 @@ def test_get_model_token_limit_missing_config(mock_memory): assert token_limit is None - def test_get_model_token_limit_litellm_success(): """Test get_model_token_limit successfully getting limit from litellm.""" config = {"provider": "anthropic", "model": "claude-2"} diff --git a/tests/ra_aid/test_default_provider.py b/tests/ra_aid/test_default_provider.py index d58280a..a4ddfa9 100644 --- a/tests/ra_aid/test_default_provider.py +++ b/tests/ra_aid/test_default_provider.py @@ -1,17 +1,19 @@ """Tests for default provider and model configuration.""" -import os -import pytest from dataclasses import dataclass from typing import Optional -from ra_aid.env import validate_environment +import pytest + from ra_aid.__main__ import parse_arguments +from ra_aid.env import validate_environment + @dataclass class MockArgs: """Mock arguments for testing.""" - provider: str + + provider: Optional[str] = None expert_provider: Optional[str] = None model: Optional[str] = None expert_model: Optional[str] = None @@ -19,6 +21,7 @@ class MockArgs: research_only: bool = False chat: bool = False + @pytest.fixture def clean_env(monkeypatch): """Remove all provider-related environment variables.""" @@ -37,6 +40,7 @@ def clean_env(monkeypatch): monkeypatch.delenv(var, raising=False) yield + def test_default_anthropic_provider(clean_env, monkeypatch): """Test that Anthropic is the default provider when no environment variables are set.""" args = parse_arguments(["-m", "test message"]) @@ -44,62 +48,41 @@ def test_default_anthropic_provider(clean_env, monkeypatch): assert args.model == "claude-3-5-sonnet-20241022" -"""Unit tests for provider and model validation in research-only mode.""" - -import pytest -from dataclasses import dataclass -from argparse import Namespace -from ra_aid.env import validate_environment - - -@dataclass -class MockArgs: - """Mock command line arguments.""" - research_only: bool = False - provider: str = None - model: str = None - expert_provider: str = None - - TEST_CASES = [ pytest.param( "research_only_no_provider", MockArgs(research_only=True), {}, "No provider specified", - id="research_only_no_provider" + id="research_only_no_provider", ), pytest.param( "research_only_anthropic", MockArgs(research_only=True, provider="anthropic"), {}, None, - id="research_only_anthropic" + id="research_only_anthropic", ), pytest.param( "research_only_non_anthropic_no_model", MockArgs(research_only=True, provider="openai"), {}, "Model is required for non-Anthropic providers", - id="research_only_non_anthropic_no_model" + id="research_only_non_anthropic_no_model", ), pytest.param( "research_only_non_anthropic_with_model", MockArgs(research_only=True, provider="openai", model="gpt-4"), {}, None, - id="research_only_non_anthropic_with_model" - ) + id="research_only_non_anthropic_with_model", + ), ] @pytest.mark.parametrize("test_name,args,env_vars,expected_error", TEST_CASES) def test_research_only_provider_validation( - test_name: str, - args: MockArgs, - env_vars: dict, - expected_error: str, - monkeypatch + test_name: str, args: MockArgs, env_vars: dict, expected_error: str, monkeypatch ): """Test provider and model validation in research-only mode.""" # Set test environment variables diff --git a/tests/ra_aid/test_env.py b/tests/ra_aid/test_env.py index 4642134..4d00163 100644 --- a/tests/ra_aid/test_env.py +++ b/tests/ra_aid/test_env.py @@ -1,10 +1,12 @@ import os -import pytest from dataclasses import dataclass from typing import Optional +import pytest + from ra_aid.env import validate_environment + @dataclass class MockArgs: provider: str @@ -16,30 +18,44 @@ class MockArgs: planner_provider: Optional[str] = None planner_model: Optional[str] = None + @pytest.fixture def clean_env(monkeypatch): """Remove relevant environment variables before each test""" env_vars = [ - 'ANTHROPIC_API_KEY', 'OPENAI_API_KEY', 'OPENROUTER_API_KEY', - 'OPENAI_API_BASE', 'EXPERT_ANTHROPIC_API_KEY', 'EXPERT_OPENAI_API_KEY', - 'EXPERT_OPENROUTER_API_KEY', 'EXPERT_OPENAI_API_BASE', 'TAVILY_API_KEY', 'ANTHROPIC_MODEL' + "ANTHROPIC_API_KEY", + "OPENAI_API_KEY", + "OPENROUTER_API_KEY", + "OPENAI_API_BASE", + "EXPERT_ANTHROPIC_API_KEY", + "EXPERT_OPENAI_API_KEY", + "EXPERT_OPENROUTER_API_KEY", + "EXPERT_OPENAI_API_BASE", + "TAVILY_API_KEY", + "ANTHROPIC_MODEL", ] for var in env_vars: monkeypatch.delenv(var, raising=False) + def test_anthropic_validation(clean_env, monkeypatch): - args = MockArgs(provider="anthropic", expert_provider="openai", model="claude-3-haiku-20240307") + args = MockArgs( + provider="anthropic", expert_provider="openai", model="claude-3-haiku-20240307" + ) # Should fail without API key with pytest.raises(SystemExit): validate_environment(args) # Should pass with API key and model - monkeypatch.setenv('ANTHROPIC_API_KEY', 'test-key') - expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") + expert_enabled, expert_missing, web_research_enabled, web_research_missing = ( + validate_environment(args) + ) assert not expert_enabled assert not web_research_enabled - assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing + assert "TAVILY_API_KEY environment variable is not set" in web_research_missing + def test_openai_validation(clean_env, monkeypatch): args = MockArgs(provider="openai", expert_provider="openai") @@ -49,13 +65,16 @@ def test_openai_validation(clean_env, monkeypatch): validate_environment(args) # Should pass with API key and enable expert mode with fallback - monkeypatch.setenv('OPENAI_API_KEY', 'test-key') - expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) + monkeypatch.setenv("OPENAI_API_KEY", "test-key") + expert_enabled, expert_missing, web_research_enabled, web_research_missing = ( + validate_environment(args) + ) assert expert_enabled assert not expert_missing assert not web_research_enabled - assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing - assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'test-key' + assert "TAVILY_API_KEY environment variable is not set" in web_research_missing + assert os.environ.get("EXPERT_OPENAI_API_KEY") == "test-key" + def test_openai_compatible_validation(clean_env, monkeypatch): args = MockArgs(provider="openai-compatible", expert_provider="openai-compatible") @@ -65,126 +84,158 @@ def test_openai_compatible_validation(clean_env, monkeypatch): validate_environment(args) # Should fail with only API key - monkeypatch.setenv('OPENAI_API_KEY', 'test-key') + monkeypatch.setenv("OPENAI_API_KEY", "test-key") with pytest.raises(SystemExit): validate_environment(args) # Should pass with both API key and base URL - monkeypatch.setenv('OPENAI_API_BASE', 'http://test') - expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) + monkeypatch.setenv("OPENAI_API_BASE", "http://test") + expert_enabled, expert_missing, web_research_enabled, web_research_missing = ( + validate_environment(args) + ) assert expert_enabled assert not expert_missing assert not web_research_enabled - assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing - assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'test-key' - assert os.environ.get('EXPERT_OPENAI_API_BASE') == 'http://test' + assert "TAVILY_API_KEY environment variable is not set" in web_research_missing + assert os.environ.get("EXPERT_OPENAI_API_KEY") == "test-key" + assert os.environ.get("EXPERT_OPENAI_API_BASE") == "http://test" + def test_expert_fallback(clean_env, monkeypatch): args = MockArgs(provider="openai", expert_provider="openai") # Set only base API key - monkeypatch.setenv('OPENAI_API_KEY', 'test-key') + monkeypatch.setenv("OPENAI_API_KEY", "test-key") # Should enable expert mode with fallback - expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) + expert_enabled, expert_missing, web_research_enabled, web_research_missing = ( + validate_environment(args) + ) assert expert_enabled assert not expert_missing assert not web_research_enabled - assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing - assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'test-key' + assert "TAVILY_API_KEY environment variable is not set" in web_research_missing + assert os.environ.get("EXPERT_OPENAI_API_KEY") == "test-key" # Should use explicit expert key if available - monkeypatch.setenv('EXPERT_OPENAI_API_KEY', 'expert-key') - expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) + monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "expert-key") + expert_enabled, expert_missing, web_research_enabled, web_research_missing = ( + validate_environment(args) + ) assert expert_enabled assert not expert_missing assert not web_research_enabled - assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing - assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'expert-key' + assert "TAVILY_API_KEY environment variable is not set" in web_research_missing + assert os.environ.get("EXPERT_OPENAI_API_KEY") == "expert-key" + def test_cross_provider_fallback(clean_env, monkeypatch): """Test that fallback works even when providers differ""" - args = MockArgs(provider="openai", expert_provider="anthropic", expert_model="claude-3-haiku-20240307") + args = MockArgs( + provider="openai", + expert_provider="anthropic", + expert_model="claude-3-haiku-20240307", + ) # Set base API key for main provider and expert provider - monkeypatch.setenv('OPENAI_API_KEY', 'openai-key') - monkeypatch.setenv('ANTHROPIC_API_KEY', 'anthropic-key') - monkeypatch.setenv('ANTHROPIC_MODEL', 'claude-3-haiku-20240307') + monkeypatch.setenv("OPENAI_API_KEY", "openai-key") + monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key") + monkeypatch.setenv("ANTHROPIC_MODEL", "claude-3-haiku-20240307") # Should enable expert mode with fallback to ANTHROPIC base key - expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) + expert_enabled, expert_missing, web_research_enabled, web_research_missing = ( + validate_environment(args) + ) assert expert_enabled assert not expert_missing assert not web_research_enabled - assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing + assert "TAVILY_API_KEY environment variable is not set" in web_research_missing # Try with openai-compatible expert provider args = MockArgs(provider="anthropic", expert_provider="openai-compatible") - monkeypatch.setenv('OPENAI_API_KEY', 'openai-key') - monkeypatch.setenv('OPENAI_API_BASE', 'http://test') + monkeypatch.setenv("OPENAI_API_KEY", "openai-key") + monkeypatch.setenv("OPENAI_API_BASE", "http://test") - expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) + expert_enabled, expert_missing, web_research_enabled, web_research_missing = ( + validate_environment(args) + ) assert expert_enabled assert not expert_missing assert not web_research_enabled - assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing - assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'openai-key' - assert os.environ.get('EXPERT_OPENAI_API_BASE') == 'http://test' + assert "TAVILY_API_KEY environment variable is not set" in web_research_missing + assert os.environ.get("EXPERT_OPENAI_API_KEY") == "openai-key" + assert os.environ.get("EXPERT_OPENAI_API_BASE") == "http://test" + def test_no_warning_on_fallback(clean_env, monkeypatch): """Test that no warning is issued when fallback succeeds""" args = MockArgs(provider="openai", expert_provider="openai") # Set only base API key - monkeypatch.setenv('OPENAI_API_KEY', 'test-key') + monkeypatch.setenv("OPENAI_API_KEY", "test-key") # Should enable expert mode with fallback and no warnings - expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) + expert_enabled, expert_missing, web_research_enabled, web_research_missing = ( + validate_environment(args) + ) assert expert_enabled assert not expert_missing assert not web_research_enabled - assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing - assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'test-key' + assert "TAVILY_API_KEY environment variable is not set" in web_research_missing + assert os.environ.get("EXPERT_OPENAI_API_KEY") == "test-key" # Should use explicit expert key if available - monkeypatch.setenv('EXPERT_OPENAI_API_KEY', 'expert-key') - expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) + monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "expert-key") + expert_enabled, expert_missing, web_research_enabled, web_research_missing = ( + validate_environment(args) + ) assert expert_enabled assert not expert_missing assert not web_research_enabled - assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing - assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'expert-key' + assert "TAVILY_API_KEY environment variable is not set" in web_research_missing + assert os.environ.get("EXPERT_OPENAI_API_KEY") == "expert-key" + def test_different_providers_no_expert_key(clean_env, monkeypatch): """Test behavior when providers differ and only base keys are available""" - args = MockArgs(provider="anthropic", expert_provider="openai", model="claude-3-haiku-20240307") + args = MockArgs( + provider="anthropic", expert_provider="openai", model="claude-3-haiku-20240307" + ) # Set only base keys - monkeypatch.setenv('ANTHROPIC_API_KEY', 'anthropic-key') - monkeypatch.setenv('OPENAI_API_KEY', 'openai-key') + monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key") + monkeypatch.setenv("OPENAI_API_KEY", "openai-key") # Should enable expert mode and use base OPENAI key - expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) + expert_enabled, expert_missing, web_research_enabled, web_research_missing = ( + validate_environment(args) + ) assert expert_enabled assert not expert_missing assert not web_research_enabled - assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing + assert "TAVILY_API_KEY environment variable is not set" in web_research_missing def test_mixed_provider_openai_compatible(clean_env, monkeypatch): """Test behavior with openai-compatible expert and different main provider""" - args = MockArgs(provider="anthropic", expert_provider="openai-compatible", model="claude-3-haiku-20240307") + args = MockArgs( + provider="anthropic", + expert_provider="openai-compatible", + model="claude-3-haiku-20240307", + ) # Set all required keys and URLs - monkeypatch.setenv('ANTHROPIC_API_KEY', 'anthropic-key') - monkeypatch.setenv('OPENAI_API_KEY', 'openai-key') - monkeypatch.setenv('OPENAI_API_BASE', 'http://test') + monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key") + monkeypatch.setenv("OPENAI_API_KEY", "openai-key") + monkeypatch.setenv("OPENAI_API_BASE", "http://test") # Should enable expert mode and use base openai key and URL - expert_enabled, expert_missing, web_research_enabled, web_research_missing = validate_environment(args) + expert_enabled, expert_missing, web_research_enabled, web_research_missing = ( + validate_environment(args) + ) assert expert_enabled assert not expert_missing assert not web_research_enabled - assert 'TAVILY_API_KEY environment variable is not set' in web_research_missing - assert os.environ.get('EXPERT_OPENAI_API_KEY') == 'openai-key' - assert os.environ.get('EXPERT_OPENAI_API_BASE') == 'http://test' + assert "TAVILY_API_KEY environment variable is not set" in web_research_missing + assert os.environ.get("EXPERT_OPENAI_API_KEY") == "openai-key" + assert os.environ.get("EXPERT_OPENAI_API_BASE") == "http://test" diff --git a/tests/ra_aid/test_llm.py b/tests/ra_aid/test_llm.py index 403b2d1..831ed1c 100644 --- a/tests/ra_aid/test_llm.py +++ b/tests/ra_aid/test_llm.py @@ -1,179 +1,185 @@ import os -import pytest -from unittest.mock import patch, Mock -from langchain_openai.chat_models import ChatOpenAI -from langchain_anthropic.chat_models import ChatAnthropic -from langchain_google_genai.chat_models import ChatGoogleGenerativeAI -from langchain_core.messages import HumanMessage from dataclasses import dataclass -from ra_aid.agents.ciayn_agent import CiaynAgent +from unittest.mock import Mock, patch +import pytest +from langchain_anthropic.chat_models import ChatAnthropic +from langchain_core.messages import HumanMessage +from langchain_google_genai.chat_models import ChatGoogleGenerativeAI +from langchain_openai.chat_models import ChatOpenAI + +from ra_aid.agents.ciayn_agent import CiaynAgent from ra_aid.env import validate_environment from ra_aid.llm import ( - initialize_llm, - initialize_expert_llm, + create_llm_client, get_env_var, get_provider_config, - create_llm_client + initialize_expert_llm, + initialize_llm, ) + @pytest.fixture def clean_env(monkeypatch): """Remove relevant environment variables before each test""" env_vars = [ - 'ANTHROPIC_API_KEY', 'OPENAI_API_KEY', 'OPENROUTER_API_KEY', - 'OPENAI_API_BASE', 'EXPERT_ANTHROPIC_API_KEY', 'EXPERT_OPENAI_API_KEY', - 'EXPERT_OPENROUTER_API_KEY', 'EXPERT_OPENAI_API_BASE', 'GEMINI_API_KEY', 'EXPERT_GEMINI_API_KEY' + "ANTHROPIC_API_KEY", + "OPENAI_API_KEY", + "OPENROUTER_API_KEY", + "OPENAI_API_BASE", + "EXPERT_ANTHROPIC_API_KEY", + "EXPERT_OPENAI_API_KEY", + "EXPERT_OPENROUTER_API_KEY", + "EXPERT_OPENAI_API_BASE", + "GEMINI_API_KEY", + "EXPERT_GEMINI_API_KEY", ] for var in env_vars: monkeypatch.delenv(var, raising=False) + @pytest.fixture def mock_openai(): """ Mock ChatOpenAI class for testing OpenAI provider initialization. Prevents actual API calls during testing. """ - with patch('ra_aid.llm.ChatOpenAI') as mock: + with patch("ra_aid.llm.ChatOpenAI") as mock: mock.return_value = Mock(spec=ChatOpenAI) yield mock + def test_initialize_expert_defaults(clean_env, mock_openai, monkeypatch): """Test expert LLM initialization with default parameters.""" monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "test-key") - llm = initialize_expert_llm() - - mock_openai.assert_called_once_with( - api_key="test-key", - model="o1", - temperature=0 - ) + _llm = initialize_expert_llm() + + mock_openai.assert_called_once_with(api_key="test-key", model="o1", temperature=0) + def test_initialize_expert_openai_custom(clean_env, mock_openai, monkeypatch): """Test expert OpenAI initialization with custom parameters.""" monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "test-key") - llm = initialize_expert_llm("openai", "gpt-4-preview") - + _llm = initialize_expert_llm("openai", "gpt-4-preview") + mock_openai.assert_called_once_with( - api_key="test-key", - model="gpt-4-preview", - temperature=0 + api_key="test-key", model="gpt-4-preview", temperature=0 ) + def test_initialize_expert_gemini(clean_env, mock_gemini, monkeypatch): """Test expert Gemini initialization.""" monkeypatch.setenv("EXPERT_GEMINI_API_KEY", "test-key") - llm = initialize_expert_llm("gemini", "gemini-2.0-flash-thinking-exp-1219") - + _llm = initialize_expert_llm("gemini", "gemini-2.0-flash-thinking-exp-1219") + mock_gemini.assert_called_once_with( - api_key="test-key", - model="gemini-2.0-flash-thinking-exp-1219", - temperature=0 + api_key="test-key", model="gemini-2.0-flash-thinking-exp-1219", temperature=0 ) + def test_initialize_expert_anthropic(clean_env, mock_anthropic, monkeypatch): """Test expert Anthropic initialization.""" monkeypatch.setenv("EXPERT_ANTHROPIC_API_KEY", "test-key") - llm = initialize_expert_llm("anthropic", "claude-3") - + _llm = initialize_expert_llm("anthropic", "claude-3") + mock_anthropic.assert_called_once_with( - api_key="test-key", - model_name="claude-3", - temperature=0 + api_key="test-key", model_name="claude-3", temperature=0 ) + def test_initialize_expert_openrouter(clean_env, mock_openai, monkeypatch): """Test expert OpenRouter initialization.""" monkeypatch.setenv("EXPERT_OPENROUTER_API_KEY", "test-key") - llm = initialize_expert_llm("openrouter", "models/mistral-large") - + _llm = initialize_expert_llm("openrouter", "models/mistral-large") + mock_openai.assert_called_once_with( api_key="test-key", base_url="https://openrouter.ai/api/v1", model="models/mistral-large", - temperature=0 + temperature=0, ) + def test_initialize_expert_openai_compatible(clean_env, mock_openai, monkeypatch): """Test expert OpenAI-compatible initialization.""" monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "test-key") monkeypatch.setenv("EXPERT_OPENAI_API_BASE", "http://test-url") - llm = initialize_expert_llm("openai-compatible", "local-model") - + _llm = initialize_expert_llm("openai-compatible", "local-model") + mock_openai.assert_called_once_with( api_key="test-key", base_url="http://test-url", model="local-model", - temperature=0 + temperature=0, ) + def test_initialize_expert_unsupported_provider(clean_env): """Test error handling for unsupported provider in expert mode.""" with pytest.raises(ValueError, match=r"Unsupported provider: unknown"): initialize_expert_llm("unknown", "model") + def test_estimate_tokens(): """Test token estimation functionality.""" # Test empty/None cases assert CiaynAgent._estimate_tokens(None) == 0 - assert CiaynAgent._estimate_tokens('') == 0 - + assert CiaynAgent._estimate_tokens("") == 0 + # Test string content - assert CiaynAgent._estimate_tokens('test') == 1 # 4 bytes - assert CiaynAgent._estimate_tokens('hello world') == 2 # 11 bytes - assert CiaynAgent._estimate_tokens('πŸš€') == 1 # 4 bytes - + assert CiaynAgent._estimate_tokens("test") == 1 # 4 bytes + assert CiaynAgent._estimate_tokens("hello world") == 2 # 11 bytes + assert CiaynAgent._estimate_tokens("πŸš€") == 1 # 4 bytes + # Test message content - msg = HumanMessage(content='test message') + msg = HumanMessage(content="test message") assert CiaynAgent._estimate_tokens(msg) == 3 # 11 bytes + def test_initialize_openai(clean_env, mock_openai): """Test OpenAI provider initialization""" os.environ["OPENAI_API_KEY"] = "test-key" - model = initialize_llm("openai", "gpt-4") - - mock_openai.assert_called_once_with( - api_key="test-key", - model="gpt-4" - ) + _model = initialize_llm("openai", "gpt-4") + + mock_openai.assert_called_once_with(api_key="test-key", model="gpt-4") + def test_initialize_gemini(clean_env, mock_gemini): """Test Gemini provider initialization""" os.environ["GEMINI_API_KEY"] = "test-key" - model = initialize_llm("gemini", "gemini-2.0-flash-thinking-exp-1219") - + _model = initialize_llm("gemini", "gemini-2.0-flash-thinking-exp-1219") + mock_gemini.assert_called_once_with( - api_key="test-key", - model="gemini-2.0-flash-thinking-exp-1219" + api_key="test-key", model="gemini-2.0-flash-thinking-exp-1219" ) + def test_initialize_anthropic(clean_env, mock_anthropic): """Test Anthropic provider initialization""" os.environ["ANTHROPIC_API_KEY"] = "test-key" - model = initialize_llm("anthropic", "claude-3") - - mock_anthropic.assert_called_once_with( - api_key="test-key", - model_name="claude-3" - ) + _model = initialize_llm("anthropic", "claude-3") + + mock_anthropic.assert_called_once_with(api_key="test-key", model_name="claude-3") + def test_initialize_openrouter(clean_env, mock_openai): """Test OpenRouter provider initialization""" os.environ["OPENROUTER_API_KEY"] = "test-key" - model = initialize_llm("openrouter", "mistral-large") - + _model = initialize_llm("openrouter", "mistral-large") + mock_openai.assert_called_once_with( api_key="test-key", base_url="https://openrouter.ai/api/v1", - model="mistral-large" + model="mistral-large", ) + def test_initialize_openai_compatible(clean_env, mock_openai): """Test OpenAI-compatible provider initialization""" os.environ["OPENAI_API_KEY"] = "test-key" os.environ["OPENAI_API_BASE"] = "https://custom-endpoint/v1" - model = initialize_llm("openai-compatible", "local-model") - + _model = initialize_llm("openai-compatible", "local-model") + mock_openai.assert_called_once_with( api_key="test-key", base_url="https://custom-endpoint/v1", @@ -181,12 +187,14 @@ def test_initialize_openai_compatible(clean_env, mock_openai): temperature=0.3, ) + def test_initialize_unsupported_provider(clean_env): """Test initialization with unsupported provider raises ValueError""" with pytest.raises(ValueError) as exc_info: initialize_llm("unsupported", "model") assert str(exc_info.value) == "Unsupported provider: unsupported" + def test_temperature_defaults(clean_env, mock_openai, mock_anthropic, mock_gemini): """Test default temperature behavior for different providers.""" os.environ["OPENAI_API_KEY"] = "test-key" @@ -199,27 +207,19 @@ def test_temperature_defaults(clean_env, mock_openai, mock_anthropic, mock_gemin api_key="test-key", base_url="http://test-url", model="test-model", - temperature=0.3 + temperature=0.3, ) - + # Test other providers don't set temperature by default initialize_llm("openai", "test-model") - mock_openai.assert_called_with( - api_key="test-key", - model="test-model" - ) - + mock_openai.assert_called_with(api_key="test-key", model="test-model") + initialize_llm("anthropic", "test-model") - mock_anthropic.assert_called_with( - api_key="test-key", - model_name="test-model" - ) - + mock_anthropic.assert_called_with(api_key="test-key", model_name="test-model") + initialize_llm("gemini", "test-model") - mock_gemini.assert_called_with( - api_key="test-key", - model="test-model" - ) + mock_gemini.assert_called_with(api_key="test-key", model="test-model") + def test_explicit_temperature(clean_env, mock_openai, mock_anthropic, mock_gemini): """Test explicit temperature setting for each provider.""" @@ -227,136 +227,129 @@ def test_explicit_temperature(clean_env, mock_openai, mock_anthropic, mock_gemin os.environ["ANTHROPIC_API_KEY"] = "test-key" os.environ["OPENROUTER_API_KEY"] = "test-key" os.environ["GEMINI_API_KEY"] = "test-key" - + test_temp = 0.7 - + # Test OpenAI initialize_llm("openai", "test-model", temperature=test_temp) mock_openai.assert_called_with( - api_key="test-key", - model="test-model", - temperature=test_temp + api_key="test-key", model="test-model", temperature=test_temp ) - + # Test Gemini initialize_llm("gemini", "test-model", temperature=test_temp) mock_gemini.assert_called_with( - api_key="test-key", - model="test-model", - temperature=test_temp + api_key="test-key", model="test-model", temperature=test_temp ) - + # Test Anthropic initialize_llm("anthropic", "test-model", temperature=test_temp) mock_anthropic.assert_called_with( - api_key="test-key", - model_name="test-model", - temperature=test_temp + api_key="test-key", model_name="test-model", temperature=test_temp ) - + # Test OpenRouter initialize_llm("openrouter", "test-model", temperature=test_temp) mock_openai.assert_called_with( api_key="test-key", base_url="https://openrouter.ai/api/v1", model="test-model", - temperature=test_temp + temperature=test_temp, ) + def test_temperature_validation(clean_env, mock_openai): """Test temperature validation in command line arguments.""" from ra_aid.__main__ import parse_arguments - + # Test temperature below minimum with pytest.raises(SystemExit): - parse_arguments(['--message', 'test', '--temperature', '-0.1']) - + parse_arguments(["--message", "test", "--temperature", "-0.1"]) + # Test temperature above maximum with pytest.raises(SystemExit): - parse_arguments(['--message', 'test', '--temperature', '2.1']) - + parse_arguments(["--message", "test", "--temperature", "2.1"]) + # Test valid temperature - args = parse_arguments(['--message', 'test', '--temperature', '0.7']) + args = parse_arguments(["--message", "test", "--temperature", "0.7"]) assert args.temperature == 0.7 + def test_provider_name_validation(): """Test provider name validation and normalization.""" # Test all supported providers providers = ["openai", "anthropic", "openrouter", "openai-compatible", "gemini"] for provider in providers: try: - with patch(f'ra_aid.llm.ChatOpenAI'), patch('ra_aid.llm.ChatAnthropic'): + with patch("ra_aid.llm.ChatOpenAI"), patch("ra_aid.llm.ChatAnthropic"): initialize_llm(provider, "test-model") except ValueError: pytest.fail(f"Valid provider {provider} raised ValueError") - + # Test case sensitivity - with patch('ra_aid.llm.ChatOpenAI'): + with patch("ra_aid.llm.ChatOpenAI"): with pytest.raises(ValueError): initialize_llm("OpenAI", "test-model") -def test_initialize_llm_cross_provider(clean_env, mock_openai, mock_anthropic, mock_gemini, monkeypatch): + +def test_initialize_llm_cross_provider( + clean_env, mock_openai, mock_anthropic, mock_gemini, monkeypatch +): """Test initializing different providers in sequence.""" # Initialize OpenAI monkeypatch.setenv("OPENAI_API_KEY", "openai-key") - llm1 = initialize_llm("openai", "gpt-4") - + _llm1 = initialize_llm("openai", "gpt-4") + # Initialize Anthropic - monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key") - llm2 = initialize_llm("anthropic", "claude-3") - + monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key") + _llm2 = initialize_llm("anthropic", "claude-3") + # Initialize Gemini monkeypatch.setenv("GEMINI_API_KEY", "gemini-key") - llm3 = initialize_llm("gemini", "gemini-2.0-flash-thinking-exp-1219") - + _llm3 = initialize_llm("gemini", "gemini-2.0-flash-thinking-exp-1219") + # Verify both were initialized correctly - mock_openai.assert_called_once_with( - api_key="openai-key", - model="gpt-4" - ) + mock_openai.assert_called_once_with(api_key="openai-key", model="gpt-4") mock_anthropic.assert_called_once_with( - api_key="anthropic-key", - model_name="claude-3" + api_key="anthropic-key", model_name="claude-3" ) mock_gemini.assert_called_once_with( - api_key="gemini-key", - model="gemini-2.0-flash-thinking-exp-1219" + api_key="gemini-key", model="gemini-2.0-flash-thinking-exp-1219" ) + @dataclass class Args: """Test arguments class.""" + provider: str expert_provider: str model: str = None expert_model: str = None + def test_environment_variable_precedence(clean_env, mock_openai, monkeypatch): """Test environment variable precedence and fallback.""" # Test get_env_var helper with fallback monkeypatch.setenv("TEST_KEY", "base-value") monkeypatch.setenv("EXPERT_TEST_KEY", "expert-value") - + assert get_env_var("TEST_KEY") == "base-value" assert get_env_var("TEST_KEY", expert=True) == "expert-value" - + # Test fallback when expert value not set monkeypatch.delenv("EXPERT_TEST_KEY", raising=False) assert get_env_var("TEST_KEY", expert=True) == "base-value" - + # Test provider config monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "expert-key") config = get_provider_config("openai", is_expert=True) assert config["api_key"] == "expert-key" - + # Test LLM client creation with expert mode - llm = create_llm_client("openai", "o1", is_expert=True) - mock_openai.assert_called_with( - api_key="expert-key", - model="o1", - temperature=0 - ) - + _llm = create_llm_client("openai", "o1", is_expert=True) + mock_openai.assert_called_with(api_key="expert-key", model="o1", temperature=0) + # Test environment validation monkeypatch.setenv("EXPERT_OPENAI_API_KEY", "") monkeypatch.delenv("OPENAI_API_KEY", raising=False) @@ -364,146 +357,164 @@ def test_environment_variable_precedence(clean_env, mock_openai, monkeypatch): monkeypatch.setenv("ANTHROPIC_API_KEY", "anthropic-key") monkeypatch.setenv("GEMINI_API_KEY", "gemini-key") monkeypatch.setenv("ANTHROPIC_MODEL", "claude-3-haiku-20240307") - + args = Args(provider="anthropic", expert_provider="openai") - expert_enabled, expert_missing, web_enabled, web_missing = validate_environment(args) + expert_enabled, expert_missing, web_enabled, web_missing = validate_environment( + args + ) assert not expert_enabled assert expert_missing assert not web_enabled assert web_missing + @pytest.fixture def mock_anthropic(): """ Mock ChatAnthropic class for testing Anthropic provider initialization. Prevents actual API calls during testing. """ - with patch('ra_aid.llm.ChatAnthropic') as mock: + with patch("ra_aid.llm.ChatAnthropic") as mock: mock.return_value = Mock(spec=ChatAnthropic) yield mock + @pytest.fixture def mock_gemini(): """Mock ChatGoogleGenerativeAI class for testing Gemini provider initialization.""" - with patch('ra_aid.llm.ChatGoogleGenerativeAI') as mock: + with patch("ra_aid.llm.ChatGoogleGenerativeAI") as mock: mock.return_value = Mock(spec=ChatGoogleGenerativeAI) yield mock + @pytest.fixture def mock_deepseek_reasoner(): """Mock ChatDeepseekReasoner for testing DeepSeek provider initialization.""" - with patch('ra_aid.llm.ChatDeepseekReasoner') as mock: + with patch("ra_aid.llm.ChatDeepseekReasoner") as mock: mock.return_value = Mock() yield mock -def test_initialize_deepseek(clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch): + +def test_initialize_deepseek( + clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch +): """Test DeepSeek provider initialization with different models.""" monkeypatch.setenv("DEEPSEEK_API_KEY", "test-key") - + # Test with reasoner model - model = initialize_llm("deepseek", "deepseek-reasoner") + _model = initialize_llm("deepseek", "deepseek-reasoner") mock_deepseek_reasoner.assert_called_with( api_key="test-key", base_url="https://api.deepseek.com", temperature=1, - model="deepseek-reasoner" + model="deepseek-reasoner", ) - + # Test with non-reasoner model - model = initialize_llm("deepseek", "deepseek-chat") + _model = initialize_llm("deepseek", "deepseek-chat") mock_openai.assert_called_with( api_key="test-key", base_url="https://api.deepseek.com", temperature=1, - model="deepseek-chat" + model="deepseek-chat", ) -def test_initialize_expert_deepseek(clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch): + +def test_initialize_expert_deepseek( + clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch +): """Test expert DeepSeek provider initialization.""" monkeypatch.setenv("EXPERT_DEEPSEEK_API_KEY", "test-key") - + # Test with reasoner model - model = initialize_expert_llm("deepseek", "deepseek-reasoner") + _model = initialize_expert_llm("deepseek", "deepseek-reasoner") mock_deepseek_reasoner.assert_called_with( api_key="test-key", base_url="https://api.deepseek.com", temperature=0, - model="deepseek-reasoner" + model="deepseek-reasoner", ) - + # Test with non-reasoner model - model = initialize_expert_llm("deepseek", "deepseek-chat") + _model = initialize_expert_llm("deepseek", "deepseek-chat") mock_openai.assert_called_with( api_key="test-key", base_url="https://api.deepseek.com", temperature=0, - model="deepseek-chat" + model="deepseek-chat", ) -def test_initialize_openrouter_deepseek(clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch): + +def test_initialize_openrouter_deepseek( + clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch +): """Test OpenRouter DeepSeek model initialization.""" monkeypatch.setenv("OPENROUTER_API_KEY", "test-key") - + # Test with DeepSeek R1 model - model = initialize_llm("openrouter", "deepseek/deepseek-r1") + _model = initialize_llm("openrouter", "deepseek/deepseek-r1") mock_deepseek_reasoner.assert_called_with( api_key="test-key", base_url="https://openrouter.ai/api/v1", temperature=1, - model="deepseek/deepseek-r1" - ) - - # Test with non-DeepSeek model - model = initialize_llm("openrouter", "mistral/mistral-large") - mock_openai.assert_called_with( - api_key="test-key", - base_url="https://openrouter.ai/api/v1", - model="mistral/mistral-large" + model="deepseek/deepseek-r1", ) -def test_initialize_expert_openrouter_deepseek(clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch): - """Test expert OpenRouter DeepSeek model initialization.""" - monkeypatch.setenv("EXPERT_OPENROUTER_API_KEY", "test-key") - - # Test with DeepSeek R1 model via create_llm_client - model = create_llm_client("openrouter", "deepseek/deepseek-r1", is_expert=True) - mock_deepseek_reasoner.assert_called_with( - api_key="test-key", - base_url="https://openrouter.ai/api/v1", - temperature=0, - model="deepseek/deepseek-r1" - ) - # Test with non-DeepSeek model - model = create_llm_client("openrouter", "mistral/mistral-large", is_expert=True) + _model = initialize_llm("openrouter", "mistral/mistral-large") mock_openai.assert_called_with( api_key="test-key", base_url="https://openrouter.ai/api/v1", model="mistral/mistral-large", - temperature=0 ) + +def test_initialize_expert_openrouter_deepseek( + clean_env, mock_openai, mock_deepseek_reasoner, monkeypatch +): + """Test expert OpenRouter DeepSeek model initialization.""" + monkeypatch.setenv("EXPERT_OPENROUTER_API_KEY", "test-key") + + # Test with DeepSeek R1 model via create_llm_client + _model = create_llm_client("openrouter", "deepseek/deepseek-r1", is_expert=True) + mock_deepseek_reasoner.assert_called_with( + api_key="test-key", + base_url="https://openrouter.ai/api/v1", + temperature=0, + model="deepseek/deepseek-r1", + ) + + # Test with non-DeepSeek model + _model = create_llm_client("openrouter", "mistral/mistral-large", is_expert=True) + mock_openai.assert_called_with( + api_key="test-key", + base_url="https://openrouter.ai/api/v1", + model="mistral/mistral-large", + temperature=0, + ) + + def test_deepseek_environment_fallback(clean_env, mock_deepseek_reasoner, monkeypatch): """Test DeepSeek environment variable fallback behavior.""" # Test environment variable helper with fallback monkeypatch.setenv("DEEPSEEK_API_KEY", "base-key") assert get_env_var("DEEPSEEK_API_KEY", expert=True) == "base-key" - + # Test provider config with fallback config = get_provider_config("deepseek", is_expert=True) assert config["api_key"] == "base-key" assert config["base_url"] == "https://api.deepseek.com" - + # Test with expert key monkeypatch.setenv("EXPERT_DEEPSEEK_API_KEY", "expert-key") config = get_provider_config("deepseek", is_expert=True) assert config["api_key"] == "expert-key" - + # Test client creation with expert key - model = create_llm_client("deepseek", "deepseek-reasoner", is_expert=True) + _model = create_llm_client("deepseek", "deepseek-reasoner", is_expert=True) mock_deepseek_reasoner.assert_called_with( api_key="expert-key", base_url="https://api.deepseek.com", temperature=0, - model="deepseek-reasoner" + model="deepseek-reasoner", ) diff --git a/tests/ra_aid/test_main.py b/tests/ra_aid/test_main.py index 5895716..8e769a1 100644 --- a/tests/ra_aid/test_main.py +++ b/tests/ra_aid/test_main.py @@ -1,19 +1,21 @@ """Unit tests for __main__.py argument parsing.""" import pytest + from ra_aid.__main__ import parse_arguments -from ra_aid.tools.memory import _global_memory from ra_aid.config import DEFAULT_RECURSION_LIMIT +from ra_aid.tools.memory import _global_memory @pytest.fixture def mock_dependencies(monkeypatch): """Mock all dependencies needed for main().""" - monkeypatch.setattr('ra_aid.__main__.check_dependencies', lambda: None) - - monkeypatch.setattr('ra_aid.__main__.validate_environment', - lambda args: (True, [], True, [])) - + monkeypatch.setattr("ra_aid.__main__.check_dependencies", lambda: None) + + monkeypatch.setattr( + "ra_aid.__main__.validate_environment", lambda args: (True, [], True, []) + ) + def mock_config_update(*args, **kwargs): config = _global_memory.get("config", {}) if kwargs.get("temperature"): @@ -21,27 +23,31 @@ def mock_dependencies(monkeypatch): _global_memory["config"] = config return None - monkeypatch.setattr('ra_aid.__main__.initialize_llm', - mock_config_update) - - monkeypatch.setattr('ra_aid.__main__.run_research_agent', - lambda *args, **kwargs: None) + monkeypatch.setattr("ra_aid.__main__.initialize_llm", mock_config_update) + + monkeypatch.setattr( + "ra_aid.__main__.run_research_agent", lambda *args, **kwargs: None + ) + def test_recursion_limit_in_global_config(mock_dependencies): """Test that recursion limit is correctly set in global config.""" - from ra_aid.__main__ import main import sys from unittest.mock import patch - + + from ra_aid.__main__ import main + _global_memory.clear() - - with patch.object(sys, 'argv', ['ra-aid', '-m', 'test message']): + + with patch.object(sys, "argv", ["ra-aid", "-m", "test message"]): main() assert _global_memory["config"]["recursion_limit"] == DEFAULT_RECURSION_LIMIT - + _global_memory.clear() - - with patch.object(sys, 'argv', ['ra-aid', '-m', 'test message', '--recursion-limit', '50']): + + with patch.object( + sys, "argv", ["ra-aid", "-m", "test message", "--recursion-limit", "50"] + ): main() assert _global_memory["config"]["recursion_limit"] == 50 @@ -60,23 +66,35 @@ def test_zero_recursion_limit(): def test_config_settings(mock_dependencies): """Test that various settings are correctly applied in global config.""" - from ra_aid.__main__ import main import sys from unittest.mock import patch + from ra_aid.__main__ import main + _global_memory.clear() - - with patch.object(sys, 'argv', [ - 'ra-aid', '-m', 'test message', - '--cowboy-mode', - '--research-only', - '--provider', 'anthropic', - '--model', 'claude-3-5-sonnet-20241022', - '--expert-provider', 'openai', - '--expert-model', 'gpt-4', - '--temperature', '0.7', - '--disable-limit-tokens' - ]): + + with patch.object( + sys, + "argv", + [ + "ra-aid", + "-m", + "test message", + "--cowboy-mode", + "--research-only", + "--provider", + "anthropic", + "--model", + "claude-3-5-sonnet-20241022", + "--expert-provider", + "openai", + "--expert-model", + "gpt-4", + "--temperature", + "0.7", + "--disable-limit-tokens", + ], + ): main() config = _global_memory["config"] assert config["cowboy_mode"] is True @@ -90,20 +108,25 @@ def test_config_settings(mock_dependencies): def test_temperature_validation(mock_dependencies): """Test that temperature argument is correctly passed to initialize_llm.""" - from ra_aid.__main__ import main, initialize_llm import sys from unittest.mock import patch + from ra_aid.__main__ import main + _global_memory.clear() - - with patch('ra_aid.__main__.initialize_llm') as mock_init_llm: - with patch.object(sys, 'argv', ['ra-aid', '-m', 'test', '--temperature', '0.7']): + + with patch("ra_aid.__main__.initialize_llm") as mock_init_llm: + with patch.object( + sys, "argv", ["ra-aid", "-m", "test", "--temperature", "0.7"] + ): main() mock_init_llm.assert_called_once() - assert mock_init_llm.call_args.kwargs['temperature'] == 0.7 + assert mock_init_llm.call_args.kwargs["temperature"] == 0.7 with pytest.raises(SystemExit): - with patch.object(sys, 'argv', ['ra-aid', '-m', 'test', '--temperature', '2.1']): + with patch.object( + sys, "argv", ["ra-aid", "-m", "test", "--temperature", "2.1"] + ): main() @@ -125,19 +148,30 @@ def test_missing_message(): def test_research_model_provider_args(mock_dependencies): """Test that research-specific model/provider args are correctly stored in config.""" - from ra_aid.__main__ import main import sys from unittest.mock import patch + from ra_aid.__main__ import main + _global_memory.clear() - - with patch.object(sys, 'argv', [ - 'ra-aid', '-m', 'test message', - '--research-provider', 'anthropic', - '--research-model', 'claude-3-haiku-20240307', - '--planner-provider', 'openai', - '--planner-model', 'gpt-4' - ]): + + with patch.object( + sys, + "argv", + [ + "ra-aid", + "-m", + "test message", + "--research-provider", + "anthropic", + "--research-model", + "claude-3-haiku-20240307", + "--planner-provider", + "openai", + "--planner-model", + "gpt-4", + ], + ): main() config = _global_memory["config"] assert config["research_provider"] == "anthropic" @@ -148,17 +182,18 @@ def test_research_model_provider_args(mock_dependencies): def test_planner_model_provider_args(mock_dependencies): """Test that planner provider/model args fall back to main config when not specified.""" - from ra_aid.__main__ import main import sys from unittest.mock import patch + from ra_aid.__main__ import main + _global_memory.clear() - - with patch.object(sys, 'argv', [ - 'ra-aid', '-m', 'test message', - '--provider', 'openai', - '--model', 'gpt-4' - ]): + + with patch.object( + sys, + "argv", + ["ra-aid", "-m", "test message", "--provider", "openai", "--model", "gpt-4"], + ): main() config = _global_memory["config"] assert config["planner_provider"] == "openai" diff --git a/tests/ra_aid/test_programmer.py b/tests/ra_aid/test_programmer.py index 3fdf0f3..583951d 100644 --- a/tests/ra_aid/test_programmer.py +++ b/tests/ra_aid/test_programmer.py @@ -1,4 +1,5 @@ import pytest + from ra_aid.tools.programmer import parse_aider_flags, run_programming_task # Test cases for parse_aider_flags function @@ -7,65 +8,57 @@ test_cases = [ ( "yes-always,dark-mode", ["--yes-always", "--dark-mode"], - "basic comma separated flags without dashes" + "basic comma separated flags without dashes", ), ( "--yes-always,--dark-mode", ["--yes-always", "--dark-mode"], - "comma separated flags with dashes" + "comma separated flags with dashes", ), ( "yes-always, dark-mode", ["--yes-always", "--dark-mode"], - "comma separated flags with space" + "comma separated flags with space", ), ( "--yes-always, --dark-mode", ["--yes-always", "--dark-mode"], - "comma separated flags with dashes and space" - ), - ( - "", - [], - "empty string" + "comma separated flags with dashes and space", ), + ("", [], "empty string"), ( " yes-always , dark-mode ", ["--yes-always", "--dark-mode"], - "flags with extra whitespace" + "flags with extra whitespace", ), - ( - "--yes-always", - ["--yes-always"], - "single flag with dashes" - ), - ( - "yes-always", - ["--yes-always"], - "single flag without dashes" - ) + ("--yes-always", ["--yes-always"], "single flag with dashes"), + ("yes-always", ["--yes-always"], "single flag without dashes"), ] + @pytest.mark.parametrize("input_flags,expected,description", test_cases) def test_parse_aider_flags(input_flags, expected, description): """Table-driven test for parse_aider_flags function.""" result = parse_aider_flags(input_flags) assert result == expected, f"Failed test case: {description}" + def test_aider_config_flag(mocker): """Test that aider config flag is properly included in the command when specified.""" mock_memory = { - 'config': {'aider_config': '/path/to/config.yml'}, - 'related_files': {} + "config": {"aider_config": "/path/to/config.yml"}, + "related_files": {}, } - mocker.patch('ra_aid.tools.programmer._global_memory', mock_memory) - + mocker.patch("ra_aid.tools.programmer._global_memory", mock_memory) + # Mock the run_interactive_command to capture the command that would be run - mock_run = mocker.patch('ra_aid.tools.programmer.run_interactive_command', return_value=(b'', 0)) - + mock_run = mocker.patch( + "ra_aid.tools.programmer.run_interactive_command", return_value=(b"", 0) + ) + run_programming_task("test instruction") - + args = mock_run.call_args[0][0] # Get the first positional arg (command list) - assert '--config' in args - config_index = args.index('--config') - assert args[config_index + 1] == '/path/to/config.yml' + assert "--config" in args + config_index = args.index("--config") + assert args[config_index + 1] == "/path/to/config.yml" diff --git a/tests/ra_aid/test_provider_integration.py b/tests/ra_aid/test_provider_integration.py index 2c2e88d..120caf6 100644 --- a/tests/ra_aid/test_provider_integration.py +++ b/tests/ra_aid/test_provider_integration.py @@ -1,19 +1,19 @@ """Integration tests for provider validation and environment handling.""" import os -import pytest from dataclasses import dataclass from typing import Optional +import pytest + from ra_aid.env import validate_environment from ra_aid.provider_strategy import ( - ProviderFactory, - ValidationResult, AnthropicStrategy, - OpenAIStrategy, - OpenAICompatibleStrategy, - OpenRouterStrategy, GeminiStrategy, + OpenAICompatibleStrategy, + OpenAIStrategy, + OpenRouterStrategy, + ProviderFactory, ) diff --git a/tests/ra_aid/test_tool_configs.py b/tests/ra_aid/test_tool_configs.py index f935811..7231d61 100644 --- a/tests/ra_aid/test_tool_configs.py +++ b/tests/ra_aid/test_tool_configs.py @@ -1,72 +1,84 @@ -import pytest from ra_aid.tool_configs import ( + get_implementation_tools, + get_planning_tools, get_read_only_tools, get_research_tools, - get_planning_tools, - get_implementation_tools, - get_web_research_tools + get_web_research_tools, ) + def test_get_read_only_tools(): # Test without human interaction tools = get_read_only_tools(human_interaction=False) assert len(tools) > 0 assert all(callable(tool) for tool in tools) - + # Test with human interaction tools_with_human = get_read_only_tools(human_interaction=True) assert len(tools_with_human) == len(tools) + 1 + def test_get_research_tools(): # Test basic research tools tools = get_research_tools() assert len(tools) > 0 assert all(callable(tool) for tool in tools) - + # Test without expert tools_no_expert = get_research_tools(expert_enabled=False) assert len(tools_no_expert) < len(tools) - + # Test research-only mode tools_research_only = get_research_tools(research_only=True) assert len(tools_research_only) < len(tools) + def test_get_planning_tools(): # Test with expert enabled tools = get_planning_tools(expert_enabled=True) assert len(tools) > 0 assert all(callable(tool) for tool in tools) - + # Test without expert tools_no_expert = get_planning_tools(expert_enabled=False) assert len(tools_no_expert) < len(tools) + def test_get_implementation_tools(): # Test with expert enabled tools = get_implementation_tools(expert_enabled=True) assert len(tools) > 0 assert all(callable(tool) for tool in tools) - + # Test without expert tools_no_expert = get_implementation_tools(expert_enabled=False) assert len(tools_no_expert) < len(tools) + def test_get_web_research_tools(): # Test with expert enabled tools = get_web_research_tools(expert_enabled=True) assert len(tools) == 5 assert all(callable(tool) for tool in tools) - + # Get tool names and verify exact matches tool_names = [tool.name for tool in tools] - expected_names = ['emit_expert_context', 'ask_expert', 'web_search_tavily', 'emit_research_notes', 'task_completed'] + expected_names = [ + "emit_expert_context", + "ask_expert", + "web_search_tavily", + "emit_research_notes", + "task_completed", + ] assert sorted(tool_names) == sorted(expected_names) - + # Test without expert enabled tools_no_expert = get_web_research_tools(expert_enabled=False) assert len(tools_no_expert) == 3 assert all(callable(tool) for tool in tools_no_expert) - + # Verify exact tool names when expert is disabled tool_names_no_expert = [tool.name for tool in tools_no_expert] - assert sorted(tool_names_no_expert) == sorted(['web_search_tavily', 'emit_research_notes', 'task_completed']) + assert sorted(tool_names_no_expert) == sorted( + ["web_search_tavily", "emit_research_notes", "task_completed"] + ) diff --git a/tests/ra_aid/test_utils.py b/tests/ra_aid/test_utils.py index 0ce2ba1..293f78c 100644 --- a/tests/ra_aid/test_utils.py +++ b/tests/ra_aid/test_utils.py @@ -1,6 +1,5 @@ """Tests for utility functions.""" -import pytest from ra_aid.text.processing import truncate_output @@ -9,10 +8,10 @@ def test_normal_truncation(): # Create input with 10 lines input_lines = [f"Line {i}\n" for i in range(10)] input_text = "".join(input_lines) - + # Truncate to 5 lines result = truncate_output(input_text, max_lines=5) - + # Verify truncation message and content assert "[5 lines of output truncated]" in result assert "Line 5\n" in result @@ -25,7 +24,7 @@ def test_no_truncation_needed(): """Test when input is shorter than max_lines.""" input_text = "Line 1\nLine 2\nLine 3\n" result = truncate_output(input_text, max_lines=5) - + # Should return original text unchanged assert result == input_text assert "[lines of output truncated]" not in result @@ -42,9 +41,9 @@ def test_exact_max_lines(): # Create input with exactly 5 lines input_lines = [f"Line {i}\n" for i in range(5)] input_text = "".join(input_lines) - + result = truncate_output(input_text, max_lines=5) - + # Should return original text unchanged assert result == input_text assert "[lines of output truncated]" not in result @@ -54,9 +53,9 @@ def test_different_line_endings(): """Test with different line endings (\\n, \\r\\n, \\r).""" # Mix of different line endings input_text = "Line 1\nLine 2\r\nLine 3\rLine 4\nLine 5\r\nLine 6" - + result = truncate_output(input_text, max_lines=3) - + # Should preserve line endings in truncated output assert "[3 lines of output truncated]" in result assert "Line 4" in result @@ -71,12 +70,12 @@ def test_ansi_sequences(): "\033[31mRed Line 1\033[0m\n", "\033[32mGreen Line 2\033[0m\n", "\033[34mBlue Line 3\033[0m\n", - "\033[33mYellow Line 4\033[0m\n" + "\033[33mYellow Line 4\033[0m\n", ] input_text = "".join(input_lines) - + result = truncate_output(input_text, max_lines=2) - + # Should preserve ANSI sequences in truncated output assert "[2 lines of output truncated]" in result assert "\033[34mBlue Line 3\033[0m" in result @@ -89,10 +88,10 @@ def test_custom_max_lines(): # Create input with 100 lines input_lines = [f"Line {i}\n" for i in range(100)] input_text = "".join(input_lines) - + # Test with custom max_lines=10 result = truncate_output(input_text, max_lines=10) - + # Should have truncation message and last 10 lines assert "[90 lines of output truncated]" in result assert "Line 90\n" in result @@ -105,9 +104,9 @@ def test_no_trailing_newline(): """Test with input that doesn't end in newline.""" input_lines = [f"Line {i}" for i in range(10)] input_text = "\n".join(input_lines) # No trailing newline - + result = truncate_output(input_text, max_lines=5) - + # Should handle truncation correctly without trailing newline assert "[5 lines of output truncated]" in result assert "Line 5" in result diff --git a/tests/ra_aid/tools/test_execution.py b/tests/ra_aid/tools/test_execution.py index ee9951c..2d1961e 100644 --- a/tests/ra_aid/tools/test_execution.py +++ b/tests/ra_aid/tools/test_execution.py @@ -1,14 +1,15 @@ """Tests for test execution utilities.""" +from unittest.mock import patch + import pytest -from unittest.mock import Mock, patch + from ra_aid.tools.handle_user_defined_test_cmd_execution import execute_test_command # Test cases for execute_test_command test_cases = [ # Format: (name, config, original_prompt, test_attempts, auto_test, # mock_responses, expected_result) - # Case 1: No test command configured ( "no_test_command", @@ -17,9 +18,8 @@ test_cases = [ 0, False, {}, - (True, "original prompt", False, 0) + (True, "original prompt", False, 0), ), - # Case 2: User declines to run test ( "user_declines_test", @@ -28,9 +28,8 @@ test_cases = [ 0, False, {"ask_human_response": "n"}, - (True, "original prompt", False, 0) + (True, "original prompt", False, 0), ), - # Case 3: User enables auto-test ( "user_enables_auto_test", @@ -40,11 +39,10 @@ test_cases = [ False, { "ask_human_response": "a", - "shell_cmd_result": {"success": True, "output": "All tests passed"} + "shell_cmd_result": {"success": True, "output": "All tests passed"}, }, - (True, "original prompt", True, 1) + (True, "original prompt", True, 1), ), - # Case 4: Auto-test success ( "auto_test_success", @@ -53,9 +51,8 @@ test_cases = [ 0, True, {"shell_cmd_result": {"success": True, "output": "All tests passed"}}, - (True, "original prompt", True, 1) + (True, "original prompt", True, 1), ), - # Case 5: Auto-test failure with retry ( "auto_test_failure_retry", @@ -64,9 +61,13 @@ test_cases = [ 0, True, {"shell_cmd_result": {"success": False, "output": "Test failed"}}, - (False, "original prompt. Previous attempt failed with: Test failed", True, 1) + ( + False, + "original prompt. Previous attempt failed with: Test failed", + True, + 1, + ), ), - # Case 6: Max retries reached ( "max_retries_reached", @@ -75,9 +76,8 @@ test_cases = [ 3, True, {}, - (True, "original prompt", True, 3) + (True, "original prompt", True, 3), ), - # Case 7: User runs test manually ( "manual_test_success", @@ -87,11 +87,10 @@ test_cases = [ False, { "ask_human_response": "y", - "shell_cmd_result": {"success": True, "output": "All tests passed"} + "shell_cmd_result": {"success": True, "output": "All tests passed"}, }, - (True, "original prompt", False, 1) + (True, "original prompt", False, 1), ), - # Case 8: Manual test failure ( "manual_test_failure", @@ -101,11 +100,15 @@ test_cases = [ False, { "ask_human_response": "y", - "shell_cmd_result": {"success": False, "output": "Test failed"} + "shell_cmd_result": {"success": False, "output": "Test failed"}, }, - (False, "original prompt. Previous attempt failed with: Test failed", False, 1) + ( + False, + "original prompt. Previous attempt failed with: Test failed", + False, + 1, + ), ), - # Case 9: Manual test error ( "manual_test_error", @@ -115,11 +118,10 @@ test_cases = [ False, { "ask_human_response": "y", - "shell_cmd_result_error": Exception("Command failed") + "shell_cmd_result_error": Exception("Command failed"), }, - (True, "original prompt", False, 1) + (True, "original prompt", False, 1), ), - # Case 10: Auto-test error ( "auto_test_error", @@ -127,17 +129,16 @@ test_cases = [ "original prompt", 0, True, - { - "shell_cmd_result_error": Exception("Command failed") - }, - (True, "original prompt", True, 1) + {"shell_cmd_result_error": Exception("Command failed")}, + (True, "original prompt", True, 1), ), ] + @pytest.mark.parametrize( "name,config,original_prompt,test_attempts,auto_test,mock_responses,expected", test_cases, - ids=[case[0] for case in test_cases] + ids=[case[0] for case in test_cases], ) def test_execute_test_command( name: str, @@ -149,7 +150,7 @@ def test_execute_test_command( expected: tuple, ) -> None: """Test execute_test_command with different scenarios. - + Args: name: Test case name config: Test configuration @@ -159,61 +160,68 @@ def test_execute_test_command( mock_responses: Mock response data expected: Expected result tuple """ - with patch("ra_aid.tools.handle_user_defined_test_cmd_execution.ask_human") as mock_ask_human, \ - patch("ra_aid.tools.handle_user_defined_test_cmd_execution.run_shell_command") as mock_run_cmd, \ - patch("ra_aid.tools.handle_user_defined_test_cmd_execution.console") as mock_console, \ - patch("ra_aid.tools.handle_user_defined_test_cmd_execution.logger") as mock_logger: - + with ( + patch( + "ra_aid.tools.handle_user_defined_test_cmd_execution.ask_human" + ) as mock_ask_human, + patch( + "ra_aid.tools.handle_user_defined_test_cmd_execution.run_shell_command" + ) as mock_run_cmd, + patch( + "ra_aid.tools.handle_user_defined_test_cmd_execution.console" + ) as _mock_console, + patch( + "ra_aid.tools.handle_user_defined_test_cmd_execution.logger" + ) as mock_logger, + ): # Configure mocks based on mock_responses if "ask_human_response" in mock_responses: mock_ask_human.invoke.return_value = mock_responses["ask_human_response"] - + if "shell_cmd_result_error" in mock_responses: mock_run_cmd.side_effect = mock_responses["shell_cmd_result_error"] elif "shell_cmd_result" in mock_responses: mock_run_cmd.return_value = mock_responses["shell_cmd_result"] - + # Execute test command - result = execute_test_command( - config, - original_prompt, - test_attempts, - auto_test - ) - + result = execute_test_command(config, original_prompt, test_attempts, auto_test) + # Verify result matches expected assert result == expected, f"Test case '{name}' failed" - + # Verify mock interactions if config.get("test_cmd") and not auto_test: mock_ask_human.invoke.assert_called_once() - + if auto_test and test_attempts < config.get("max_test_cmd_retries", 5): if config.get("test_cmd"): # Verify run_shell_command called with command and default timeout - mock_run_cmd.assert_called_once_with(config["test_cmd"], timeout=config.get('timeout', 30)) - + mock_run_cmd.assert_called_once_with( + config["test_cmd"], timeout=config.get("timeout", 30) + ) + # Verify logging for max retries if test_attempts >= config.get("max_test_cmd_retries", 5): mock_logger.warning.assert_called_once_with("Max test retries reached") + def test_execute_test_command_error_handling() -> None: """Test error handling in execute_test_command.""" config = {"test_cmd": "pytest"} - - with patch("ra_aid.tools.handle_user_defined_test_cmd_execution.run_shell_command") as mock_run_cmd, \ - patch("ra_aid.tools.handle_user_defined_test_cmd_execution.logger") as mock_logger: - + + with ( + patch( + "ra_aid.tools.handle_user_defined_test_cmd_execution.run_shell_command" + ) as mock_run_cmd, + patch( + "ra_aid.tools.handle_user_defined_test_cmd_execution.logger" + ) as mock_logger, + ): # Simulate run_shell_command raising an exception mock_run_cmd.side_effect = Exception("Command failed") - - result = execute_test_command( - config, - "original prompt", - 0, - True - ) - + + result = execute_test_command(config, "original prompt", 0, True) + # Should handle error and continue assert result == (True, "original prompt", True, 1) - mock_logger.warning.assert_called_once() \ No newline at end of file + mock_logger.warning.assert_called_once() diff --git a/tests/ra_aid/tools/test_expert.py b/tests/ra_aid/tools/test_expert.py index 831e3a5..17c4acd 100644 --- a/tests/ra_aid/tools/test_expert.py +++ b/tests/ra_aid/tools/test_expert.py @@ -1,8 +1,13 @@ -import os +from unittest.mock import patch + import pytest -from pathlib import Path -from unittest.mock import patch, mock_open -from ra_aid.tools.expert import read_files_with_limit, emit_expert_context, expert_context + +from ra_aid.tools.expert import ( + emit_expert_context, + expert_context, + read_files_with_limit, +) + @pytest.fixture def temp_test_files(tmp_path): @@ -10,96 +15,106 @@ def temp_test_files(tmp_path): file1 = tmp_path / "test1.txt" file2 = tmp_path / "test2.txt" file3 = tmp_path / "test3.txt" - + file1.write_text("Line 1\nLine 2\nLine 3\n") file2.write_text("File 2 Line 1\nFile 2 Line 2\n") file3.write_text("") # Empty file - + return tmp_path, [file1, file2, file3] + def test_read_files_with_limit_basic(temp_test_files): """Test basic successful reading of multiple files.""" tmp_path, files = temp_test_files result = read_files_with_limit([str(f) for f in files]) - + assert "## File:" in result assert "Line 1" in result assert "File 2 Line 1" in result assert str(files[0]) in result assert str(files[1]) in result + def test_read_files_with_limit_empty_file(temp_test_files): """Test handling of empty files.""" tmp_path, files = temp_test_files result = read_files_with_limit([str(files[2])]) # Empty file assert result == "" # Empty files should be excluded from output + def test_read_files_with_limit_nonexistent_file(temp_test_files): """Test handling of nonexistent files.""" tmp_path, files = temp_test_files nonexistent = str(tmp_path / "nonexistent.txt") result = read_files_with_limit([str(files[0]), nonexistent]) - + assert "Line 1" in result # Should contain content from existing file assert "nonexistent.txt" not in result # Shouldn't include non-existent file + def test_read_files_with_limit_line_limit(temp_test_files): """Test enforcement of line limit.""" tmp_path, files = temp_test_files result = read_files_with_limit([str(files[0]), str(files[1])], max_lines=2) - + assert "truncated" in result assert "Line 1" in result assert "Line 2" in result assert "File 2 Line 1" not in result # Should be truncated before reaching file 2 -@patch('builtins.open') + +@patch("builtins.open") def test_read_files_with_limit_permission_error(mock_open_func, temp_test_files): """Test handling of permission errors.""" mock_open_func.side_effect = PermissionError("Permission denied") tmp_path, files = temp_test_files - + result = read_files_with_limit([str(files[0])]) assert result == "" # Should return empty string on permission error -@patch('builtins.open') + +@patch("builtins.open") def test_read_files_with_limit_io_error(mock_open_func, temp_test_files): """Test handling of IO errors.""" mock_open_func.side_effect = IOError("IO Error") tmp_path, files = temp_test_files - + result = read_files_with_limit([str(files[0])]) assert result == "" # Should return empty string on IO error + def test_read_files_with_limit_encoding_error(temp_test_files): """Test handling of encoding errors.""" tmp_path, files = temp_test_files - + # Create a file with invalid UTF-8 invalid_file = tmp_path / "invalid.txt" - with open(invalid_file, 'wb') as f: - f.write(b'\xFF\xFE\x00\x00') # Invalid UTF-8 - + with open(invalid_file, "wb") as f: + f.write(b"\xff\xfe\x00\x00") # Invalid UTF-8 + result = read_files_with_limit([str(invalid_file)]) assert result == "" # Should return empty string on encoding error + def test_expert_context_management(): """Test expert context global state management.""" # Clear any existing context - expert_context['text'].clear() - expert_context['files'].clear() - + expert_context["text"].clear() + expert_context["files"].clear() + # Test adding context result1 = emit_expert_context.invoke("Test context 1") assert "Context added" in result1 - assert len(expert_context['text']) == 1 - assert expert_context['text'][0] == "Test context 1" - + assert len(expert_context["text"]) == 1 + assert expert_context["text"][0] == "Test context 1" + # Test adding multiple contexts result2 = emit_expert_context.invoke("Test context 2") assert "Context added" in result2 - assert len(expert_context['text']) == 2 - assert expert_context['text'][1] == "Test context 2" - + assert len(expert_context["text"]) == 2 + assert expert_context["text"][1] == "Test context 2" + # Test context accumulation - assert all(ctx in expert_context['text'] for ctx in ["Test context 1", "Test context 2"]) + assert all( + ctx in expert_context["text"] for ctx in ["Test context 1", "Test context 2"] + ) diff --git a/tests/ra_aid/tools/test_file_str_replace.py b/tests/ra_aid/tools/test_file_str_replace.py index 181e076..265f5cb 100644 --- a/tests/ra_aid/tools/test_file_str_replace.py +++ b/tests/ra_aid/tools/test_file_str_replace.py @@ -1,9 +1,11 @@ import os +from unittest.mock import patch + import pytest -from pathlib import Path -from unittest.mock import patch, mock_open + from ra_aid.tools.file_str_replace import file_str_replace + @pytest.fixture def temp_test_dir(tmp_path): """Create a temporary test directory.""" @@ -11,164 +13,156 @@ def temp_test_dir(tmp_path): test_dir.mkdir(exist_ok=True) return test_dir + def test_basic_replacement(temp_test_dir): """Test basic string replacement functionality.""" test_file = temp_test_dir / "test.txt" initial_content = "Hello world! This is a test." test_file.write_text(initial_content) - - result = file_str_replace.invoke({ - "filepath": str(test_file), - "old_str": "world", - "new_str": "universe" - }) - + + result = file_str_replace.invoke( + {"filepath": str(test_file), "old_str": "world", "new_str": "universe"} + ) + assert result["success"] is True assert test_file.read_text() == "Hello universe! This is a test." assert "Successfully replaced" in result["message"] + def test_file_not_found(): """Test handling of non-existent file.""" - result = file_str_replace.invoke({ - "filepath": "nonexistent.txt", - "old_str": "test", - "new_str": "replacement" - }) - + result = file_str_replace.invoke( + {"filepath": "nonexistent.txt", "old_str": "test", "new_str": "replacement"} + ) + assert result["success"] is False assert "File not found" in result["message"] + def test_string_not_found(temp_test_dir): """Test handling of string not present in file.""" test_file = temp_test_dir / "test.txt" test_file.write_text("Hello world!") - - result = file_str_replace.invoke({ - "filepath": str(test_file), - "old_str": "nonexistent", - "new_str": "replacement" - }) - + + result = file_str_replace.invoke( + {"filepath": str(test_file), "old_str": "nonexistent", "new_str": "replacement"} + ) + assert result["success"] is False assert "String not found" in result["message"] + def test_multiple_occurrences(temp_test_dir): """Test handling of multiple string occurrences.""" test_file = temp_test_dir / "test.txt" test_file.write_text("test test test") - - result = file_str_replace.invoke({ - "filepath": str(test_file), - "old_str": "test", - "new_str": "replacement" - }) - + + result = file_str_replace.invoke( + {"filepath": str(test_file), "old_str": "test", "new_str": "replacement"} + ) + assert result["success"] is False assert "appears" in result["message"] assert "must be unique" in result["message"] + def test_empty_strings(temp_test_dir): """Test handling of empty strings.""" test_file = temp_test_dir / "test.txt" test_file.write_text("Hello world!") - + # Test empty old string - result1 = file_str_replace.invoke({ - "filepath": str(test_file), - "old_str": "", - "new_str": "replacement" - }) + result1 = file_str_replace.invoke( + {"filepath": str(test_file), "old_str": "", "new_str": "replacement"} + ) assert result1["success"] is False - + # Test empty new string - result2 = file_str_replace.invoke({ - "filepath": str(test_file), - "old_str": "world", - "new_str": "" - }) + result2 = file_str_replace.invoke( + {"filepath": str(test_file), "old_str": "world", "new_str": ""} + ) assert result2["success"] is True assert test_file.read_text() == "Hello !" + def test_special_characters(temp_test_dir): """Test handling of special characters.""" test_file = temp_test_dir / "test.txt" initial_content = "Hello\nworld!\t\r\nSpecial chars: $@#%" test_file.write_text(initial_content) - - result = file_str_replace.invoke({ - "filepath": str(test_file), - "old_str": "Special chars: $@#%", - "new_str": "Replaced!" - }) - + + result = file_str_replace.invoke( + { + "filepath": str(test_file), + "old_str": "Special chars: $@#%", + "new_str": "Replaced!", + } + ) + assert result["success"] is True assert "Special chars: $@#%" not in test_file.read_text() assert "Replaced!" in test_file.read_text() -@patch('pathlib.Path.read_text') + +@patch("pathlib.Path.read_text") def test_io_error(mock_read_text, temp_test_dir): """Test handling of IO errors during read.""" # Create and write to file first test_file = temp_test_dir / "test.txt" test_file.write_text("some test content") - + # Then mock read_text to raise error mock_read_text.side_effect = IOError("Failed to read file") - - result = file_str_replace.invoke({ - "filepath": str(test_file), - "old_str": "test", - "new_str": "replacement" - }) - + + result = file_str_replace.invoke( + {"filepath": str(test_file), "old_str": "test", "new_str": "replacement"} + ) + assert result["success"] is False assert "Failed to read file" in result["message"] + def test_permission_error(temp_test_dir): """Test handling of permission errors.""" test_file = temp_test_dir / "readonly.txt" test_file.write_text("test content") os.chmod(test_file, 0o444) # Make file read-only - + try: - result = file_str_replace.invoke({ - "filepath": str(test_file), - "old_str": "test", - "new_str": "replacement" - }) - + result = file_str_replace.invoke( + {"filepath": str(test_file), "old_str": "test", "new_str": "replacement"} + ) + assert result["success"] is False assert "Permission" in result["message"] or "Error" in result["message"] finally: os.chmod(test_file, 0o644) # Restore permissions for cleanup + def test_unicode_strings(temp_test_dir): """Test handling of Unicode strings.""" test_file = temp_test_dir / "unicode.txt" initial_content = "Hello δΈ–η•Œ! Unicode γƒ†γ‚Ήγƒˆ" test_file.write_text(initial_content) - - result = file_str_replace.invoke({ - "filepath": str(test_file), - "old_str": "δΈ–η•Œ", - "new_str": "ワールド" - }) - + + result = file_str_replace.invoke( + {"filepath": str(test_file), "old_str": "δΈ–η•Œ", "new_str": "ワールド"} + ) + assert result["success"] is True assert "δΈ–η•Œ" not in test_file.read_text() assert "ワールド" in test_file.read_text() + def test_long_string_truncation(temp_test_dir): """Test handling and truncation of very long strings.""" test_file = temp_test_dir / "test.txt" long_string = "x" * 100 test_file.write_text(f"prefix {long_string} suffix") - - result = file_str_replace.invoke({ - "filepath": str(test_file), - "old_str": long_string, - "new_str": "replaced" - }) - + + result = file_str_replace.invoke( + {"filepath": str(test_file), "old_str": long_string, "new_str": "replaced"} + ) + assert result["success"] is True assert test_file.read_text() == "prefix replaced suffix" diff --git a/tests/ra_aid/tools/test_fuzzy_find.py b/tests/ra_aid/tools/test_fuzzy_find.py index 53b9eee..b3c99d8 100644 --- a/tests/ra_aid/tools/test_fuzzy_find.py +++ b/tests/ra_aid/tools/test_fuzzy_find.py @@ -1,14 +1,15 @@ import pytest -from pytest import mark from git import Repo from git.exc import InvalidGitRepositoryError + from ra_aid.tools import fuzzy_find_project_files + @pytest.fixture def git_repo(tmp_path): """Create a temporary git repository with some test files""" repo = Repo.init(tmp_path) - + # Create some files (tmp_path / "main.py").write_text("print('hello')") (tmp_path / "test_main.py").write_text("def test_main(): pass") @@ -16,119 +17,117 @@ def git_repo(tmp_path): (tmp_path / "lib/utils.py").write_text("def util(): pass") (tmp_path / "lib/__pycache__").mkdir() (tmp_path / "lib/__pycache__/utils.cpython-39.pyc").write_text("cache") - + # Create some untracked files (tmp_path / "untracked.txt").write_text("untracked content") (tmp_path / "draft.py").write_text("# draft code") - + # Add and commit only some files repo.index.add(["main.py", "lib/utils.py"]) repo.index.commit("Initial commit") - + return tmp_path + def test_basic_fuzzy_search(git_repo): """Test basic fuzzy matching functionality""" - results = fuzzy_find_project_files.invoke({"search_term": "utils", "repo_path": str(git_repo)}) - + results = fuzzy_find_project_files.invoke( + {"search_term": "utils", "repo_path": str(git_repo)} + ) + assert len(results) >= 1 assert any("lib/utils.py" in match[0] for match in results) assert all(isinstance(match[1], int) for match in results) + def test_threshold_filtering(git_repo): """Test threshold parameter behavior""" # Should match with high threshold - results_high = fuzzy_find_project_files.invoke({ - "search_term": "main", - "threshold": 80, - "repo_path": str(git_repo) - }) + results_high = fuzzy_find_project_files.invoke( + {"search_term": "main", "threshold": 80, "repo_path": str(git_repo)} + ) assert len(results_high) >= 1 assert any("main.py" in match[0] for match in results_high) - + # Should not match with very high threshold - results_very_high = fuzzy_find_project_files.invoke({ - "search_term": "mian", - "threshold": 99, - "repo_path": str(git_repo) - }) + results_very_high = fuzzy_find_project_files.invoke( + {"search_term": "mian", "threshold": 99, "repo_path": str(git_repo)} + ) assert len(results_very_high) == 0 + def test_max_results_limit(git_repo): """Test max_results parameter""" max_results = 1 - results = fuzzy_find_project_files.invoke({ - "search_term": "py", - "max_results": max_results, - "repo_path": str(git_repo) - }) + results = fuzzy_find_project_files.invoke( + {"search_term": "py", "max_results": max_results, "repo_path": str(git_repo)} + ) assert len(results) <= max_results + def test_include_paths_filter(git_repo): """Test include_paths filtering""" - results = fuzzy_find_project_files.invoke({ - "search_term": "py", - "include_paths": ["lib/*"], - "repo_path": str(git_repo) - }) + results = fuzzy_find_project_files.invoke( + {"search_term": "py", "include_paths": ["lib/*"], "repo_path": str(git_repo)} + ) assert all("lib/" in match[0] for match in results) + def test_exclude_patterns_filter(git_repo): """Test exclude_patterns filtering""" - results = fuzzy_find_project_files.invoke({ - "search_term": "py", - "exclude_patterns": ["*test*"], - "repo_path": str(git_repo) - }) + results = fuzzy_find_project_files.invoke( + { + "search_term": "py", + "exclude_patterns": ["*test*"], + "repo_path": str(git_repo), + } + ) assert not any("test" in match[0] for match in results) + def test_invalid_threshold(): """Test error handling for invalid threshold""" with pytest.raises(ValueError): - fuzzy_find_project_files.invoke({ - "search_term": "test", - "threshold": 101 - }) + fuzzy_find_project_files.invoke({"search_term": "test", "threshold": 101}) + def test_non_git_repo(tmp_path): """Test error handling outside git repo""" with pytest.raises(InvalidGitRepositoryError): - fuzzy_find_project_files.invoke({ - "search_term": "test", - "repo_path": str(tmp_path) - }) + fuzzy_find_project_files.invoke( + {"search_term": "test", "repo_path": str(tmp_path)} + ) + def test_exact_match(git_repo): """Test exact matching returns 100% score""" - results = fuzzy_find_project_files.invoke({ - "search_term": "main.py", - "repo_path": str(git_repo) - }) + results = fuzzy_find_project_files.invoke( + {"search_term": "main.py", "repo_path": str(git_repo)} + ) assert len(results) >= 1 assert any(match[1] == 100 for match in results) + def test_empty_search_term(git_repo): """Test behavior with empty search term""" - results = fuzzy_find_project_files.invoke({ - "search_term": "", - "repo_path": str(git_repo) - }) + results = fuzzy_find_project_files.invoke( + {"search_term": "", "repo_path": str(git_repo)} + ) assert len(results) == 0 + def test_untracked_files(git_repo): """Test that untracked files are included in search results""" - results = fuzzy_find_project_files.invoke({ - "search_term": "untracked", - "repo_path": str(git_repo) - }) + results = fuzzy_find_project_files.invoke( + {"search_term": "untracked", "repo_path": str(git_repo)} + ) assert len(results) >= 1 assert any("untracked.txt" in match[0] for match in results) + def test_no_matches(git_repo): """Test behavior when no files match the search term""" - results = fuzzy_find_project_files.invoke({ - "search_term": "nonexistentfile", - "threshold": 80, - "repo_path": str(git_repo) - }) + results = fuzzy_find_project_files.invoke( + {"search_term": "nonexistentfile", "threshold": 80, "repo_path": str(git_repo)} + ) assert len(results) == 0 diff --git a/tests/ra_aid/tools/test_handle_user_defined_test_cmd_execution.py b/tests/ra_aid/tools/test_handle_user_defined_test_cmd_execution.py index 1915f00..d6e4c85 100644 --- a/tests/ra_aid/tools/test_handle_user_defined_test_cmd_execution.py +++ b/tests/ra_aid/tools/test_handle_user_defined_test_cmd_execution.py @@ -1,41 +1,44 @@ """Tests for user-defined test command execution utilities.""" -import pytest -from unittest.mock import patch, Mock import subprocess +from unittest.mock import Mock, patch + +import pytest + from ra_aid.tools.handle_user_defined_test_cmd_execution import ( - TestState, TestCommandExecutor, - execute_test_command + TestState, + execute_test_command, ) + @pytest.fixture def test_state(): """Create a test state fixture.""" return TestState( - prompt="test prompt", - test_attempts=0, - auto_test=False, - should_break=False + prompt="test prompt", test_attempts=0, auto_test=False, should_break=False ) + @pytest.fixture def test_executor(): """Create a test executor fixture.""" config = {"test_cmd": "test", "max_test_cmd_retries": 3} return TestCommandExecutor(config, "test prompt") + def test_check_max_retries(test_executor): """Test max retries check.""" test_executor.state.test_attempts = 2 assert not test_executor.check_max_retries() - + test_executor.state.test_attempts = 3 assert test_executor.check_max_retries() - + test_executor.state.test_attempts = 4 assert test_executor.check_max_retries() + def test_handle_test_failure(test_executor): """Test handling of test failures.""" test_result = {"output": "error message"} @@ -44,79 +47,100 @@ def test_handle_test_failure(test_executor): assert not test_executor.state.should_break assert "error message" in test_executor.state.prompt + def test_run_test_command_success(test_executor): """Test successful test command execution.""" - with patch("ra_aid.tools.handle_user_defined_test_cmd_execution.run_shell_command") as mock_run: + with patch( + "ra_aid.tools.handle_user_defined_test_cmd_execution.run_shell_command" + ) as mock_run: mock_run.return_value = {"success": True, "output": ""} test_executor.run_test_command("test", "original") assert test_executor.state.should_break assert test_executor.state.test_attempts == 1 + def test_run_test_command_failure(test_executor): """Test failed test command execution.""" - with patch("ra_aid.tools.handle_user_defined_test_cmd_execution.run_shell_command") as mock_run: + with patch( + "ra_aid.tools.handle_user_defined_test_cmd_execution.run_shell_command" + ) as mock_run: mock_run.return_value = {"success": False, "output": "error"} test_executor.run_test_command("test", "original") assert not test_executor.state.should_break assert test_executor.state.test_attempts == 1 assert "error" in test_executor.state.prompt + def test_run_test_command_error(test_executor): """Test test command execution error.""" - with patch("ra_aid.tools.handle_user_defined_test_cmd_execution.run_shell_command") as mock_run: + with patch( + "ra_aid.tools.handle_user_defined_test_cmd_execution.run_shell_command" + ) as mock_run: mock_run.side_effect = Exception("Generic error") test_executor.run_test_command("test", "original") assert test_executor.state.should_break assert test_executor.state.test_attempts == 1 + def test_run_test_command_timeout(test_executor): """Test test command timeout handling.""" - with patch("ra_aid.tools.handle_user_defined_test_cmd_execution.run_shell_command") as mock_run,\ - patch("ra_aid.tools.handle_user_defined_test_cmd_execution.logger.warning") as mock_logger: - + with ( + patch( + "ra_aid.tools.handle_user_defined_test_cmd_execution.run_shell_command" + ) as mock_run, + patch( + "ra_aid.tools.handle_user_defined_test_cmd_execution.logger.warning" + ) as mock_logger, + ): # Create a TimeoutExpired exception timeout_exc = subprocess.TimeoutExpired(cmd="test", timeout=30) mock_run.side_effect = timeout_exc - + test_executor.run_test_command("test", "original") - + # Verify state updates assert not test_executor.state.should_break assert test_executor.state.test_attempts == 1 assert "timed out after 30 seconds" in test_executor.state.prompt - + # Verify logging mock_logger.assert_called_once() + def test_run_test_command_called_process_error(test_executor): """Test handling of CalledProcessError exception.""" - with patch("ra_aid.tools.handle_user_defined_test_cmd_execution.run_shell_command") as mock_run,\ - patch("ra_aid.tools.handle_user_defined_test_cmd_execution.logger.error") as mock_logger: - + with ( + patch( + "ra_aid.tools.handle_user_defined_test_cmd_execution.run_shell_command" + ) as mock_run, + patch( + "ra_aid.tools.handle_user_defined_test_cmd_execution.logger.error" + ) as mock_logger, + ): # Create a CalledProcessError exception process_error = subprocess.CalledProcessError( - returncode=1, - cmd="test", - output="Command failed output" + returncode=1, cmd="test", output="Command failed output" ) mock_run.side_effect = process_error - + test_executor.run_test_command("test", "original") - + # Verify state updates assert not test_executor.state.should_break assert test_executor.state.test_attempts == 1 assert "failed with exit code 1" in test_executor.state.prompt - + # Verify logging mock_logger.assert_called_once() + def test_handle_user_response_no(test_executor): """Test handling of 'no' response.""" test_executor.handle_user_response("n", "test", "original") assert test_executor.state.should_break assert not test_executor.state.auto_test + def test_handle_user_response_auto(test_executor): """Test handling of 'auto' response.""" with patch.object(test_executor, "run_test_command") as mock_run: @@ -124,6 +148,7 @@ def test_handle_user_response_auto(test_executor): assert test_executor.state.auto_test mock_run.assert_called_once_with("test", "original") + def test_handle_user_response_yes(test_executor): """Test handling of 'yes' response.""" with patch.object(test_executor, "run_test_command") as mock_run: @@ -131,54 +156,66 @@ def test_handle_user_response_yes(test_executor): assert not test_executor.state.auto_test mock_run.assert_called_once_with("test", "original") + def test_execute_no_cmd(): """Test execution with no test command.""" executor = TestCommandExecutor({}, "prompt") result = executor.execute() assert result == (True, "prompt", False, 0) + def test_execute_manual(): """Test manual test execution.""" config = {"test_cmd": "test"} executor = TestCommandExecutor(config, "prompt") - + def mock_handle_response(response, cmd, prompt): # Simulate the behavior of handle_user_response and run_test_command executor.state.should_break = True executor.state.test_attempts = 1 executor.state.prompt = "new prompt" - - with patch("ra_aid.tools.handle_user_defined_test_cmd_execution.ask_human") as mock_ask, \ - patch.object(executor, "handle_user_response", side_effect=mock_handle_response) as mock_handle: + + with ( + patch( + "ra_aid.tools.handle_user_defined_test_cmd_execution.ask_human" + ) as mock_ask, + patch.object( + executor, "handle_user_response", side_effect=mock_handle_response + ) as mock_handle, + ): mock_ask.invoke.return_value = "y" - + result = executor.execute() mock_handle.assert_called_once_with("y", "test", "prompt") assert result == (True, "new prompt", False, 1) + def test_execute_auto(): """Test auto test execution.""" config = {"test_cmd": "test", "max_test_cmd_retries": 3} executor = TestCommandExecutor(config, "prompt", auto_test=True) - + # Set up state before creating mock executor.state.test_attempts = 1 executor.state.should_break = True - + with patch.object(executor, "run_test_command") as mock_run: result = executor.execute() assert result == (True, "prompt", True, 1) mock_run.assert_called_once_with("test", "prompt") + def test_execute_test_command_function(): """Test the execute_test_command function.""" config = {"test_cmd": "test"} - with patch("ra_aid.tools.handle_user_defined_test_cmd_execution.TestCommandExecutor") as mock_executor_class: + with patch( + "ra_aid.tools.handle_user_defined_test_cmd_execution.TestCommandExecutor" + ) as mock_executor_class: mock_executor = Mock() mock_executor.execute.return_value = (True, "new prompt", True, 1) mock_executor_class.return_value = mock_executor - + result = execute_test_command(config, "prompt", auto_test=True) assert result == (True, "new prompt", True, 1) mock_executor_class.assert_called_once_with(config, "prompt", 0, True) - mock_executor.execute.assert_called_once() \ No newline at end of file + mock_executor.execute.assert_called_once() diff --git a/tests/ra_aid/tools/test_list_directory.py b/tests/ra_aid/tools/test_list_directory.py index 3d182a0..91a8d66 100644 --- a/tests/ra_aid/tools/test_list_directory.py +++ b/tests/ra_aid/tools/test_list_directory.py @@ -1,126 +1,137 @@ -import os -import pytest import tempfile from datetime import datetime from pathlib import Path + +import pytest + from ra_aid.tools import list_directory_tree from ra_aid.tools.list_directory import load_gitignore_patterns, should_ignore EXPECTED_YEAR = str(datetime.now().year) + @pytest.fixture def temp_dir(): """Create a temporary directory for testing""" with tempfile.TemporaryDirectory() as tmpdir: yield Path(tmpdir) + def create_test_directory_structure(path: Path): """Create a test directory structure""" # Create files (path / "file1.txt").write_text("content1") (path / "file2.py").write_text("content2") (path / ".hidden").write_text("hidden") - + # Create subdirectories subdir1 = path / "subdir1" subdir1.mkdir() (subdir1 / "subfile1.txt").write_text("subcontent1") (subdir1 / "subfile2.py").write_text("subcontent2") - + subdir2 = path / "subdir2" subdir2.mkdir() (subdir2 / ".git").mkdir() (subdir2 / "__pycache__").mkdir() + def test_list_directory_basic(temp_dir): """Test basic directory listing functionality""" create_test_directory_structure(temp_dir) - - result = list_directory_tree.invoke({ - "path": str(temp_dir), - "max_depth": 2, - "follow_links": False - }) - + + result = list_directory_tree.invoke( + {"path": str(temp_dir), "max_depth": 2, "follow_links": False} + ) + # Check basic structure assert isinstance(result, str) assert "file1.txt" in result assert "file2.py" in result assert "subdir1" in result assert "subdir2" in result - + # Hidden files should be excluded by default assert ".hidden" not in result assert ".git" not in result assert "__pycache__" not in result - + # File details should not be present by default assert "bytes" not in result.lower() assert "2024-" not in result + def test_list_directory_with_details(temp_dir): """Test directory listing with file details""" create_test_directory_structure(temp_dir) - - result = list_directory_tree.invoke({ - "path": str(temp_dir), - "max_depth": 2, - "show_size": True, - "show_modified": True - }) - + + result = list_directory_tree.invoke( + { + "path": str(temp_dir), + "max_depth": 2, + "show_size": True, + "show_modified": True, + } + ) + # File details should be present assert "bytes" in result.lower() or "kb" in result.lower() or "b" in result.lower() assert f"{EXPECTED_YEAR}-" in result + def test_list_directory_depth_limit(temp_dir): """Test max_depth parameter""" create_test_directory_structure(temp_dir) - + # Test with depth 1 (default) - result = list_directory_tree.invoke({ - "path": str(temp_dir) # Use defaults - }) - + result = list_directory_tree.invoke( + { + "path": str(temp_dir) # Use defaults + } + ) + assert isinstance(result, str) assert "subdir1" in result # Directory name should be visible assert "subfile1.txt" not in result # But not its contents assert "subfile2.py" not in result + def test_list_directory_ignore_patterns(temp_dir): """Test exclude patterns""" create_test_directory_structure(temp_dir) - - result = list_directory_tree.invoke({ - "path": str(temp_dir), - "max_depth": 2, - "exclude_patterns": ["*.py"] - }) - + + result = list_directory_tree.invoke( + {"path": str(temp_dir), "max_depth": 2, "exclude_patterns": ["*.py"]} + ) + assert isinstance(result, str) assert "file1.txt" in result assert "file2.py" not in result assert "subfile2.py" not in result + def test_gitignore_patterns(): """Test gitignore pattern loading and matching""" with tempfile.TemporaryDirectory() as tmpdir: path = Path(tmpdir) - + # Create a .gitignore file (path / ".gitignore").write_text("*.log\n*.tmp\n") - + spec = load_gitignore_patterns(path) - + assert should_ignore("test.log", spec) is True assert should_ignore("test.tmp", spec) is True assert should_ignore("test.txt", spec) is False assert should_ignore("dir/test.log", spec) is True + def test_invalid_path(): """Test error handling for invalid paths""" with pytest.raises(ValueError, match="Path does not exist"): list_directory_tree.invoke({"path": "/nonexistent/path"}) - + with pytest.raises(ValueError, match="Path is not a directory"): - list_directory_tree.invoke({"path": __file__}) # Try to list the test file itself + list_directory_tree.invoke( + {"path": __file__} + ) # Try to list the test file itself diff --git a/tests/ra_aid/tools/test_memory.py b/tests/ra_aid/tools/test_memory.py index 80cc430..4689f0a 100644 --- a/tests/ra_aid/tools/test_memory.py +++ b/tests/ra_aid/tools/test_memory.py @@ -1,105 +1,113 @@ import pytest + from ra_aid.tools.memory import ( _global_memory, - get_memory_value, - emit_key_facts, delete_key_facts, - emit_key_snippets, delete_key_snippets, - emit_related_files, - get_related_files, - deregister_related_files, - emit_task, delete_tasks, - swap_task_order, + deregister_related_files, + emit_key_facts, + emit_key_snippets, + emit_related_files, + emit_task, + get_memory_value, + get_related_files, + get_work_log, log_work_event, reset_work_log, - get_work_log + swap_task_order, ) + @pytest.fixture def reset_memory(): """Reset global memory before each test""" - _global_memory['key_facts'] = {} - _global_memory['key_fact_id_counter'] = 0 - _global_memory['key_snippets'] = {} - _global_memory['key_snippet_id_counter'] = 0 - _global_memory['research_notes'] = [] - _global_memory['plans'] = [] - _global_memory['tasks'] = {} - _global_memory['task_id_counter'] = 0 - _global_memory['related_files'] = {} - _global_memory['related_file_id_counter'] = 0 - _global_memory['work_log'] = [] + _global_memory["key_facts"] = {} + _global_memory["key_fact_id_counter"] = 0 + _global_memory["key_snippets"] = {} + _global_memory["key_snippet_id_counter"] = 0 + _global_memory["research_notes"] = [] + _global_memory["plans"] = [] + _global_memory["tasks"] = {} + _global_memory["task_id_counter"] = 0 + _global_memory["related_files"] = {} + _global_memory["related_file_id_counter"] = 0 + _global_memory["work_log"] = [] yield # Clean up after test - _global_memory['key_facts'] = {} - _global_memory['key_fact_id_counter'] = 0 - _global_memory['key_snippets'] = {} - _global_memory['key_snippet_id_counter'] = 0 - _global_memory['research_notes'] = [] - _global_memory['plans'] = [] - _global_memory['tasks'] = {} - _global_memory['task_id_counter'] = 0 - _global_memory['related_files'] = {} - _global_memory['related_file_id_counter'] = 0 - _global_memory['work_log'] = [] + _global_memory["key_facts"] = {} + _global_memory["key_fact_id_counter"] = 0 + _global_memory["key_snippets"] = {} + _global_memory["key_snippet_id_counter"] = 0 + _global_memory["research_notes"] = [] + _global_memory["plans"] = [] + _global_memory["tasks"] = {} + _global_memory["task_id_counter"] = 0 + _global_memory["related_files"] = {} + _global_memory["related_file_id_counter"] = 0 + _global_memory["work_log"] = [] + def test_emit_key_facts_single_fact(reset_memory): """Test emitting a single key fact using emit_key_facts""" # Test with single fact result = emit_key_facts.invoke({"facts": ["First fact"]}) assert result == "Facts stored." - assert _global_memory['key_facts'][0] == "First fact" - assert _global_memory['key_fact_id_counter'] == 1 + assert _global_memory["key_facts"][0] == "First fact" + assert _global_memory["key_fact_id_counter"] == 1 + def test_delete_key_facts_single_fact(reset_memory): """Test deleting a single key fact using delete_key_facts""" # Add a fact emit_key_facts.invoke({"facts": ["Test fact"]}) - + # Delete the fact result = delete_key_facts.invoke({"fact_ids": [0]}) assert result == "Facts deleted." - assert 0 not in _global_memory['key_facts'] + assert 0 not in _global_memory["key_facts"] + def test_delete_key_facts_invalid(reset_memory): """Test deleting non-existent facts returns empty list""" # Try to delete non-existent fact result = delete_key_facts.invoke({"fact_ids": [999]}) assert result == "Facts deleted." - + # Add and delete a fact, then try to delete it again emit_key_facts.invoke({"facts": ["Test fact"]}) delete_key_facts.invoke({"fact_ids": [0]}) result = delete_key_facts.invoke({"fact_ids": [0]}) assert result == "Facts deleted." + def test_get_memory_value_key_facts(reset_memory): """Test get_memory_value with key facts dictionary""" # Empty key facts should return empty string - assert get_memory_value('key_facts') == "" - + assert get_memory_value("key_facts") == "" + # Add some facts emit_key_facts.invoke({"facts": ["First fact", "Second fact"]}) - + # Should return markdown formatted list expected = "## πŸ”‘ Key Fact #0\n\nFirst fact\n\n## πŸ”‘ Key Fact #1\n\nSecond fact" - assert get_memory_value('key_facts') == expected + assert get_memory_value("key_facts") == expected + def test_get_memory_value_other_types(reset_memory): """Test get_memory_value remains compatible with other memory types""" # Add some research notes - _global_memory['research_notes'].append("Note 1") - _global_memory['research_notes'].append("Note 2") - - assert get_memory_value('research_notes') == "Note 1\nNote 2" - + _global_memory["research_notes"].append("Note 1") + _global_memory["research_notes"].append("Note 2") + + assert get_memory_value("research_notes") == "Note 1\nNote 2" + # Test with empty list - assert get_memory_value('plans') == "" - + assert get_memory_value("plans") == "" + # Test with non-existent key - assert get_memory_value('nonexistent') == "" + assert get_memory_value("nonexistent") == "" + def test_log_work_event(reset_memory): """Test logging work events with timestamps""" @@ -107,85 +115,91 @@ def test_log_work_event(reset_memory): log_work_event("Started task") log_work_event("Made progress") log_work_event("Completed task") - + # Verify events are stored - assert len(_global_memory['work_log']) == 3 - + assert len(_global_memory["work_log"]) == 3 + # Check event structure - event = _global_memory['work_log'][0] - assert isinstance(event['timestamp'], str) - assert event['event'] == "Started task" - + event = _global_memory["work_log"][0] + assert isinstance(event["timestamp"], str) + assert event["event"] == "Started task" + # Verify order - assert _global_memory['work_log'][1]['event'] == "Made progress" - assert _global_memory['work_log'][2]['event'] == "Completed task" + assert _global_memory["work_log"][1]["event"] == "Made progress" + assert _global_memory["work_log"][2]["event"] == "Completed task" + def test_get_work_log(reset_memory): """Test work log formatting with heading-based markdown""" # Test empty log assert get_work_log() == "No work log entries" - + # Add some events log_work_event("First event") log_work_event("Second event") - + # Get formatted log log = get_work_log() - + assert "First event" in log assert "Second event" in log + def test_reset_work_log(reset_memory): """Test resetting the work log""" # Add some events log_work_event("Test event") - assert len(_global_memory['work_log']) == 1 - + assert len(_global_memory["work_log"]) == 1 + # Reset log reset_work_log() - + # Verify log is empty - assert len(_global_memory['work_log']) == 0 - assert get_memory_value('work_log') == "" + assert len(_global_memory["work_log"]) == 0 + assert get_memory_value("work_log") == "" + def test_empty_work_log(reset_memory): """Test empty work log behavior""" # Fresh work log should return empty string - assert get_memory_value('work_log') == "" + assert get_memory_value("work_log") == "" + def test_emit_key_facts(reset_memory): """Test emitting multiple key facts at once""" # Test emitting multiple facts facts = ["First fact", "Second fact", "Third fact"] result = emit_key_facts.invoke({"facts": facts}) - + # Verify return message assert result == "Facts stored." - + # Verify facts stored in memory with correct IDs - assert _global_memory['key_facts'][0] == "First fact" - assert _global_memory['key_facts'][1] == "Second fact" - assert _global_memory['key_facts'][2] == "Third fact" - + assert _global_memory["key_facts"][0] == "First fact" + assert _global_memory["key_facts"][1] == "Second fact" + assert _global_memory["key_facts"][2] == "Third fact" + # Verify counter incremented correctly - assert _global_memory['key_fact_id_counter'] == 3 + assert _global_memory["key_fact_id_counter"] == 3 + def test_delete_key_facts(reset_memory): """Test deleting multiple key facts""" # Add some test facts emit_key_facts.invoke({"facts": ["First fact", "Second fact", "Third fact"]}) - + # Test deleting mix of existing and non-existing IDs result = delete_key_facts.invoke({"fact_ids": [0, 1, 999]}) - + # Verify success message assert result == "Facts deleted." - + # Verify correct facts removed from memory - assert 0 not in _global_memory['key_facts'] - assert 1 not in _global_memory['key_facts'] - assert 2 in _global_memory['key_facts'] # ID 2 should remain - assert _global_memory['key_facts'][2] == "Third fact" + assert 0 not in _global_memory["key_facts"] + assert 1 not in _global_memory["key_facts"] + assert 2 in _global_memory["key_facts"] # ID 2 should remain + assert _global_memory["key_facts"][2] == "Third fact" + def test_emit_key_snippets(reset_memory): """Test emitting multiple code snippets at once""" @@ -195,28 +209,29 @@ def test_emit_key_snippets(reset_memory): "filepath": "test.py", "line_number": 10, "snippet": "def test():\n pass", - "description": "Test function" + "description": "Test function", }, { "filepath": "main.py", "line_number": 20, "snippet": "print('hello')", - "description": None - } + "description": None, + }, ] - + # Emit snippets result = emit_key_snippets.invoke({"snippets": snippets}) - + # Verify return message assert result == "Snippets stored." - + # Verify snippets stored correctly - assert _global_memory['key_snippets'][0] == snippets[0] - assert _global_memory['key_snippets'][1] == snippets[1] - + assert _global_memory["key_snippets"][0] == snippets[0] + assert _global_memory["key_snippets"][1] == snippets[1] + # Verify counter incremented correctly - assert _global_memory['key_snippet_id_counter'] == 2 + assert _global_memory["key_snippet_id_counter"] == 2 + def test_delete_key_snippets(reset_memory): """Test deleting multiple code snippets""" @@ -226,34 +241,35 @@ def test_delete_key_snippets(reset_memory): "filepath": "test1.py", "line_number": 1, "snippet": "code1", - "description": None + "description": None, }, { "filepath": "test2.py", "line_number": 2, "snippet": "code2", - "description": None + "description": None, }, { "filepath": "test3.py", "line_number": 3, "snippet": "code3", - "description": None - } + "description": None, + }, ] emit_key_snippets.invoke({"snippets": snippets}) - + # Test deleting mix of valid and invalid IDs result = delete_key_snippets.invoke({"snippet_ids": [0, 1, 999]}) - + # Verify success message assert result == "Snippets deleted." - + # Verify correct snippets removed - assert 0 not in _global_memory['key_snippets'] - assert 1 not in _global_memory['key_snippets'] - assert 2 in _global_memory['key_snippets'] - assert _global_memory['key_snippets'][2]['filepath'] == "test3.py" + assert 0 not in _global_memory["key_snippets"] + assert 1 not in _global_memory["key_snippets"] + assert 2 in _global_memory["key_snippets"] + assert _global_memory["key_snippets"][2]["filepath"] == "test3.py" + def test_delete_key_snippets_empty(reset_memory): """Test deleting snippets with empty ID list""" @@ -262,16 +278,17 @@ def test_delete_key_snippets_empty(reset_memory): "filepath": "test.py", "line_number": 1, "snippet": "code", - "description": None + "description": None, } emit_key_snippets.invoke({"snippets": [snippet]}) - + # Test with empty list result = delete_key_snippets.invoke({"snippet_ids": []}) assert result == "Snippets deleted." - + # Verify snippet still exists - assert 0 in _global_memory['key_snippets'] + assert 0 in _global_memory["key_snippets"] + def test_emit_related_files_basic(reset_memory, tmp_path): """Test basic adding of files with ID tracking""" @@ -282,24 +299,26 @@ def test_emit_related_files_basic(reset_memory, tmp_path): main_file.write_text("# Main file") utils_file = tmp_path / "utils.py" utils_file.write_text("# Utils file") - + # Test adding single file result = emit_related_files.invoke({"files": [str(test_file)]}) assert result == f"File ID #0: {test_file}" - assert _global_memory['related_files'][0] == str(test_file) - + assert _global_memory["related_files"][0] == str(test_file) + # Test adding multiple files result = emit_related_files.invoke({"files": [str(main_file), str(utils_file)]}) assert result == f"File ID #1: {main_file}\nFile ID #2: {utils_file}" # Verify both files exist in related_files - values = list(_global_memory['related_files'].values()) + values = list(_global_memory["related_files"].values()) assert str(main_file) in values assert str(utils_file) in values + def test_get_related_files_empty(reset_memory): """Test getting related files when none added""" assert get_related_files() == [] + def test_emit_related_files_duplicates(reset_memory, tmp_path): """Test that duplicate files return existing IDs with proper formatting""" # Create test files @@ -309,21 +328,22 @@ def test_emit_related_files_duplicates(reset_memory, tmp_path): main_file.write_text("# Main file") new_file = tmp_path / "new.py" new_file.write_text("# New file") - + # Add initial files result = emit_related_files.invoke({"files": [str(test_file), str(main_file)]}) assert result == f"File ID #0: {test_file}\nFile ID #1: {main_file}" - first_id = 0 # ID of test.py - + _first_id = 0 # ID of test.py + # Try adding duplicates result = emit_related_files.invoke({"files": [str(test_file)]}) assert result == f"File ID #0: {test_file}" # Should return same ID - assert len(_global_memory['related_files']) == 2 # Count should not increase - + assert len(_global_memory["related_files"]) == 2 # Count should not increase + # Try mix of new and duplicate files result = emit_related_files.invoke({"files": [str(test_file), str(new_file)]}) assert result == f"File ID #0: {test_file}\nFile ID #2: {new_file}" - assert len(_global_memory['related_files']) == 3 + assert len(_global_memory["related_files"]) == 3 + def test_related_files_id_tracking(reset_memory, tmp_path): """Test ID assignment and counter functionality for related files""" @@ -332,20 +352,21 @@ def test_related_files_id_tracking(reset_memory, tmp_path): file1.write_text("# File 1") file2 = tmp_path / "file2.py" file2.write_text("# File 2") - + # Add first file result = emit_related_files.invoke({"files": [str(file1)]}) assert result == f"File ID #0: {file1}" - assert _global_memory['related_file_id_counter'] == 1 - + assert _global_memory["related_file_id_counter"] == 1 + # Add second file result = emit_related_files.invoke({"files": [str(file2)]}) assert result == f"File ID #1: {file2}" - assert _global_memory['related_file_id_counter'] == 2 - + assert _global_memory["related_file_id_counter"] == 2 + # Verify all files stored correctly - assert _global_memory['related_files'][0] == str(file1) - assert _global_memory['related_files'][1] == str(file2) + assert _global_memory["related_files"][0] == str(file1) + assert _global_memory["related_files"][1] == str(file2) + def test_deregister_related_files(reset_memory, tmp_path): """Test deleting related files""" @@ -356,41 +377,43 @@ def test_deregister_related_files(reset_memory, tmp_path): file2.write_text("# File 2") file3 = tmp_path / "file3.py" file3.write_text("# File 3") - + # Add test files emit_related_files.invoke({"files": [str(file1), str(file2), str(file3)]}) - + # Delete middle file result = deregister_related_files.invoke({"file_ids": [1]}) assert result == "File references removed." - assert 1 not in _global_memory['related_files'] - assert len(_global_memory['related_files']) == 2 - + assert 1 not in _global_memory["related_files"] + assert len(_global_memory["related_files"]) == 2 + # Delete multiple files including non-existent ID result = deregister_related_files.invoke({"file_ids": [0, 2, 999]}) assert result == "File references removed." - assert len(_global_memory['related_files']) == 0 - + assert len(_global_memory["related_files"]) == 0 + # Counter should remain unchanged after deletions - assert _global_memory['related_file_id_counter'] == 3 + assert _global_memory["related_file_id_counter"] == 3 + def test_related_files_duplicates(reset_memory, tmp_path): """Test duplicate file handling returns same ID""" # Create test file test_file = tmp_path / "test.py" test_file.write_text("# Test file") - + # Add initial file result1 = emit_related_files.invoke({"files": [str(test_file)]}) assert result1 == f"File ID #0: {test_file}" - + # Add same file again result2 = emit_related_files.invoke({"files": [str(test_file)]}) assert result2 == f"File ID #0: {test_file}" - + # Verify only one entry exists - assert len(_global_memory['related_files']) == 1 - assert _global_memory['related_file_id_counter'] == 1 + assert len(_global_memory["related_files"]) == 1 + assert _global_memory["related_file_id_counter"] == 1 + def test_emit_related_files_with_directory(reset_memory, tmp_path): """Test that directories and non-existent paths are rejected while valid files are added""" @@ -400,24 +423,25 @@ def test_emit_related_files_with_directory(reset_memory, tmp_path): test_file = tmp_path / "test_file.txt" test_file.write_text("test content") nonexistent = tmp_path / "does_not_exist.txt" - + # Try to emit directory, nonexistent path, and valid file - result = emit_related_files.invoke({ - "files": [str(test_dir), str(nonexistent), str(test_file)] - }) - + result = emit_related_files.invoke( + {"files": [str(test_dir), str(nonexistent), str(test_file)]} + ) + # Verify specific error messages for directory and nonexistent path assert f"Error: Path '{test_dir}' is a directory, not a file" in result assert f"Error: Path '{nonexistent}' does not exist" in result - + # Verify directory and nonexistent not added but valid file was - assert len(_global_memory['related_files']) == 1 - values = list(_global_memory['related_files'].values()) + assert len(_global_memory["related_files"]) == 1 + values = list(_global_memory["related_files"].values()) assert str(test_file) in values assert str(test_dir) not in values assert str(nonexistent) not in values assert str(nonexistent) not in values + def test_related_files_formatting(reset_memory, tmp_path): """Test related files output string formatting""" # Create test files @@ -425,19 +449,20 @@ def test_related_files_formatting(reset_memory, tmp_path): file1.write_text("# File 1") file2 = tmp_path / "file2.py" file2.write_text("# File 2") - + # Add some files emit_related_files.invoke({"files": [str(file1), str(file2)]}) - + # Get formatted output - output = get_memory_value('related_files') + output = get_memory_value("related_files") # Expect just the IDs on separate lines expected = "0\n1" assert output == expected - + # Test empty case - _global_memory['related_files'] = {} - assert get_memory_value('related_files') == "" + _global_memory["related_files"] = {} + assert get_memory_value("related_files") == "" + def test_key_snippets_integration(reset_memory, tmp_path): """Integration test for key snippets functionality""" @@ -448,59 +473,59 @@ def test_key_snippets_integration(reset_memory, tmp_path): file2.write_text("def func2():\n return True") file3 = tmp_path / "file3.py" file3.write_text("class TestClass:\n pass") - + # Initial snippets to add snippets = [ { "filepath": str(file1), "line_number": 10, "snippet": "def func1():\n pass", - "description": "First function" + "description": "First function", }, { "filepath": str(file2), "line_number": 20, "snippet": "def func2():\n return True", - "description": "Second function" + "description": "Second function", }, { "filepath": str(file3), "line_number": 30, "snippet": "class TestClass:\n pass", - "description": "Test class" - } + "description": "Test class", + }, ] - + # Add all snippets result = emit_key_snippets.invoke({"snippets": snippets}) assert result == "Snippets stored." - assert _global_memory['key_snippet_id_counter'] == 3 + assert _global_memory["key_snippet_id_counter"] == 3 # Verify related files were tracked with IDs - assert len(_global_memory['related_files']) == 3 + assert len(_global_memory["related_files"]) == 3 # Check files are stored with proper IDs - file_values = _global_memory['related_files'].values() - assert str(file1) in file_values + file_values = _global_memory["related_files"].values() + assert str(file1) in file_values assert str(file2) in file_values assert str(file3) in file_values - + # Verify all snippets were stored correctly - assert len(_global_memory['key_snippets']) == 3 - assert _global_memory['key_snippets'][0] == snippets[0] - assert _global_memory['key_snippets'][1] == snippets[1] - assert _global_memory['key_snippets'][2] == snippets[2] - + assert len(_global_memory["key_snippets"]) == 3 + assert _global_memory["key_snippets"][0] == snippets[0] + assert _global_memory["key_snippets"][1] == snippets[1] + assert _global_memory["key_snippets"][2] == snippets[2] + # Delete some but not all snippets (0 and 2) result = delete_key_snippets.invoke({"snippet_ids": [0, 2]}) assert result == "Snippets deleted." - + # Verify remaining snippet is intact - assert len(_global_memory['key_snippets']) == 1 - assert 1 in _global_memory['key_snippets'] - assert _global_memory['key_snippets'][1] == snippets[1] - + assert len(_global_memory["key_snippets"]) == 1 + assert 1 in _global_memory["key_snippets"] + assert _global_memory["key_snippets"][1] == snippets[1] + # Counter should remain unchanged after deletions - assert _global_memory['key_snippet_id_counter'] == 3 - + assert _global_memory["key_snippet_id_counter"] == 3 + # Add new snippet to verify counter continues correctly file4 = tmp_path / "file4.py" file4.write_text("def func4():\n return False") @@ -508,47 +533,49 @@ def test_key_snippets_integration(reset_memory, tmp_path): "filepath": str(file4), "line_number": 40, "snippet": "def func4():\n return False", - "description": "Fourth function" + "description": "Fourth function", } result = emit_key_snippets.invoke({"snippets": [new_snippet]}) assert result == "Snippets stored." - assert _global_memory['key_snippet_id_counter'] == 4 + assert _global_memory["key_snippet_id_counter"] == 4 # Verify new file was added to related files - file_values = _global_memory['related_files'].values() + file_values = _global_memory["related_files"].values() assert str(file4) in file_values - assert len(_global_memory['related_files']) == 4 - + assert len(_global_memory["related_files"]) == 4 + # Delete remaining snippets result = delete_key_snippets.invoke({"snippet_ids": [1, 3]}) assert result == "Snippets deleted." - + # Verify all snippets are gone - assert len(_global_memory['key_snippets']) == 0 - + assert len(_global_memory["key_snippets"]) == 0 + # Counter should still maintain its value - assert _global_memory['key_snippet_id_counter'] == 4 + assert _global_memory["key_snippet_id_counter"] == 4 + def test_emit_task_with_id(reset_memory): """Test emitting tasks with ID tracking""" # Test adding a single task task = "Implement new feature" result = emit_task.invoke({"task": task}) - + # Verify return message includes task ID assert result == "Task #0 stored." - + # Verify task stored correctly with ID - assert _global_memory['tasks'][0] == task - + assert _global_memory["tasks"][0] == task + # Verify counter incremented - assert _global_memory['task_id_counter'] == 1 - + assert _global_memory["task_id_counter"] == 1 + # Add another task to verify counter continues correctly task2 = "Fix bug" result = emit_task.invoke({"task": task2}) assert result == "Task #1 stored." - assert _global_memory['tasks'][1] == task2 - assert _global_memory['task_id_counter'] == 2 + assert _global_memory["tasks"][1] == task2 + assert _global_memory["task_id_counter"] == 2 + def test_delete_tasks(reset_memory): """Test deleting tasks""" @@ -556,24 +583,25 @@ def test_delete_tasks(reset_memory): tasks = ["Task 1", "Task 2", "Task 3"] for task in tasks: emit_task.invoke({"task": task}) - + # Test deleting single task result = delete_tasks.invoke({"task_ids": [1]}) assert result == "Tasks deleted." - assert 1 not in _global_memory['tasks'] - assert len(_global_memory['tasks']) == 2 - + assert 1 not in _global_memory["tasks"] + assert len(_global_memory["tasks"]) == 2 + # Test deleting multiple tasks including non-existent ID result = delete_tasks.invoke({"task_ids": [0, 2, 999]}) assert result == "Tasks deleted." - assert len(_global_memory['tasks']) == 0 - + assert len(_global_memory["tasks"]) == 0 + # Test deleting from empty tasks dict result = delete_tasks.invoke({"task_ids": [0]}) assert result == "Tasks deleted." - + # Counter should remain unchanged after deletions - assert _global_memory['task_id_counter'] == 3 + assert _global_memory["task_id_counter"] == 3 + def test_swap_task_order_valid_ids(reset_memory): """Test basic task swapping functionality""" @@ -581,63 +609,67 @@ def test_swap_task_order_valid_ids(reset_memory): tasks = ["Task 1", "Task 2", "Task 3"] for task in tasks: emit_task.invoke({"task": task}) - + # Swap tasks 0 and 2 result = swap_task_order.invoke({"id1": 0, "id2": 2}) assert result == "Tasks swapped." - + # Verify tasks were swapped - assert _global_memory['tasks'][0] == "Task 3" - assert _global_memory['tasks'][2] == "Task 1" - assert _global_memory['tasks'][1] == "Task 2" # Unchanged + assert _global_memory["tasks"][0] == "Task 3" + assert _global_memory["tasks"][2] == "Task 1" + assert _global_memory["tasks"][1] == "Task 2" # Unchanged + def test_swap_task_order_invalid_ids(reset_memory): """Test error handling for invalid task IDs""" # Add a test task emit_task.invoke({"task": "Task 1"}) - + # Try to swap with non-existent ID result = swap_task_order.invoke({"id1": 0, "id2": 999}) assert result == "Invalid task ID(s)" - + # Verify original task unchanged - assert _global_memory['tasks'][0] == "Task 1" + assert _global_memory["tasks"][0] == "Task 1" + def test_swap_task_order_same_id(reset_memory): """Test handling of attempt to swap a task with itself""" # Add test task emit_task.invoke({"task": "Task 1"}) - + # Try to swap task with itself result = swap_task_order.invoke({"id1": 0, "id2": 0}) assert result == "Cannot swap task with itself" - + # Verify task unchanged - assert _global_memory['tasks'][0] == "Task 1" + assert _global_memory["tasks"][0] == "Task 1" + def test_swap_task_order_empty_tasks(reset_memory): """Test swapping behavior with empty tasks dictionary""" result = swap_task_order.invoke({"id1": 0, "id2": 1}) assert result == "Invalid task ID(s)" + def test_swap_task_order_after_delete(reset_memory): """Test swapping after deleting a task""" # Add test tasks tasks = ["Task 1", "Task 2", "Task 3"] for task in tasks: emit_task.invoke({"task": task}) - + # Delete middle task delete_tasks.invoke({"task_ids": [1]}) - + # Try to swap with deleted task result = swap_task_order.invoke({"id1": 0, "id2": 1}) assert result == "Invalid task ID(s)" - + # Try to swap remaining valid tasks result = swap_task_order.invoke({"id1": 0, "id2": 2}) assert result == "Tasks swapped." - + # Verify swap worked - assert _global_memory['tasks'][0] == "Task 3" - assert _global_memory['tasks'][2] == "Task 1" + assert _global_memory["tasks"][0] == "Task 3" + assert _global_memory["tasks"][2] == "Task 1" diff --git a/tests/ra_aid/tools/test_read_file.py b/tests/ra_aid/tools/test_read_file.py index 9c6ade8..c32b7fd 100644 --- a/tests/ra_aid/tools/test_read_file.py +++ b/tests/ra_aid/tools/test_read_file.py @@ -1,8 +1,8 @@ import pytest -from pytest import mark -from langchain.schema.runnable import Runnable + from ra_aid.tools import read_file_tool + def test_basic_file_reading(tmp_path): """Test basic file reading functionality""" # Create a test file @@ -15,8 +15,9 @@ def test_basic_file_reading(tmp_path): # Verify return format and content assert isinstance(result, dict) - assert 'content' in result - assert result['content'] == test_content + assert "content" in result + assert result["content"] == test_content + def test_no_truncation(tmp_path): """Test that files under max_lines are not truncated""" @@ -31,8 +32,9 @@ def test_no_truncation(tmp_path): # Verify no truncation occurred assert isinstance(result, dict) - assert '[lines of output truncated]' not in result['content'] - assert len(result['content'].splitlines()) == line_count + assert "[lines of output truncated]" not in result["content"] + assert len(result["content"].splitlines()) == line_count + def test_with_truncation(tmp_path): """Test that files over max_lines are properly truncated""" @@ -47,14 +49,18 @@ def test_with_truncation(tmp_path): # Verify truncation occurred correctly assert isinstance(result, dict) - assert '[1000 lines of output truncated]' in result['content'] - assert len(result['content'].splitlines()) == 5001 # 5000 content lines + 1 truncation message + assert "[1000 lines of output truncated]" in result["content"] + assert ( + len(result["content"].splitlines()) == 5001 + ) # 5000 content lines + 1 truncation message + def test_nonexistent_file(): """Test error handling for non-existent files""" with pytest.raises(FileNotFoundError): read_file_tool.invoke({"filepath": "/nonexistent/file.txt"}) + def test_empty_file(tmp_path): """Test reading an empty file""" # Create an empty test file @@ -66,5 +72,5 @@ def test_empty_file(tmp_path): # Verify return format and empty content assert isinstance(result, dict) - assert 'content' in result - assert result['content'] == "" + assert "content" in result + assert result["content"] == "" diff --git a/tests/ra_aid/tools/test_reflection.py b/tests/ra_aid/tools/test_reflection.py index 5f2004f..66de1fa 100644 --- a/tests/ra_aid/tools/test_reflection.py +++ b/tests/ra_aid/tools/test_reflection.py @@ -1,30 +1,34 @@ -import pytest from ra_aid.tools.reflection import get_function_info + # Sample functions for testing get_function_info def simple_func(): """A simple function with no parameters.""" pass + def typed_func(a: int, b: str = "default") -> bool: """A function with type hints and default value. - + Args: a: An integer parameter b: A string parameter with default - + Returns: bool: Always returns True """ return True + def complex_func(pos1, pos2, *args, kw1="default", **kwargs): """A function with complex signature.""" pass + def no_docstring_func(x): pass + class TestGetFunctionInfo: def test_simple_function_info(self): """Test info extraction for simple function.""" @@ -58,5 +62,3 @@ class TestGetFunctionInfo: info = get_function_info(no_docstring_func) assert "no_docstring_func" in info assert "No docstring provided" in info - - diff --git a/tests/ra_aid/tools/test_shell.py b/tests/ra_aid/tools/test_shell.py index 405bdfb..3636ae6 100644 --- a/tests/ra_aid/tools/test_shell.py +++ b/tests/ra_aid/tools/test_shell.py @@ -1,93 +1,107 @@ +from unittest.mock import patch + import pytest -from unittest.mock import patch, MagicMock -from ra_aid.tools.shell import run_shell_command + from ra_aid.tools.memory import _global_memory +from ra_aid.tools.shell import run_shell_command + @pytest.fixture def mock_console(): - with patch('ra_aid.tools.shell.console') as mock: + with patch("ra_aid.tools.shell.console") as mock: yield mock + @pytest.fixture def mock_prompt(): - with patch('ra_aid.tools.shell.Prompt') as mock: + with patch("ra_aid.tools.shell.Prompt") as mock: yield mock + @pytest.fixture def mock_run_interactive(): - with patch('ra_aid.tools.shell.run_interactive_command') as mock: + with patch("ra_aid.tools.shell.run_interactive_command") as mock: mock.return_value = (b"test output", 0) yield mock + def test_shell_command_cowboy_mode(mock_console, mock_prompt, mock_run_interactive): """Test shell command execution in cowboy mode (no approval)""" - _global_memory['config'] = {'cowboy_mode': True} - + _global_memory["config"] = {"cowboy_mode": True} + result = run_shell_command.invoke({"command": "echo test"}) - - assert result['success'] is True - assert result['return_code'] == 0 - assert "test output" in result['output'] + + assert result["success"] is True + assert result["return_code"] == 0 + assert "test output" in result["output"] mock_prompt.ask.assert_not_called() + def test_shell_command_cowboy_message(mock_console, mock_prompt, mock_run_interactive): """Test that cowboy mode displays a properly formatted cowboy message with correct spacing""" - _global_memory['config'] = {'cowboy_mode': True} - - with patch('ra_aid.tools.shell.get_cowboy_message') as mock_get_message: - mock_get_message.return_value = '🀠 Test cowboy message!' + _global_memory["config"] = {"cowboy_mode": True} + + with patch("ra_aid.tools.shell.get_cowboy_message") as mock_get_message: + mock_get_message.return_value = "🀠 Test cowboy message!" result = run_shell_command.invoke({"command": "echo test"}) - - assert result['success'] is True + + assert result["success"] is True mock_console.print.assert_any_call("") mock_console.print.assert_any_call(" 🀠 Test cowboy message!") mock_console.print.assert_any_call("") mock_get_message.assert_called_once() -def test_shell_command_interactive_approved(mock_console, mock_prompt, mock_run_interactive): + +def test_shell_command_interactive_approved( + mock_console, mock_prompt, mock_run_interactive +): """Test shell command execution with interactive approval""" - _global_memory['config'] = {'cowboy_mode': False} - mock_prompt.ask.return_value = 'y' - + _global_memory["config"] = {"cowboy_mode": False} + mock_prompt.ask.return_value = "y" + result = run_shell_command.invoke({"command": "echo test"}) - - assert result['success'] is True - assert result['return_code'] == 0 - assert "test output" in result['output'] + + assert result["success"] is True + assert result["return_code"] == 0 + assert "test output" in result["output"] mock_prompt.ask.assert_called_once_with( "Execute this command? (y=yes, n=no, c=enable cowboy mode for session)", choices=["y", "n", "c"], default="y", show_choices=True, - show_default=True + show_default=True, ) -def test_shell_command_interactive_rejected(mock_console, mock_prompt, mock_run_interactive): + +def test_shell_command_interactive_rejected( + mock_console, mock_prompt, mock_run_interactive +): """Test shell command rejection in interactive mode""" - _global_memory['config'] = {'cowboy_mode': False} - mock_prompt.ask.return_value = 'n' - + _global_memory["config"] = {"cowboy_mode": False} + mock_prompt.ask.return_value = "n" + result = run_shell_command.invoke({"command": "echo test"}) - - assert result['success'] is False - assert result['return_code'] == 1 - assert "cancelled by user" in result['output'] + + assert result["success"] is False + assert result["return_code"] == 1 + assert "cancelled by user" in result["output"] mock_prompt.ask.assert_called_once_with( "Execute this command? (y=yes, n=no, c=enable cowboy mode for session)", choices=["y", "n", "c"], default="y", show_choices=True, - show_default=True + show_default=True, ) mock_run_interactive.assert_not_called() + def test_shell_command_execution_error(mock_console, mock_prompt, mock_run_interactive): """Test handling of shell command execution errors""" - _global_memory['config'] = {'cowboy_mode': True} + _global_memory["config"] = {"cowboy_mode": True} mock_run_interactive.side_effect = Exception("Command failed") - + result = run_shell_command.invoke({"command": "invalid command"}) - - assert result['success'] is False - assert result['return_code'] == 1 - assert "Command failed" in result['output'] + + assert result["success"] is False + assert result["return_code"] == 1 + assert "Command failed" in result["output"] diff --git a/tests/ra_aid/tools/test_write_file.py b/tests/ra_aid/tools/test_write_file.py index 67d5312..90c844b 100644 --- a/tests/ra_aid/tools/test_write_file.py +++ b/tests/ra_aid/tools/test_write_file.py @@ -1,9 +1,11 @@ import os +from unittest.mock import patch + import pytest -from pathlib import Path -from unittest.mock import patch, mock_open + from ra_aid.tools.write_file import write_file_tool + @pytest.fixture def temp_test_dir(tmp_path): """Create a temporary test directory.""" @@ -11,168 +13,157 @@ def temp_test_dir(tmp_path): test_dir.mkdir(exist_ok=True) return test_dir + def test_basic_write_functionality(temp_test_dir): """Test basic successful file writing.""" test_file = temp_test_dir / "test.txt" content = "Hello, World!\nTest content" - - result = write_file_tool.invoke({ - "filepath": str(test_file), - "content": content - }) - + + result = write_file_tool.invoke({"filepath": str(test_file), "content": content}) + # Verify file contents assert test_file.read_text() == content - + # Verify return dict format assert isinstance(result, dict) assert result["success"] is True assert result["filepath"] == str(test_file) - assert result["bytes_written"] == len(content.encode('utf-8')) + assert result["bytes_written"] == len(content.encode("utf-8")) assert "Operation completed" in result["message"] + def test_directory_creation(temp_test_dir): """Test writing to a file in a non-existent directory.""" nested_dir = temp_test_dir / "nested" / "subdirs" test_file = nested_dir / "test.txt" content = "Test content" - - result = write_file_tool.invoke({ - "filepath": str(test_file), - "content": content - }) - + + result = write_file_tool.invoke({"filepath": str(test_file), "content": content}) + assert test_file.exists() assert test_file.read_text() == content assert result["success"] is True + def test_different_encodings(temp_test_dir): """Test writing files with different encodings.""" test_file = temp_test_dir / "encoded.txt" content = "Hello δΈ–η•Œ" # Mixed ASCII and Unicode - - # Test UTF-8 - result_utf8 = write_file_tool.invoke({ - "filepath": str(test_file), - "content": content, - "encoding": 'utf-8' - }) - assert result_utf8["success"] is True - assert test_file.read_text(encoding='utf-8') == content - - # Test UTF-16 - result_utf16 = write_file_tool.invoke({ - "filepath": str(test_file), - "content": content, - "encoding": 'utf-16' - }) - assert result_utf16["success"] is True - assert test_file.read_text(encoding='utf-16') == content -@patch('builtins.open') + # Test UTF-8 + result_utf8 = write_file_tool.invoke( + {"filepath": str(test_file), "content": content, "encoding": "utf-8"} + ) + assert result_utf8["success"] is True + assert test_file.read_text(encoding="utf-8") == content + + # Test UTF-16 + result_utf16 = write_file_tool.invoke( + {"filepath": str(test_file), "content": content, "encoding": "utf-16"} + ) + assert result_utf16["success"] is True + assert test_file.read_text(encoding="utf-16") == content + + +@patch("builtins.open") def test_permission_error(mock_open_func, temp_test_dir): """Test handling of permission errors.""" mock_open_func.side_effect = PermissionError("Permission denied") test_file = temp_test_dir / "noperm.txt" - - result = write_file_tool.invoke({ - "filepath": str(test_file), - "content": "test content" - }) - + + result = write_file_tool.invoke( + {"filepath": str(test_file), "content": "test content"} + ) + assert result["success"] is False assert "Permission denied" in result["message"] assert result["error"] is not None -@patch('builtins.open') + +@patch("builtins.open") def test_io_error(mock_open_func, temp_test_dir): """Test handling of IO errors.""" mock_open_func.side_effect = IOError("IO Error occurred") test_file = temp_test_dir / "ioerror.txt" - - result = write_file_tool.invoke({ - "filepath": str(test_file), - "content": "test content" - }) - + + result = write_file_tool.invoke( + {"filepath": str(test_file), "content": "test content"} + ) + assert result["success"] is False assert "IO Error" in result["message"] assert result["error"] is not None + def test_empty_content(temp_test_dir): """Test writing empty content to a file.""" test_file = temp_test_dir / "empty.txt" - - result = write_file_tool.invoke({ - "filepath": str(test_file), - "content": "" - }) - + + result = write_file_tool.invoke({"filepath": str(test_file), "content": ""}) + assert test_file.exists() assert test_file.read_text() == "" assert result["success"] is True assert result["bytes_written"] == 0 + def test_overwrite_existing_file(temp_test_dir): """Test overwriting an existing file.""" test_file = temp_test_dir / "overwrite.txt" - + # Write initial content test_file.write_text("Initial content") - + # Overwrite with new content new_content = "New content" - result = write_file_tool.invoke({ - "filepath": str(test_file), - "content": new_content - }) - + result = write_file_tool.invoke( + {"filepath": str(test_file), "content": new_content} + ) + assert test_file.read_text() == new_content assert result["success"] is True - assert result["bytes_written"] == len(new_content.encode('utf-8')) + assert result["bytes_written"] == len(new_content.encode("utf-8")) + def test_large_file_write(temp_test_dir): """Test writing a large file and verify statistics.""" test_file = temp_test_dir / "large.txt" content = "Large content\n" * 1000 # Create substantial content - - result = write_file_tool.invoke({ - "filepath": str(test_file), - "content": content - }) - + + result = write_file_tool.invoke({"filepath": str(test_file), "content": content}) + assert test_file.exists() assert test_file.read_text() == content assert result["success"] is True - assert result["bytes_written"] == len(content.encode('utf-8')) - assert os.path.getsize(test_file) == len(content.encode('utf-8')) + assert result["bytes_written"] == len(content.encode("utf-8")) + assert os.path.getsize(test_file) == len(content.encode("utf-8")) + def test_invalid_path_characters(temp_test_dir): """Test handling of invalid path characters.""" invalid_path = temp_test_dir / "invalid\0file.txt" - - result = write_file_tool.invoke({ - "filepath": str(invalid_path), - "content": "test content" - }) - + + result = write_file_tool.invoke( + {"filepath": str(invalid_path), "content": "test content"} + ) + assert result["success"] is False assert "Invalid file path" in result["message"] + def test_write_to_readonly_directory(temp_test_dir): """Test writing to a readonly directory.""" readonly_dir = temp_test_dir / "readonly" readonly_dir.mkdir() test_file = readonly_dir / "test.txt" - + # Make directory readonly os.chmod(readonly_dir, 0o444) - + try: - result = write_file_tool.invoke({ - "filepath": str(test_file), - "content": "test content" - }) + result = write_file_tool.invoke( + {"filepath": str(test_file), "content": "test content"} + ) assert result["success"] is False assert "Permission" in result["message"] finally: diff --git a/tests/scripts/test_extract_changelog.py b/tests/scripts/test_extract_changelog.py index f9da477..ca5433f 100644 --- a/tests/scripts/test_extract_changelog.py +++ b/tests/scripts/test_extract_changelog.py @@ -1,6 +1,8 @@ import pytest + from scripts.extract_changelog import extract_version_content + @pytest.fixture def basic_changelog(): return """## [1.2.0] @@ -14,6 +16,7 @@ def basic_changelog(): - Change Y """ + @pytest.fixture def complex_changelog(): return """## [2.0.0] @@ -30,6 +33,7 @@ def complex_changelog(): Some content """ + def test_basic_version_extraction(basic_changelog): """Test extracting a simple version entry""" result = extract_version_content(basic_changelog, "1.2.0") @@ -39,6 +43,7 @@ def test_basic_version_extraction(basic_changelog): - Feature B""" assert result == expected + def test_middle_version_extraction(complex_changelog): """Test extracting a version from middle of changelog""" result = extract_version_content(complex_changelog, "1.9.0") @@ -49,22 +54,26 @@ def test_middle_version_extraction(complex_changelog): - Bug fix""" assert result == expected + def test_version_not_found(): """Test error handling when version doesn't exist""" with pytest.raises(ValueError, match="Version 9.9.9 not found in changelog"): extract_version_content("## [1.0.0]\nSome content", "9.9.9") + def test_empty_changelog(): """Test handling empty changelog""" with pytest.raises(ValueError, match="Version 1.0.0 not found in changelog"): extract_version_content("", "1.0.0") + def test_malformed_changelog(): """Test handling malformed changelog without proper version headers""" content = "Some content\nNo version headers here\n" with pytest.raises(ValueError, match="Version 1.0.0 not found in changelog"): extract_version_content(content, "1.0.0") + def test_version_with_special_chars(): """Test handling versions with special regex characters""" content = """## [1.0.0-beta.1] diff --git a/tests/test_file_listing.py b/tests/test_file_listing.py index 623c184..91f3efa 100644 --- a/tests/test_file_listing.py +++ b/tests/test_file_listing.py @@ -1,19 +1,21 @@ """Tests for file listing functionality.""" import os -import pytest -from pathlib import Path import subprocess -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch + +import pytest + from ra_aid.file_listing import ( + DirectoryAccessError, + DirectoryNotFoundError, + FileListerError, + GitCommandError, get_file_listing, is_git_repo, - GitCommandError, - DirectoryNotFoundError, - DirectoryAccessError, - FileListerError, ) + @pytest.fixture def empty_git_repo(tmp_path): """Create an empty git repository.""" @@ -30,7 +32,7 @@ def sample_git_repo(empty_git_repo): "src/main.py", "src/utils.py", "tests/test_main.py", - "docs/index.html" + "docs/index.html", ] for file_path in files: @@ -43,10 +45,12 @@ def sample_git_repo(empty_git_repo): subprocess.run( ["git", "commit", "-m", "Initial commit"], cwd=empty_git_repo, - env={"GIT_AUTHOR_NAME": "Test", - "GIT_AUTHOR_EMAIL": "test@example.com", - "GIT_COMMITTER_NAME": "Test", - "GIT_COMMITTER_EMAIL": "test@example.com"} + env={ + "GIT_AUTHOR_NAME": "Test", + "GIT_AUTHOR_EMAIL": "test@example.com", + "GIT_COMMITTER_NAME": "Test", + "GIT_COMMITTER_EMAIL": "test@example.com", + }, ) return empty_git_repo @@ -154,7 +158,10 @@ FILE_LISTING_TEST_CASES = [ }, { "name": "duplicate_files", - "git_output": "\n".join([SINGLE_FILE_NAME, SINGLE_FILE_NAME] + MULTI_FILE_NAMES[1:]) + "\n", + "git_output": "\n".join( + [SINGLE_FILE_NAME, SINGLE_FILE_NAME] + MULTI_FILE_NAMES[1:] + ) + + "\n", "expected_files": [SINGLE_FILE_NAME] + MULTI_FILE_NAMES[1:], "expected_total": 3, # After deduplication "limit": None, @@ -217,6 +224,7 @@ FILE_LISTING_TEST_CASES = [ }, ] + def create_mock_process(git_output: str) -> MagicMock: """Create a mock process with the given git output.""" mock_process = MagicMock() @@ -224,12 +232,14 @@ def create_mock_process(git_output: str) -> MagicMock: mock_process.returncode = 0 return mock_process + @pytest.fixture def mock_subprocess(): """Fixture to mock subprocess.run.""" with patch("subprocess.run") as mock_run: yield mock_run + @pytest.fixture def mock_is_git_repo(): """Fixture to mock is_git_repo function.""" @@ -237,6 +247,7 @@ def mock_is_git_repo(): mock.return_value = True yield mock + @pytest.mark.parametrize("test_case", FILE_LISTING_TEST_CASES, ids=lambda x: x["name"]) def test_get_file_listing(test_case, mock_subprocess, mock_is_git_repo): """Test get_file_listing with various inputs.""" @@ -245,6 +256,7 @@ def test_get_file_listing(test_case, mock_subprocess, mock_is_git_repo): assert files == test_case["expected_files"] assert total == test_case["expected_total"] + def test_get_file_listing_non_git_repo(mock_is_git_repo): """Test get_file_listing with non-git repository.""" mock_is_git_repo.return_value = False @@ -252,21 +264,23 @@ def test_get_file_listing_non_git_repo(mock_is_git_repo): assert files == EMPTY_FILE_LIST assert total == EMPTY_FILE_TOTAL + def test_get_file_listing_git_error(mock_subprocess, mock_is_git_repo): """Test get_file_listing when git command fails.""" mock_subprocess.side_effect = GitCommandError("Git command failed") with pytest.raises(GitCommandError): get_file_listing(DUMMY_PATH) + def test_get_file_listing_permission_error(mock_subprocess, mock_is_git_repo): """Test get_file_listing with permission error.""" mock_subprocess.side_effect = PermissionError("Permission denied") with pytest.raises(DirectoryAccessError): get_file_listing(DUMMY_PATH) + def test_get_file_listing_unexpected_error(mock_subprocess, mock_is_git_repo): """Test get_file_listing with unexpected error.""" mock_subprocess.side_effect = Exception("Unexpected error") with pytest.raises(FileListerError): get_file_listing(DUMMY_PATH) - diff --git a/tests/test_project_info.py b/tests/test_project_info.py index 3b0a7e1..7c55ff1 100644 --- a/tests/test_project_info.py +++ b/tests/test_project_info.py @@ -2,22 +2,18 @@ import os import subprocess -import pytest -from pathlib import Path -from ra_aid.project_info import ( - get_project_info, - ProjectInfo, - ProjectInfoError -) -from ra_aid.project_state import DirectoryNotFoundError, DirectoryAccessError -from ra_aid.file_listing import GitCommandError +import pytest + +from ra_aid.project_info import ProjectInfo, get_project_info +from ra_aid.project_state import DirectoryAccessError, DirectoryNotFoundError @pytest.fixture def empty_git_repo(tmp_path): """Create an empty git repository.""" import subprocess + subprocess.run(["git", "init"], cwd=tmp_path, capture_output=True) return tmp_path @@ -31,25 +27,27 @@ def sample_git_repo(empty_git_repo): "src/main.py", "src/utils.py", "tests/test_main.py", - "docs/index.html" + "docs/index.html", ] - + for file_path in files: full_path = empty_git_repo / file_path full_path.parent.mkdir(parents=True, exist_ok=True) full_path.write_text(f"Content of {file_path}") - + # Add and commit files subprocess.run(["git", "add", "."], cwd=empty_git_repo) subprocess.run( ["git", "commit", "-m", "Initial commit"], cwd=empty_git_repo, - env={"GIT_AUTHOR_NAME": "Test", - "GIT_AUTHOR_EMAIL": "test@example.com", - "GIT_COMMITTER_NAME": "Test", - "GIT_COMMITTER_EMAIL": "test@example.com"} + env={ + "GIT_AUTHOR_NAME": "Test", + "GIT_AUTHOR_EMAIL": "test@example.com", + "GIT_COMMITTER_NAME": "Test", + "GIT_COMMITTER_EMAIL": "test@example.com", + }, ) - + return empty_git_repo @@ -89,7 +87,7 @@ def test_file_as_directory(tmp_path): """Test handling of file path instead of directory.""" test_file = tmp_path / "test.txt" test_file.write_text("test") - + with pytest.raises(DirectoryNotFoundError): get_project_info(str(test_file)) @@ -100,7 +98,7 @@ def test_permission_error(tmp_path): try: # Make directory unreadable os.chmod(tmp_path, 0o000) - + with pytest.raises(DirectoryAccessError): get_project_info(str(tmp_path)) finally: diff --git a/tests/test_project_state.py b/tests/test_project_state.py index c194d61..1cb41af 100644 --- a/tests/test_project_state.py +++ b/tests/test_project_state.py @@ -1,14 +1,14 @@ """Tests for project state detection functionality.""" import os + import pytest -from pathlib import Path from ra_aid.project_state import ( - is_new_project, - DirectoryNotFoundError, DirectoryAccessError, - ProjectStateError + DirectoryNotFoundError, + ProjectStateError, + is_new_project, ) @@ -81,7 +81,7 @@ def test_file_as_directory(tmp_path): """Test that passing a file instead of directory raises error.""" test_file = tmp_path / "test.txt" test_file.write_text("test") - + with pytest.raises(ProjectStateError): is_new_project(str(test_file)) @@ -92,7 +92,7 @@ def test_permission_error(tmp_path): try: # Make directory unreadable os.chmod(tmp_path, 0o000) - + with pytest.raises(DirectoryAccessError): is_new_project(str(tmp_path)) finally: