diff --git a/ra_aid/tools/shell.py b/ra_aid/tools/shell.py index a1a5b91..06fa7d5 100644 --- a/ra_aid/tools/shell.py +++ b/ra_aid/tools/shell.py @@ -1,3 +1,5 @@ +import platform +import shutil from typing import Dict, Union from langchain_core.tools import tool @@ -15,6 +17,16 @@ from ra_aid.database.repositories.human_input_repository import get_human_input_ console = Console() +def _detect_shell(): + """Detect the appropriate shell to use based on the environment.""" + if platform.system().lower().startswith("win"): + # Check if pwsh is available, otherwise fall back to powershell + if shutil.which("pwsh"): + return ["pwsh", "-c"] + else: + return ["powershell", "-c"] + else: + return ["/bin/bash", "-c"] def _truncate_for_log(text: str, max_length: int = 300) -> str: """Truncate text for logging, adding [truncated] if necessary.""" @@ -98,8 +110,9 @@ def run_shell_command( try: print() + shell_cmd = _detect_shell() output, return_code = run_interactive_command( - ["/bin/bash", "-c", command], + shell_cmd + [command], expected_runtime_seconds=timeout, ) print() @@ -131,4 +144,4 @@ def run_shell_command( ) console.print(Panel(str(e), title="❌ Error", border_style="red")) - return {"output": str(e), "return_code": 1, "success": False} \ No newline at end of file + return {"output": str(e), "return_code": 1, "success": False}