Add --temperature CLI parameter.

This commit is contained in:
AI Christianson 2024-12-28 18:36:24 -05:00
parent ace34633de
commit a28ed59bca
4 changed files with 96 additions and 3 deletions

View File

@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add CiaynAgent to support models that do not have, or are not good at, agentic function calling.
- Improve env var validation.
- Add --temperature CLI parameter.
## [0.10.3] - 2024-12-27

View File

@ -104,6 +104,12 @@ Examples:
action='store_true',
help='Enable verbose logging output'
)
parser.add_argument(
'--temperature',
type=float,
help='LLM temperature (0.0-2.0). Controls randomness in responses',
default=None
)
if args is None:
args = sys.argv[1:]
@ -129,6 +135,10 @@ Examples:
if parsed_args.expert_provider != 'openai' and not parsed_args.expert_model and not parsed_args.research_only:
parser.error(f"--expert-model is required when using expert provider '{parsed_args.expert_provider}'")
# Validate temperature range if provided
if parsed_args.temperature is not None and not (0.0 <= parsed_args.temperature <= 2.0):
parser.error('Temperature must be between 0.0 and 2.0')
return parsed_args
# Create console instance
@ -179,7 +189,7 @@ def main():
))
# Create the base model after validation
model = initialize_llm(args.provider, args.model)
model = initialize_llm(args.provider, args.model, temperature=args.temperature)
# Handle chat mode
if args.chat:

View File

@ -3,7 +3,7 @@ from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_core.language_models import BaseChatModel
def initialize_llm(provider: str, model_name: str) -> BaseChatModel:
def initialize_llm(provider: str, model_name: str, temperature: float | None = None) -> BaseChatModel:
"""Initialize a language model client based on the specified provider and model.
Note: Environment variables must be validated before calling this function.
@ -12,6 +12,8 @@ def initialize_llm(provider: str, model_name: str) -> BaseChatModel:
Args:
provider: The LLM provider to use ('openai', 'anthropic', 'openrouter', 'openai-compatible')
model_name: Name of the model to use
temperature: Optional temperature setting for controlling randomness (0.0-2.0).
If not specified, provider-specific defaults are used.
Returns:
BaseChatModel: Configured language model client
@ -23,23 +25,26 @@ def initialize_llm(provider: str, model_name: str) -> BaseChatModel:
return ChatOpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
model=model_name,
**({"temperature": temperature} if temperature is not None else {})
)
elif provider == "anthropic":
return ChatAnthropic(
api_key=os.getenv("ANTHROPIC_API_KEY"),
model_name=model_name,
**({"temperature": temperature} if temperature is not None else {})
)
elif provider == "openrouter":
return ChatOpenAI(
api_key=os.getenv("OPENROUTER_API_KEY"),
base_url="https://openrouter.ai/api/v1",
model=model_name,
**({"temperature": temperature} if temperature is not None else {})
)
elif provider == "openai-compatible":
return ChatOpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
base_url=os.getenv("OPENAI_API_BASE"),
temperature=0.3,
temperature=temperature if temperature is not None else 0.3,
model=model_name,
)
else:

View File

@ -154,6 +154,83 @@ def test_initialize_unsupported_provider(clean_env):
initialize_llm("unsupported", "model")
assert str(exc_info.value) == "Unsupported provider: unsupported"
def test_temperature_defaults(clean_env, mock_openai, mock_anthropic):
"""Test default temperature behavior for different providers."""
os.environ["OPENAI_API_KEY"] = "test-key"
os.environ["ANTHROPIC_API_KEY"] = "test-key"
os.environ["OPENAI_API_BASE"] = "http://test-url"
# Test openai-compatible default temperature
initialize_llm("openai-compatible", "test-model")
mock_openai.assert_called_with(
api_key="test-key",
base_url="http://test-url",
model="test-model",
temperature=0.3
)
# Test other providers don't set temperature by default
initialize_llm("openai", "test-model")
mock_openai.assert_called_with(
api_key="test-key",
model="test-model"
)
initialize_llm("anthropic", "test-model")
mock_anthropic.assert_called_with(
api_key="test-key",
model_name="test-model"
)
def test_explicit_temperature(clean_env, mock_openai, mock_anthropic):
"""Test explicit temperature setting for each provider."""
os.environ["OPENAI_API_KEY"] = "test-key"
os.environ["ANTHROPIC_API_KEY"] = "test-key"
os.environ["OPENROUTER_API_KEY"] = "test-key"
test_temp = 0.7
# Test OpenAI
initialize_llm("openai", "test-model", temperature=test_temp)
mock_openai.assert_called_with(
api_key="test-key",
model="test-model",
temperature=test_temp
)
# Test Anthropic
initialize_llm("anthropic", "test-model", temperature=test_temp)
mock_anthropic.assert_called_with(
api_key="test-key",
model_name="test-model",
temperature=test_temp
)
# Test OpenRouter
initialize_llm("openrouter", "test-model", temperature=test_temp)
mock_openai.assert_called_with(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
model="test-model",
temperature=test_temp
)
def test_temperature_validation(clean_env, mock_openai):
"""Test temperature validation in command line arguments."""
from ra_aid.__main__ import parse_arguments
# Test temperature below minimum
with pytest.raises(SystemExit):
parse_arguments(['--message', 'test', '--temperature', '-0.1'])
# Test temperature above maximum
with pytest.raises(SystemExit):
parse_arguments(['--message', 'test', '--temperature', '2.1'])
# Test valid temperature
args = parse_arguments(['--message', 'test', '--temperature', '0.7'])
assert args.temperature == 0.7
def test_provider_name_validation():
"""Test provider name validation and normalization."""
# Test all supported providers