diff --git a/CHANGELOG.md b/CHANGELOG.md index 4729d2c..3df4d84 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add CiaynAgent to support models that do not have, or are not good at, agentic function calling. - Improve env var validation. +- Add --temperature CLI parameter. ## [0.10.3] - 2024-12-27 diff --git a/ra_aid/__main__.py b/ra_aid/__main__.py index 0536902..f68ea5b 100644 --- a/ra_aid/__main__.py +++ b/ra_aid/__main__.py @@ -104,6 +104,12 @@ Examples: action='store_true', help='Enable verbose logging output' ) + parser.add_argument( + '--temperature', + type=float, + help='LLM temperature (0.0-2.0). Controls randomness in responses', + default=None + ) if args is None: args = sys.argv[1:] @@ -129,6 +135,10 @@ Examples: if parsed_args.expert_provider != 'openai' and not parsed_args.expert_model and not parsed_args.research_only: parser.error(f"--expert-model is required when using expert provider '{parsed_args.expert_provider}'") + # Validate temperature range if provided + if parsed_args.temperature is not None and not (0.0 <= parsed_args.temperature <= 2.0): + parser.error('Temperature must be between 0.0 and 2.0') + return parsed_args # Create console instance @@ -179,7 +189,7 @@ def main(): )) # Create the base model after validation - model = initialize_llm(args.provider, args.model) + model = initialize_llm(args.provider, args.model, temperature=args.temperature) # Handle chat mode if args.chat: diff --git a/ra_aid/llm.py b/ra_aid/llm.py index bc43fd5..c8b8db9 100644 --- a/ra_aid/llm.py +++ b/ra_aid/llm.py @@ -3,7 +3,7 @@ from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic from langchain_core.language_models import BaseChatModel -def initialize_llm(provider: str, model_name: str) -> BaseChatModel: +def initialize_llm(provider: str, model_name: str, temperature: float | None = None) -> BaseChatModel: """Initialize a language model client based on the specified provider and model. Note: Environment variables must be validated before calling this function. @@ -12,6 +12,8 @@ def initialize_llm(provider: str, model_name: str) -> BaseChatModel: Args: provider: The LLM provider to use ('openai', 'anthropic', 'openrouter', 'openai-compatible') model_name: Name of the model to use + temperature: Optional temperature setting for controlling randomness (0.0-2.0). + If not specified, provider-specific defaults are used. Returns: BaseChatModel: Configured language model client @@ -23,23 +25,26 @@ def initialize_llm(provider: str, model_name: str) -> BaseChatModel: return ChatOpenAI( api_key=os.getenv("OPENAI_API_KEY"), model=model_name, + **({"temperature": temperature} if temperature is not None else {}) ) elif provider == "anthropic": return ChatAnthropic( api_key=os.getenv("ANTHROPIC_API_KEY"), model_name=model_name, + **({"temperature": temperature} if temperature is not None else {}) ) elif provider == "openrouter": return ChatOpenAI( api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1", model=model_name, + **({"temperature": temperature} if temperature is not None else {}) ) elif provider == "openai-compatible": return ChatOpenAI( api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_API_BASE"), - temperature=0.3, + temperature=temperature if temperature is not None else 0.3, model=model_name, ) else: diff --git a/tests/ra_aid/test_llm.py b/tests/ra_aid/test_llm.py index cc8d60f..e21be63 100644 --- a/tests/ra_aid/test_llm.py +++ b/tests/ra_aid/test_llm.py @@ -154,6 +154,83 @@ def test_initialize_unsupported_provider(clean_env): initialize_llm("unsupported", "model") assert str(exc_info.value) == "Unsupported provider: unsupported" +def test_temperature_defaults(clean_env, mock_openai, mock_anthropic): + """Test default temperature behavior for different providers.""" + os.environ["OPENAI_API_KEY"] = "test-key" + os.environ["ANTHROPIC_API_KEY"] = "test-key" + os.environ["OPENAI_API_BASE"] = "http://test-url" + + # Test openai-compatible default temperature + initialize_llm("openai-compatible", "test-model") + mock_openai.assert_called_with( + api_key="test-key", + base_url="http://test-url", + model="test-model", + temperature=0.3 + ) + + # Test other providers don't set temperature by default + initialize_llm("openai", "test-model") + mock_openai.assert_called_with( + api_key="test-key", + model="test-model" + ) + + initialize_llm("anthropic", "test-model") + mock_anthropic.assert_called_with( + api_key="test-key", + model_name="test-model" + ) + +def test_explicit_temperature(clean_env, mock_openai, mock_anthropic): + """Test explicit temperature setting for each provider.""" + os.environ["OPENAI_API_KEY"] = "test-key" + os.environ["ANTHROPIC_API_KEY"] = "test-key" + os.environ["OPENROUTER_API_KEY"] = "test-key" + + test_temp = 0.7 + + # Test OpenAI + initialize_llm("openai", "test-model", temperature=test_temp) + mock_openai.assert_called_with( + api_key="test-key", + model="test-model", + temperature=test_temp + ) + + # Test Anthropic + initialize_llm("anthropic", "test-model", temperature=test_temp) + mock_anthropic.assert_called_with( + api_key="test-key", + model_name="test-model", + temperature=test_temp + ) + + # Test OpenRouter + initialize_llm("openrouter", "test-model", temperature=test_temp) + mock_openai.assert_called_with( + api_key="test-key", + base_url="https://openrouter.ai/api/v1", + model="test-model", + temperature=test_temp + ) + +def test_temperature_validation(clean_env, mock_openai): + """Test temperature validation in command line arguments.""" + from ra_aid.__main__ import parse_arguments + + # Test temperature below minimum + with pytest.raises(SystemExit): + parse_arguments(['--message', 'test', '--temperature', '-0.1']) + + # Test temperature above maximum + with pytest.raises(SystemExit): + parse_arguments(['--message', 'test', '--temperature', '2.1']) + + # Test valid temperature + args = parse_arguments(['--message', 'test', '--temperature', '0.7']) + assert args.temperature == 0.7 + def test_provider_name_validation(): """Test provider name validation and normalization.""" # Test all supported providers