Skip to content

Commit 5abc293

Browse files
author
SentienceDEV
committed
fix tests
1 parent 8e01d87 commit 5abc293

File tree

3 files changed

+231
-19
lines changed

3 files changed

+231
-19
lines changed

predicate/llm_provider.py

Lines changed: 215 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -376,12 +376,12 @@ def supports_vision(self) -> bool:
376376
return super().supports_vision()
377377

378378

379-
class OllamaProvider(OpenAIProvider):
379+
class OllamaProvider(LLMProvider):
380380
"""
381381
Ollama local LLM provider via OpenAI-compatible API.
382382
383383
Ollama serves models locally and provides an OpenAI-compatible endpoint at /v1.
384-
This provider wraps OpenAIProvider with sensible defaults for local inference.
384+
This provider uses HTTP requests directly without requiring the openai package.
385385
386386
Example:
387387
>>> from predicate.llm_provider import OllamaProvider
@@ -397,6 +397,7 @@ def __init__(
397397
self,
398398
model: str,
399399
base_url: str = "http://localhost:11434",
400+
timeout_seconds: float = 120.0,
400401
**kwargs,
401402
):
402403
"""
@@ -405,16 +406,117 @@ def __init__(
405406
Args:
406407
model: Ollama model name (e.g., "qwen3:8b", "llama3:8b", "mistral:7b")
407408
base_url: Ollama server URL (default: http://localhost:11434)
408-
**kwargs: Additional parameters passed to OpenAIProvider
409+
timeout_seconds: Request timeout in seconds (default: 120)
410+
**kwargs: Additional parameters (reserved for future use)
409411
"""
410-
# Ollama serves OpenAI-compatible API at /v1
411-
super().__init__(
412-
model=model,
413-
base_url=f"{base_url.rstrip('/')}/v1",
414-
api_key="ollama", # Ollama doesn't require a real API key
415-
**kwargs,
412+
super().__init__(model)
413+
self._ollama_base_url = base_url.rstrip("/")
414+
self._api_base_url = f"{self._ollama_base_url}/v1"
415+
self._timeout_seconds = timeout_seconds
416+
417+
def generate(
418+
self,
419+
system_prompt: str,
420+
user_prompt: str,
421+
temperature: float = 0.0,
422+
max_tokens: int | None = None,
423+
json_mode: bool = False,
424+
**kwargs,
425+
) -> LLMResponse:
426+
"""
427+
Generate response using Ollama's OpenAI-compatible API.
428+
429+
Args:
430+
system_prompt: System instruction
431+
user_prompt: User query
432+
temperature: Sampling temperature (0.0 = deterministic, 1.0 = creative)
433+
max_tokens: Maximum tokens to generate
434+
json_mode: Enable JSON response format (model-dependent support)
435+
**kwargs: Additional API parameters (max_new_tokens is mapped to max_tokens)
436+
437+
Returns:
438+
LLMResponse object
439+
"""
440+
import json
441+
import urllib.request
442+
import urllib.error
443+
444+
# Handle max_new_tokens -> max_tokens mapping for cross-provider compatibility
445+
if "max_new_tokens" in kwargs:
446+
if max_tokens is None:
447+
max_tokens = kwargs.pop("max_new_tokens")
448+
else:
449+
kwargs.pop("max_new_tokens") # max_tokens takes precedence
450+
451+
messages = []
452+
if system_prompt:
453+
messages.append({"role": "system", "content": system_prompt})
454+
messages.append({"role": "user", "content": user_prompt})
455+
456+
# Build API parameters
457+
api_params: dict[str, Any] = {
458+
"model": self._model_name,
459+
"messages": messages,
460+
"temperature": temperature,
461+
}
462+
463+
if max_tokens:
464+
api_params["max_tokens"] = max_tokens
465+
466+
if json_mode and self.supports_json_mode():
467+
api_params["response_format"] = {"type": "json_object"}
468+
469+
# Merge additional parameters (excluding internal ones)
470+
for key, value in kwargs.items():
471+
if key not in api_params:
472+
api_params[key] = value
473+
474+
# Make HTTP request to Ollama's OpenAI-compatible endpoint
475+
url = f"{self._api_base_url}/chat/completions"
476+
headers = {
477+
"Content-Type": "application/json",
478+
"Authorization": "Bearer ollama", # Ollama doesn't require a real API key
479+
}
480+
481+
try:
482+
request_data = json.dumps(api_params).encode("utf-8")
483+
req = urllib.request.Request(url, data=request_data, headers=headers, method="POST")
484+
with urllib.request.urlopen(req, timeout=self._timeout_seconds) as response:
485+
response_data = json.loads(response.read().decode("utf-8"))
486+
except urllib.error.URLError as e:
487+
raise RuntimeError(
488+
f"Failed to connect to Ollama at {self._ollama_base_url}. "
489+
f"Ensure Ollama is running: {e}"
490+
) from e
491+
except urllib.error.HTTPError as e:
492+
raise RuntimeError(
493+
f"Ollama API error: {e.code} - {e.reason}"
494+
) from e
495+
except json.JSONDecodeError as e:
496+
raise RuntimeError(f"Failed to parse Ollama response: {e}") from e
497+
498+
# Parse response
499+
choice = response_data.get("choices", [{}])[0]
500+
usage = response_data.get("usage", {})
501+
message = choice.get("message", {})
502+
503+
return LLMResponseBuilder.from_openai_format(
504+
content=message.get("content", ""),
505+
prompt_tokens=usage.get("prompt_tokens"),
506+
completion_tokens=usage.get("completion_tokens"),
507+
total_tokens=usage.get("total_tokens"),
508+
model_name=response_data.get("model", self._model_name),
509+
finish_reason=choice.get("finish_reason"),
416510
)
417-
self._ollama_base_url = base_url
511+
512+
@property
513+
def model_name(self) -> str:
514+
return self._model_name
515+
516+
@property
517+
def ollama_base_url(self) -> str:
518+
"""Return the Ollama server base URL."""
519+
return self._ollama_base_url
418520

419521
@property
420522
def is_local(self) -> bool:
@@ -445,6 +547,109 @@ def supports_vision(self) -> bool:
445547
model_lower = self._model_name.lower()
446548
return any(x in model_lower for x in ["llava", "bakllava", "moondream"])
447549

550+
def generate_with_image(
551+
self,
552+
system_prompt: str,
553+
user_prompt: str,
554+
image_base64: str,
555+
temperature: float = 0.0,
556+
max_tokens: int | None = None,
557+
**kwargs,
558+
) -> LLMResponse:
559+
"""
560+
Generate response with image input using Ollama's vision models.
561+
562+
Args:
563+
system_prompt: System instruction
564+
user_prompt: User query
565+
image_base64: Base64-encoded image (PNG or JPEG)
566+
temperature: Sampling temperature (0.0 = deterministic)
567+
max_tokens: Maximum tokens to generate
568+
**kwargs: Additional API parameters
569+
570+
Returns:
571+
LLMResponse object
572+
573+
Raises:
574+
NotImplementedError: If model doesn't support vision
575+
"""
576+
import json
577+
import urllib.request
578+
import urllib.error
579+
580+
if not self.supports_vision():
581+
raise NotImplementedError(
582+
f"Model {self._model_name} does not support vision. "
583+
"Use llava, bakllava, or moondream models."
584+
)
585+
586+
messages = []
587+
if system_prompt:
588+
messages.append({"role": "system", "content": system_prompt})
589+
590+
# Vision message format with image_url
591+
messages.append(
592+
{
593+
"role": "user",
594+
"content": [
595+
{"type": "text", "text": user_prompt},
596+
{
597+
"type": "image_url",
598+
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
599+
},
600+
],
601+
}
602+
)
603+
604+
# Build API parameters
605+
api_params: dict[str, Any] = {
606+
"model": self._model_name,
607+
"messages": messages,
608+
"temperature": temperature,
609+
}
610+
611+
if max_tokens:
612+
api_params["max_tokens"] = max_tokens
613+
614+
# Merge additional parameters
615+
for key, value in kwargs.items():
616+
if key not in api_params:
617+
api_params[key] = value
618+
619+
# Make HTTP request
620+
url = f"{self._api_base_url}/chat/completions"
621+
headers = {
622+
"Content-Type": "application/json",
623+
"Authorization": "Bearer ollama",
624+
}
625+
626+
try:
627+
request_data = json.dumps(api_params).encode("utf-8")
628+
req = urllib.request.Request(url, data=request_data, headers=headers, method="POST")
629+
with urllib.request.urlopen(req, timeout=self._timeout_seconds) as response:
630+
response_data = json.loads(response.read().decode("utf-8"))
631+
except urllib.error.URLError as e:
632+
raise RuntimeError(
633+
f"Failed to connect to Ollama at {self._ollama_base_url}. "
634+
f"Ensure Ollama is running: {e}"
635+
) from e
636+
except urllib.error.HTTPError as e:
637+
raise RuntimeError(f"Ollama API error: {e.code} - {e.reason}") from e
638+
639+
# Parse response
640+
choice = response_data.get("choices", [{}])[0]
641+
usage = response_data.get("usage", {})
642+
message = choice.get("message", {})
643+
644+
return LLMResponseBuilder.from_openai_format(
645+
content=message.get("content", ""),
646+
prompt_tokens=usage.get("prompt_tokens"),
647+
completion_tokens=usage.get("completion_tokens"),
648+
total_tokens=usage.get("total_tokens"),
649+
model_name=response_data.get("model", self._model_name),
650+
finish_reason=choice.get("finish_reason"),
651+
)
652+
448653

449654
class AnthropicProvider(LLMProvider):
450655
"""

tests/test_ollama_provider.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,36 @@
22

33
import pytest
44

5-
from predicate.llm_provider import OllamaProvider, OpenAIProvider
5+
from predicate.llm_provider import OllamaProvider, LLMProvider
66

77

88
class TestOllamaProvider:
99
"""Test suite for OllamaProvider."""
1010

11-
def test_ollama_provider_is_subclass_of_openai(self):
12-
"""OllamaProvider should inherit from OpenAIProvider."""
13-
assert issubclass(OllamaProvider, OpenAIProvider)
11+
def test_ollama_provider_is_subclass_of_llm_provider(self):
12+
"""OllamaProvider should inherit from LLMProvider (not OpenAIProvider)."""
13+
assert issubclass(OllamaProvider, LLMProvider)
1414

1515
def test_ollama_provider_default_base_url(self):
1616
"""OllamaProvider should use default localhost:11434 base URL."""
1717
provider = OllamaProvider(model="qwen3:8b")
18-
# The internal client should have base_url set to /v1 endpoint
18+
# The internal base URL should be set correctly
1919
assert provider._ollama_base_url == "http://localhost:11434"
20+
assert provider.ollama_base_url == "http://localhost:11434"
2021

2122
def test_ollama_provider_custom_base_url(self):
2223
"""OllamaProvider should accept custom base URL."""
2324
provider = OllamaProvider(model="llama3:8b", base_url="http://192.168.1.100:11434")
2425
assert provider._ollama_base_url == "http://192.168.1.100:11434"
26+
assert provider.ollama_base_url == "http://192.168.1.100:11434"
2527

2628
def test_ollama_provider_strips_trailing_slash(self):
2729
"""OllamaProvider should strip trailing slash from base URL."""
2830
provider = OllamaProvider(model="mistral:7b", base_url="http://localhost:11434/")
29-
# The /v1 should be appended correctly without double slash
30-
assert provider._ollama_base_url == "http://localhost:11434/"
31-
# The actual OpenAI client base_url should be properly formed
32-
# (trailing slash stripped before /v1 is appended)
31+
# The trailing slash should be stripped
32+
assert provider._ollama_base_url == "http://localhost:11434"
33+
# The API base URL should be properly formed
34+
assert provider._api_base_url == "http://localhost:11434/v1"
3335

3436
def test_ollama_provider_is_local_property(self):
3537
"""OllamaProvider.is_local should return True."""

traces/test-run.jsonl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,8 @@
1313
{"v": 1, "type": "run_start", "ts": "2026-03-29T01:08:28.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774746508045}
1414
{"v": 1, "type": "run_start", "ts": "2026-03-29T01:08:28.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774746508046}
1515
{"v": 1, "type": "run_start", "ts": "2026-03-29T01:08:28.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774746508134}
16+
{"v": 1, "type": "run_start", "ts": "2026-03-29T01:51:59.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774749119216}
17+
{"v": 1, "type": "run_start", "ts": "2026-03-29T01:51:59.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774749119218}
18+
{"v": 1, "type": "run_start", "ts": "2026-03-29T01:51:59.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774749119219}
19+
{"v": 1, "type": "run_start", "ts": "2026-03-29T01:51:59.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774749119221}
20+
{"v": 1, "type": "run_start", "ts": "2026-03-29T01:51:59.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774749119306}

0 commit comments

Comments
 (0)