Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,9 @@ GEMINI_PROJECT_ID=<GEMINI_PROJECT_ID>
GEMINI_API_KEY=<GEMINI_API_KEY>
GITHUB_TOKEN=<GITHUB_TOKEN>
OPENROUTER_API_KEY = <OPENROUTER_API_KEY>
OPENROUTER_MODEL = <OPENROUTER_MODEL>
OPENROUTER_MODEL = <OPENROUTER_MODEL>
# MiniMax (https://www.minimax.io) — set API key to auto-detect
MINIMAX_API_KEY=<MINIMAX_API_KEY>
# Optional overrides (defaults: MiniMax-M2.7, https://api.minimax.io/v1)
# MINIMAX_MODEL=MiniMax-M2.7
# MINIMAX_BASE_URL=https://api.minimax.io
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ This is a tutorial project of [Pocket Flow](https://github.com/The-Pocket/Pocket
pip install -r requirements.txt
```

4. Set up LLM in [`utils/call_llm.py`](./utils/call_llm.py) by providing credentials. To do so, you can put the values in a `.env` file. By default, you can use the AI Studio key with this client for Gemini Pro 2.5 by setting the `GEMINI_API_KEY` environment variable. If you want to use another LLM, you can set the `LLM_PROVIDER` environment variable (e.g. `XAI`), and then set the model, url, and API key (e.g. `XAI_MODEL`, `XAI_URL`,`XAI_API_KEY`). If using Ollama, the url is `http://localhost:11434/` and the API key can be omitted.
4. Set up LLM in [`utils/call_llm.py`](./utils/call_llm.py) by providing credentials. To do so, you can put the values in a `.env` file. By default, you can use the AI Studio key with this client for Gemini Pro 2.5 by setting the `GEMINI_API_KEY` environment variable. If you want to use another LLM, you can set the `LLM_PROVIDER` environment variable (e.g. `XAI`), and then set the model, url, and API key (e.g. `XAI_MODEL`, `XAI_URL`,`XAI_API_KEY`). If using Ollama, the url is `http://localhost:11434/` and the API key can be omitted. For [MiniMax](https://www.minimax.io), just set `MINIMAX_API_KEY` — the provider, model (`MiniMax-M2.7`), and base URL are auto-configured.
You can use your own models. We highly recommend the latest models with thinking capabilities (Claude 3.7 with thinking, O1). You can verify that it is correctly set up by running:
```bash
python utils/call_llm.py
Expand Down
Empty file added tests/__init__.py
Empty file.
270 changes: 270 additions & 0 deletions tests/test_call_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
"""Unit tests for utils/call_llm.py – focused on MiniMax provider support."""

import json
import os
from unittest import mock

import pytest


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _clean_env(monkeypatch):
"""Remove all LLM-related env vars to start with a blank slate."""
for key in list(os.environ):
if key.startswith(("LLM_PROVIDER", "GEMINI_", "MINIMAX_", "OPENAI_", "XAI_", "OLLAMA_")):
monkeypatch.delenv(key, raising=False)


# ---------------------------------------------------------------------------
# get_llm_provider – auto-detection
# ---------------------------------------------------------------------------

class TestGetLlmProvider:
"""Tests for the provider auto-detection logic."""

def test_explicit_provider_takes_priority(self, monkeypatch):
_clean_env(monkeypatch)
monkeypatch.setenv("LLM_PROVIDER", "XAI")
monkeypatch.setenv("MINIMAX_API_KEY", "test-key")
from utils.call_llm import get_llm_provider
assert get_llm_provider() == "XAI"

def test_gemini_auto_detected_before_minimax(self, monkeypatch):
_clean_env(monkeypatch)
monkeypatch.setenv("GEMINI_API_KEY", "gemini-key")
monkeypatch.setenv("MINIMAX_API_KEY", "minimax-key")
from utils.call_llm import get_llm_provider
assert get_llm_provider() == "GEMINI"

def test_minimax_auto_detected_by_api_key(self, monkeypatch):
_clean_env(monkeypatch)
monkeypatch.setenv("MINIMAX_API_KEY", "test-key")
from utils.call_llm import get_llm_provider
assert get_llm_provider() == "MINIMAX"

def test_no_provider_detected(self, monkeypatch):
_clean_env(monkeypatch)
from utils.call_llm import get_llm_provider
assert get_llm_provider() is None


# ---------------------------------------------------------------------------
# _PROVIDER_DEFAULTS – MiniMax defaults
# ---------------------------------------------------------------------------

class TestProviderDefaults:
"""Verify built-in defaults for the MINIMAX provider."""

def test_minimax_defaults_exist(self):
from utils.call_llm import _PROVIDER_DEFAULTS
assert "MINIMAX" in _PROVIDER_DEFAULTS

def test_minimax_base_url(self):
from utils.call_llm import _PROVIDER_DEFAULTS
assert _PROVIDER_DEFAULTS["MINIMAX"]["base_url"] == "https://api.minimax.io"

def test_minimax_default_model(self):
from utils.call_llm import _PROVIDER_DEFAULTS
assert _PROVIDER_DEFAULTS["MINIMAX"]["model"] == "MiniMax-M2.7"

def test_minimax_temperature_range(self):
from utils.call_llm import _PROVIDER_DEFAULTS
d = _PROVIDER_DEFAULTS["MINIMAX"]
assert d["min_temperature"] > 0
assert d["max_temperature"] <= 1.0


# ---------------------------------------------------------------------------
# _call_llm_provider – MiniMax request construction
# ---------------------------------------------------------------------------

class TestCallLlmProviderMiniMax:
"""Test that _call_llm_provider builds the correct request for MiniMax."""

def test_uses_default_model_and_base_url(self, monkeypatch):
_clean_env(monkeypatch)
monkeypatch.setenv("LLM_PROVIDER", "MINIMAX")
monkeypatch.setenv("MINIMAX_API_KEY", "test-key-123")

mock_response = mock.MagicMock()
mock_response.json.return_value = {
"choices": [{"message": {"content": "Hello from MiniMax!"}}]
}
mock_response.raise_for_status = mock.MagicMock()

with mock.patch("utils.call_llm.requests.post", return_value=mock_response) as mock_post:
from utils.call_llm import _call_llm_provider
result = _call_llm_provider("test prompt")

assert result == "Hello from MiniMax!"
call_args = mock_post.call_args
url = call_args[0][0]
payload = call_args[1]["json"]
headers = call_args[1]["headers"]
assert "api.minimax.io" in url
assert payload["model"] == "MiniMax-M2.7"
assert headers["Authorization"] == "Bearer test-key-123"

def test_custom_model_overrides_default(self, monkeypatch):
_clean_env(monkeypatch)
monkeypatch.setenv("LLM_PROVIDER", "MINIMAX")
monkeypatch.setenv("MINIMAX_API_KEY", "key")
monkeypatch.setenv("MINIMAX_MODEL", "MiniMax-M2.7-highspeed")

mock_response = mock.MagicMock()
mock_response.json.return_value = {
"choices": [{"message": {"content": "fast"}}]
}
mock_response.raise_for_status = mock.MagicMock()

with mock.patch("utils.call_llm.requests.post", return_value=mock_response) as mock_post:
from utils.call_llm import _call_llm_provider
_call_llm_provider("prompt")

payload = mock_post.call_args[1]["json"]
assert payload["model"] == "MiniMax-M2.7-highspeed"

def test_custom_base_url_overrides_default(self, monkeypatch):
_clean_env(monkeypatch)
monkeypatch.setenv("LLM_PROVIDER", "MINIMAX")
monkeypatch.setenv("MINIMAX_API_KEY", "key")
monkeypatch.setenv("MINIMAX_BASE_URL", "https://custom.endpoint.io")

mock_response = mock.MagicMock()
mock_response.json.return_value = {
"choices": [{"message": {"content": "ok"}}]
}
mock_response.raise_for_status = mock.MagicMock()

with mock.patch("utils.call_llm.requests.post", return_value=mock_response) as mock_post:
from utils.call_llm import _call_llm_provider
_call_llm_provider("prompt")

url = mock_post.call_args[0][0]
assert url.startswith("https://custom.endpoint.io")

def test_temperature_clamped_to_provider_range(self, monkeypatch):
_clean_env(monkeypatch)
monkeypatch.setenv("LLM_PROVIDER", "MINIMAX")
monkeypatch.setenv("MINIMAX_API_KEY", "key")

mock_response = mock.MagicMock()
mock_response.json.return_value = {
"choices": [{"message": {"content": "ok"}}]
}
mock_response.raise_for_status = mock.MagicMock()

with mock.patch("utils.call_llm.requests.post", return_value=mock_response) as mock_post:
from utils.call_llm import _call_llm_provider
_call_llm_provider("prompt")

payload = mock_post.call_args[1]["json"]
temp = payload["temperature"]
assert 0 < temp <= 1.0, f"Temperature {temp} out of MiniMax range"

def test_no_defaults_for_unknown_provider(self, monkeypatch):
_clean_env(monkeypatch)
monkeypatch.setenv("LLM_PROVIDER", "CUSTOM")
# Without model and base_url, should raise ValueError
with pytest.raises(ValueError, match="CUSTOM_MODEL"):
from utils.call_llm import _call_llm_provider
_call_llm_provider("prompt")


# ---------------------------------------------------------------------------
# call_llm – end-to-end with MiniMax auto-detection
# ---------------------------------------------------------------------------

class TestCallLlmMiniMaxIntegration:
"""Test call_llm() with MiniMax auto-detection (env var only, mocked HTTP)."""

def test_auto_detect_and_call(self, monkeypatch, tmp_path):
_clean_env(monkeypatch)
monkeypatch.setenv("MINIMAX_API_KEY", "test-key")

# Use a temp cache file so we don't pollute the repo
monkeypatch.setattr("utils.call_llm.cache_file", str(tmp_path / "cache.json"))

mock_response = mock.MagicMock()
mock_response.json.return_value = {
"choices": [{"message": {"content": "MiniMax response"}}]
}
mock_response.raise_for_status = mock.MagicMock()

with mock.patch("utils.call_llm.requests.post", return_value=mock_response):
from utils.call_llm import call_llm
result = call_llm("hello", use_cache=False)

assert result == "MiniMax response"

def test_cache_works_with_minimax(self, monkeypatch, tmp_path):
_clean_env(monkeypatch)
monkeypatch.setenv("MINIMAX_API_KEY", "test-key")

cache_path = str(tmp_path / "cache.json")
monkeypatch.setattr("utils.call_llm.cache_file", cache_path)

mock_response = mock.MagicMock()
mock_response.json.return_value = {
"choices": [{"message": {"content": "first call"}}]
}
mock_response.raise_for_status = mock.MagicMock()

with mock.patch("utils.call_llm.requests.post", return_value=mock_response) as mock_post:
from utils.call_llm import call_llm
r1 = call_llm("cached prompt", use_cache=True)
r2 = call_llm("cached prompt", use_cache=True)

# Second call should use cache, not make another HTTP request
assert r1 == r2 == "first call"
assert mock_post.call_count == 1


# ---------------------------------------------------------------------------
# Integration test – real API call (requires MINIMAX_API_KEY)
# ---------------------------------------------------------------------------

_LIVE_API_KEY = os.environ.get("MINIMAX_API_KEY", "")


@pytest.mark.skipif(
not _LIVE_API_KEY,
reason="MINIMAX_API_KEY not set – skipping live integration test",
)
class TestMiniMaxLiveIntegration:
"""Live integration tests that call the real MiniMax API."""

def test_simple_chat_completion(self, monkeypatch, tmp_path):
_clean_env(monkeypatch)
monkeypatch.setenv("MINIMAX_API_KEY", _LIVE_API_KEY)
monkeypatch.setattr("utils.call_llm.cache_file", str(tmp_path / "cache.json"))

from utils.call_llm import call_llm
result = call_llm("Reply with exactly: HELLO", use_cache=False)
assert "HELLO" in result.upper()

def test_longer_response(self, monkeypatch, tmp_path):
_clean_env(monkeypatch)
monkeypatch.setenv("MINIMAX_API_KEY", _LIVE_API_KEY)
monkeypatch.setattr("utils.call_llm.cache_file", str(tmp_path / "cache.json"))

from utils.call_llm import call_llm
result = call_llm(
"List three colors, one per line. No other text.",
use_cache=False,
)
lines = [l.strip() for l in result.strip().splitlines() if l.strip()]
assert len(lines) >= 3

def test_highspeed_model(self, monkeypatch, tmp_path):
_clean_env(monkeypatch)
monkeypatch.setenv("MINIMAX_API_KEY", _LIVE_API_KEY)
monkeypatch.setenv("MINIMAX_MODEL", "MiniMax-M2.7-highspeed")
monkeypatch.setattr("utils.call_llm.cache_file", str(tmp_path / "cache.json"))

from utils.call_llm import call_llm
result = call_llm("Say OK", use_cache=False)
assert len(result) > 0
37 changes: 32 additions & 5 deletions utils/call_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,22 @@ def get_llm_provider():
provider = os.getenv("LLM_PROVIDER")
if not provider and (os.getenv("GEMINI_PROJECT_ID") or os.getenv("GEMINI_API_KEY")):
provider = "GEMINI"
# if necessary, add ANTHROPIC/OPENAI
if not provider and os.getenv("MINIMAX_API_KEY"):
provider = "MINIMAX"
return provider


# Provider-specific defaults for base URL, model, and temperature range
_PROVIDER_DEFAULTS = {
"MINIMAX": {
"base_url": "https://api.minimax.io",
"model": "MiniMax-M2.7",
"min_temperature": 0.01,
"max_temperature": 1.0,
},
}


def _call_llm_provider(prompt: str) -> str:
"""
Call an LLM provider based on environment variables.
Expand All @@ -73,9 +85,12 @@ def _call_llm_provider(prompt: str) -> str:
base_url_var = f"{provider}_BASE_URL"
api_key_var = f"{provider}_API_KEY"

# Read the provider-specific variables
model = os.environ.get(model_var)
base_url = os.environ.get(base_url_var)
# Look up provider defaults (if any)
defaults = _PROVIDER_DEFAULTS.get(provider, {})

# Read the provider-specific variables, falling back to defaults
model = os.environ.get(model_var) or defaults.get("model")
base_url = os.environ.get(base_url_var) or defaults.get("base_url")
api_key = os.environ.get(api_key_var, "") # API key is optional, default to empty string

# Validate required variables
Expand All @@ -94,10 +109,19 @@ def _call_llm_provider(prompt: str) -> str:
if api_key: # Only add Authorization header if API key is provided
headers["Authorization"] = f"Bearer {api_key}"

# Clamp temperature to the provider's accepted range
temperature = 0.7
min_temp = defaults.get("min_temperature")
max_temp = defaults.get("max_temperature")
if min_temp is not None:
temperature = max(temperature, min_temp)
if max_temp is not None:
temperature = min(temperature, max_temp)

payload = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.7,
"temperature": temperature,
}

try:
Expand Down Expand Up @@ -142,6 +166,9 @@ def call_llm(prompt: str, use_cache: bool = True) -> str:
if provider == "GEMINI":
response_text = _call_llm_gemini(prompt)
else: # generic method using a URL that is OpenAI compatible API (Ollama, ...)
# Ensure LLM_PROVIDER is available for auto-detected providers
if provider and not os.environ.get("LLM_PROVIDER"):
os.environ["LLM_PROVIDER"] = provider
response_text = _call_llm_provider(prompt)

# Log the response
Expand Down