diff --git a/.env.sample b/.env.sample index 51bbc471..3b47252a 100644 --- a/.env.sample +++ b/.env.sample @@ -2,4 +2,9 @@ GEMINI_PROJECT_ID= GEMINI_API_KEY= GITHUB_TOKEN= OPENROUTER_API_KEY = -OPENROUTER_MODEL = \ No newline at end of file +OPENROUTER_MODEL = +# MiniMax (https://www.minimax.io) — set API key to auto-detect +MINIMAX_API_KEY= +# Optional overrides (defaults: MiniMax-M2.7, https://api.minimax.io/v1) +# MINIMAX_MODEL=MiniMax-M2.7 +# MINIMAX_BASE_URL=https://api.minimax.io \ No newline at end of file diff --git a/README.md b/README.md index cc8ad4e8..65cf0a41 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,7 @@ This is a tutorial project of [Pocket Flow](https://github.com/The-Pocket/Pocket pip install -r requirements.txt ``` -4. Set up LLM in [`utils/call_llm.py`](./utils/call_llm.py) by providing credentials. To do so, you can put the values in a `.env` file. By default, you can use the AI Studio key with this client for Gemini Pro 2.5 by setting the `GEMINI_API_KEY` environment variable. If you want to use another LLM, you can set the `LLM_PROVIDER` environment variable (e.g. `XAI`), and then set the model, url, and API key (e.g. `XAI_MODEL`, `XAI_URL`,`XAI_API_KEY`). If using Ollama, the url is `http://localhost:11434/` and the API key can be omitted. +4. Set up LLM in [`utils/call_llm.py`](./utils/call_llm.py) by providing credentials. To do so, you can put the values in a `.env` file. By default, you can use the AI Studio key with this client for Gemini Pro 2.5 by setting the `GEMINI_API_KEY` environment variable. If you want to use another LLM, you can set the `LLM_PROVIDER` environment variable (e.g. `XAI`), and then set the model, url, and API key (e.g. `XAI_MODEL`, `XAI_URL`,`XAI_API_KEY`). If using Ollama, the url is `http://localhost:11434/` and the API key can be omitted. For [MiniMax](https://www.minimax.io), just set `MINIMAX_API_KEY` — the provider, model (`MiniMax-M2.7`), and base URL are auto-configured. You can use your own models. We highly recommend the latest models with thinking capabilities (Claude 3.7 with thinking, O1). You can verify that it is correctly set up by running: ```bash python utils/call_llm.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_call_llm.py b/tests/test_call_llm.py new file mode 100644 index 00000000..c7b69c7c --- /dev/null +++ b/tests/test_call_llm.py @@ -0,0 +1,270 @@ +"""Unit tests for utils/call_llm.py – focused on MiniMax provider support.""" + +import json +import os +from unittest import mock + +import pytest + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _clean_env(monkeypatch): + """Remove all LLM-related env vars to start with a blank slate.""" + for key in list(os.environ): + if key.startswith(("LLM_PROVIDER", "GEMINI_", "MINIMAX_", "OPENAI_", "XAI_", "OLLAMA_")): + monkeypatch.delenv(key, raising=False) + + +# --------------------------------------------------------------------------- +# get_llm_provider – auto-detection +# --------------------------------------------------------------------------- + +class TestGetLlmProvider: + """Tests for the provider auto-detection logic.""" + + def test_explicit_provider_takes_priority(self, monkeypatch): + _clean_env(monkeypatch) + monkeypatch.setenv("LLM_PROVIDER", "XAI") + monkeypatch.setenv("MINIMAX_API_KEY", "test-key") + from utils.call_llm import get_llm_provider + assert get_llm_provider() == "XAI" + + def test_gemini_auto_detected_before_minimax(self, monkeypatch): + _clean_env(monkeypatch) + monkeypatch.setenv("GEMINI_API_KEY", "gemini-key") + monkeypatch.setenv("MINIMAX_API_KEY", "minimax-key") + from utils.call_llm import get_llm_provider + assert get_llm_provider() == "GEMINI" + + def test_minimax_auto_detected_by_api_key(self, monkeypatch): + _clean_env(monkeypatch) + monkeypatch.setenv("MINIMAX_API_KEY", "test-key") + from utils.call_llm import get_llm_provider + assert get_llm_provider() == "MINIMAX" + + def test_no_provider_detected(self, monkeypatch): + _clean_env(monkeypatch) + from utils.call_llm import get_llm_provider + assert get_llm_provider() is None + + +# --------------------------------------------------------------------------- +# _PROVIDER_DEFAULTS – MiniMax defaults +# --------------------------------------------------------------------------- + +class TestProviderDefaults: + """Verify built-in defaults for the MINIMAX provider.""" + + def test_minimax_defaults_exist(self): + from utils.call_llm import _PROVIDER_DEFAULTS + assert "MINIMAX" in _PROVIDER_DEFAULTS + + def test_minimax_base_url(self): + from utils.call_llm import _PROVIDER_DEFAULTS + assert _PROVIDER_DEFAULTS["MINIMAX"]["base_url"] == "https://api.minimax.io" + + def test_minimax_default_model(self): + from utils.call_llm import _PROVIDER_DEFAULTS + assert _PROVIDER_DEFAULTS["MINIMAX"]["model"] == "MiniMax-M2.7" + + def test_minimax_temperature_range(self): + from utils.call_llm import _PROVIDER_DEFAULTS + d = _PROVIDER_DEFAULTS["MINIMAX"] + assert d["min_temperature"] > 0 + assert d["max_temperature"] <= 1.0 + + +# --------------------------------------------------------------------------- +# _call_llm_provider – MiniMax request construction +# --------------------------------------------------------------------------- + +class TestCallLlmProviderMiniMax: + """Test that _call_llm_provider builds the correct request for MiniMax.""" + + def test_uses_default_model_and_base_url(self, monkeypatch): + _clean_env(monkeypatch) + monkeypatch.setenv("LLM_PROVIDER", "MINIMAX") + monkeypatch.setenv("MINIMAX_API_KEY", "test-key-123") + + mock_response = mock.MagicMock() + mock_response.json.return_value = { + "choices": [{"message": {"content": "Hello from MiniMax!"}}] + } + mock_response.raise_for_status = mock.MagicMock() + + with mock.patch("utils.call_llm.requests.post", return_value=mock_response) as mock_post: + from utils.call_llm import _call_llm_provider + result = _call_llm_provider("test prompt") + + assert result == "Hello from MiniMax!" + call_args = mock_post.call_args + url = call_args[0][0] + payload = call_args[1]["json"] + headers = call_args[1]["headers"] + assert "api.minimax.io" in url + assert payload["model"] == "MiniMax-M2.7" + assert headers["Authorization"] == "Bearer test-key-123" + + def test_custom_model_overrides_default(self, monkeypatch): + _clean_env(monkeypatch) + monkeypatch.setenv("LLM_PROVIDER", "MINIMAX") + monkeypatch.setenv("MINIMAX_API_KEY", "key") + monkeypatch.setenv("MINIMAX_MODEL", "MiniMax-M2.7-highspeed") + + mock_response = mock.MagicMock() + mock_response.json.return_value = { + "choices": [{"message": {"content": "fast"}}] + } + mock_response.raise_for_status = mock.MagicMock() + + with mock.patch("utils.call_llm.requests.post", return_value=mock_response) as mock_post: + from utils.call_llm import _call_llm_provider + _call_llm_provider("prompt") + + payload = mock_post.call_args[1]["json"] + assert payload["model"] == "MiniMax-M2.7-highspeed" + + def test_custom_base_url_overrides_default(self, monkeypatch): + _clean_env(monkeypatch) + monkeypatch.setenv("LLM_PROVIDER", "MINIMAX") + monkeypatch.setenv("MINIMAX_API_KEY", "key") + monkeypatch.setenv("MINIMAX_BASE_URL", "https://custom.endpoint.io") + + mock_response = mock.MagicMock() + mock_response.json.return_value = { + "choices": [{"message": {"content": "ok"}}] + } + mock_response.raise_for_status = mock.MagicMock() + + with mock.patch("utils.call_llm.requests.post", return_value=mock_response) as mock_post: + from utils.call_llm import _call_llm_provider + _call_llm_provider("prompt") + + url = mock_post.call_args[0][0] + assert url.startswith("https://custom.endpoint.io") + + def test_temperature_clamped_to_provider_range(self, monkeypatch): + _clean_env(monkeypatch) + monkeypatch.setenv("LLM_PROVIDER", "MINIMAX") + monkeypatch.setenv("MINIMAX_API_KEY", "key") + + mock_response = mock.MagicMock() + mock_response.json.return_value = { + "choices": [{"message": {"content": "ok"}}] + } + mock_response.raise_for_status = mock.MagicMock() + + with mock.patch("utils.call_llm.requests.post", return_value=mock_response) as mock_post: + from utils.call_llm import _call_llm_provider + _call_llm_provider("prompt") + + payload = mock_post.call_args[1]["json"] + temp = payload["temperature"] + assert 0 < temp <= 1.0, f"Temperature {temp} out of MiniMax range" + + def test_no_defaults_for_unknown_provider(self, monkeypatch): + _clean_env(monkeypatch) + monkeypatch.setenv("LLM_PROVIDER", "CUSTOM") + # Without model and base_url, should raise ValueError + with pytest.raises(ValueError, match="CUSTOM_MODEL"): + from utils.call_llm import _call_llm_provider + _call_llm_provider("prompt") + + +# --------------------------------------------------------------------------- +# call_llm – end-to-end with MiniMax auto-detection +# --------------------------------------------------------------------------- + +class TestCallLlmMiniMaxIntegration: + """Test call_llm() with MiniMax auto-detection (env var only, mocked HTTP).""" + + def test_auto_detect_and_call(self, monkeypatch, tmp_path): + _clean_env(monkeypatch) + monkeypatch.setenv("MINIMAX_API_KEY", "test-key") + + # Use a temp cache file so we don't pollute the repo + monkeypatch.setattr("utils.call_llm.cache_file", str(tmp_path / "cache.json")) + + mock_response = mock.MagicMock() + mock_response.json.return_value = { + "choices": [{"message": {"content": "MiniMax response"}}] + } + mock_response.raise_for_status = mock.MagicMock() + + with mock.patch("utils.call_llm.requests.post", return_value=mock_response): + from utils.call_llm import call_llm + result = call_llm("hello", use_cache=False) + + assert result == "MiniMax response" + + def test_cache_works_with_minimax(self, monkeypatch, tmp_path): + _clean_env(monkeypatch) + monkeypatch.setenv("MINIMAX_API_KEY", "test-key") + + cache_path = str(tmp_path / "cache.json") + monkeypatch.setattr("utils.call_llm.cache_file", cache_path) + + mock_response = mock.MagicMock() + mock_response.json.return_value = { + "choices": [{"message": {"content": "first call"}}] + } + mock_response.raise_for_status = mock.MagicMock() + + with mock.patch("utils.call_llm.requests.post", return_value=mock_response) as mock_post: + from utils.call_llm import call_llm + r1 = call_llm("cached prompt", use_cache=True) + r2 = call_llm("cached prompt", use_cache=True) + + # Second call should use cache, not make another HTTP request + assert r1 == r2 == "first call" + assert mock_post.call_count == 1 + + +# --------------------------------------------------------------------------- +# Integration test – real API call (requires MINIMAX_API_KEY) +# --------------------------------------------------------------------------- + +_LIVE_API_KEY = os.environ.get("MINIMAX_API_KEY", "") + + +@pytest.mark.skipif( + not _LIVE_API_KEY, + reason="MINIMAX_API_KEY not set – skipping live integration test", +) +class TestMiniMaxLiveIntegration: + """Live integration tests that call the real MiniMax API.""" + + def test_simple_chat_completion(self, monkeypatch, tmp_path): + _clean_env(monkeypatch) + monkeypatch.setenv("MINIMAX_API_KEY", _LIVE_API_KEY) + monkeypatch.setattr("utils.call_llm.cache_file", str(tmp_path / "cache.json")) + + from utils.call_llm import call_llm + result = call_llm("Reply with exactly: HELLO", use_cache=False) + assert "HELLO" in result.upper() + + def test_longer_response(self, monkeypatch, tmp_path): + _clean_env(monkeypatch) + monkeypatch.setenv("MINIMAX_API_KEY", _LIVE_API_KEY) + monkeypatch.setattr("utils.call_llm.cache_file", str(tmp_path / "cache.json")) + + from utils.call_llm import call_llm + result = call_llm( + "List three colors, one per line. No other text.", + use_cache=False, + ) + lines = [l.strip() for l in result.strip().splitlines() if l.strip()] + assert len(lines) >= 3 + + def test_highspeed_model(self, monkeypatch, tmp_path): + _clean_env(monkeypatch) + monkeypatch.setenv("MINIMAX_API_KEY", _LIVE_API_KEY) + monkeypatch.setenv("MINIMAX_MODEL", "MiniMax-M2.7-highspeed") + monkeypatch.setattr("utils.call_llm.cache_file", str(tmp_path / "cache.json")) + + from utils.call_llm import call_llm + result = call_llm("Say OK", use_cache=False) + assert len(result) > 0 diff --git a/utils/call_llm.py b/utils/call_llm.py index 70c9e83a..bada3c29 100644 --- a/utils/call_llm.py +++ b/utils/call_llm.py @@ -47,10 +47,22 @@ def get_llm_provider(): provider = os.getenv("LLM_PROVIDER") if not provider and (os.getenv("GEMINI_PROJECT_ID") or os.getenv("GEMINI_API_KEY")): provider = "GEMINI" - # if necessary, add ANTHROPIC/OPENAI + if not provider and os.getenv("MINIMAX_API_KEY"): + provider = "MINIMAX" return provider +# Provider-specific defaults for base URL, model, and temperature range +_PROVIDER_DEFAULTS = { + "MINIMAX": { + "base_url": "https://api.minimax.io", + "model": "MiniMax-M2.7", + "min_temperature": 0.01, + "max_temperature": 1.0, + }, +} + + def _call_llm_provider(prompt: str) -> str: """ Call an LLM provider based on environment variables. @@ -73,9 +85,12 @@ def _call_llm_provider(prompt: str) -> str: base_url_var = f"{provider}_BASE_URL" api_key_var = f"{provider}_API_KEY" - # Read the provider-specific variables - model = os.environ.get(model_var) - base_url = os.environ.get(base_url_var) + # Look up provider defaults (if any) + defaults = _PROVIDER_DEFAULTS.get(provider, {}) + + # Read the provider-specific variables, falling back to defaults + model = os.environ.get(model_var) or defaults.get("model") + base_url = os.environ.get(base_url_var) or defaults.get("base_url") api_key = os.environ.get(api_key_var, "") # API key is optional, default to empty string # Validate required variables @@ -94,10 +109,19 @@ def _call_llm_provider(prompt: str) -> str: if api_key: # Only add Authorization header if API key is provided headers["Authorization"] = f"Bearer {api_key}" + # Clamp temperature to the provider's accepted range + temperature = 0.7 + min_temp = defaults.get("min_temperature") + max_temp = defaults.get("max_temperature") + if min_temp is not None: + temperature = max(temperature, min_temp) + if max_temp is not None: + temperature = min(temperature, max_temp) + payload = { "model": model, "messages": [{"role": "user", "content": prompt}], - "temperature": 0.7, + "temperature": temperature, } try: @@ -142,6 +166,9 @@ def call_llm(prompt: str, use_cache: bool = True) -> str: if provider == "GEMINI": response_text = _call_llm_gemini(prompt) else: # generic method using a URL that is OpenAI compatible API (Ollama, ...) + # Ensure LLM_PROVIDER is available for auto-detected providers + if provider and not os.environ.get("LLM_PROVIDER"): + os.environ["LLM_PROVIDER"] = provider response_text = _call_llm_provider(prompt) # Log the response