Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/dstack/_internal/cli/services/configurators/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def interpolate_env(self, conf: RunConfigurationT):
password=interpolator.interpolate_or_error(conf.registry_auth.password),
)
if isinstance(conf, ServiceConfiguration):
for probe in conf.probes:
for probe in conf.probes or []:
for header in probe.headers:
header.value = interpolator.interpolate_or_error(header.value)
if probe.url:
Expand Down
39 changes: 36 additions & 3 deletions src/dstack/_internal/core/models/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
DEFAULT_PROBE_METHOD = "get"
MAX_PROBE_URL_LEN = 2048
DEFAULT_REPLICA_GROUP_NAME = "0"
DEFAULT_MODEL_PROBE_TIMEOUT = 30
DEFAULT_MODEL_PROBE_URL = "/v1/chat/completions"


class RunConfigurationType(str, Enum):
Expand Down Expand Up @@ -851,9 +853,9 @@ class ServiceConfigurationParams(CoreModel):
] = None
rate_limits: Annotated[list[RateLimit], Field(description="Rate limiting rules")] = []
probes: Annotated[
list[ProbeConfig],
Optional[list[ProbeConfig]],
Field(description="List of probes used to determine job health"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider documenting the default probe either here or in model. And optionally in concepts/services.md

] = []
] = None # None = omitted (may get default when model is set); [] = explicit empty

replicas: Annotated[
Optional[Union[List[ReplicaGroup], Range[int]]],
Expand Down Expand Up @@ -895,7 +897,9 @@ def validate_rate_limits(cls, v: list[RateLimit]) -> list[RateLimit]:
return v

@validator("probes")
def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]:
def validate_probes(cls, v: Optional[list[ProbeConfig]]) -> Optional[list[ProbeConfig]]:
if v is None:
return v
if has_duplicates(v):
# Using a custom validator instead of Field(unique_items=True) to avoid Pydantic bug:
# https://github.com/pydantic/pydantic/issues/3765
Expand Down Expand Up @@ -932,6 +936,35 @@ def validate_replicas(
)
return v

@root_validator()
def set_default_probes_for_model(cls, values):
Comment on lines +939 to +940
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sets the default on the client side, and we prefer server-side defaults.

Consider handling the default probe in JobConfigurator._probes instead, so that the configuration always holds the initial value (None) and only the job spec holds the default probe.

model = values.get("model")
probes = values.get("probes")
if model is not None and probes is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit)

model is not None

This check also passes for models that are declared using the full syntax that are not OpenAI-compatible or not chat models (once we support those). Such models don't necessarily provide /v1/chat/completions, so the probe will not work.

A better check would be isinstance(configuration.model, OpenAIChatModel).

body = orjson.dumps(
{
"model": model.name,
"messages": [{"role": "user", "content": "hi"}],
"max_tokens": 1,
}
).decode("utf-8")
values["probes"] = [
ProbeConfig(
type="http",
method="post",
url=DEFAULT_MODEL_PROBE_URL,
headers=[
HTTPHeaderSpec(name="Content-Type", value="application/json"),
],
body=body,
timeout=DEFAULT_MODEL_PROBE_TIMEOUT,
)
]
elif probes is None:
# Probes omitted and model not set: normalize to empty list for downstream.
values["probes"] = []
return values

@root_validator()
def validate_scaling(cls, values):
scaling = values.get("scaling")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def _service_port(self) -> Optional[int]:

def _probes(self) -> list[ProbeSpec]:
if isinstance(self.run_spec.configuration, ServiceConfiguration):
return list(map(_probe_config_to_spec, self.run_spec.configuration.probes))
return list(map(_probe_config_to_spec, self.run_spec.configuration.probes or []))
return []


Expand Down
4 changes: 2 additions & 2 deletions src/dstack/_internal/server/services/runs/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ def validate_run_spec_and_set_defaults(
raise ServerClientError(
"Scheduled services with autoscaling to zero are not supported"
)
if len(run_spec.configuration.probes) > settings.MAX_PROBES_PER_JOB:
if len(run_spec.configuration.probes or []) > settings.MAX_PROBES_PER_JOB:
raise ServerClientError(
f"Cannot configure more than {settings.MAX_PROBES_PER_JOB} probes"
)
if any(
p.timeout is not None and p.timeout > settings.MAX_PROBE_TIMEOUT
for p in run_spec.configuration.probes
for p in (run_spec.configuration.probes or [])
):
raise ServerClientError(
f"Probe timeout cannot be longer than {settings.MAX_PROBE_TIMEOUT}s"
Expand Down
45 changes: 45 additions & 0 deletions src/tests/_internal/core/models/test_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from dstack._internal.core.errors import ConfigurationError
from dstack._internal.core.models.common import RegistryAuth
from dstack._internal.core.models.configurations import (
DEFAULT_MODEL_PROBE_TIMEOUT,
DEFAULT_MODEL_PROBE_URL,
DevEnvironmentConfigurationParams,
RepoSpec,
parse_run_configuration,
Expand All @@ -13,6 +15,49 @@


class TestParseConfiguration:
def test_service_model_sets_default_probes_when_probes_omitted(self):
conf = {
"type": "service",
"commands": ["python3 -m http.server"],
"port": 8000,
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
}
parsed = parse_run_configuration(conf)
assert len(parsed.probes) == 1
probe = parsed.probes[0]
assert probe.type == "http"
assert probe.method == "post"
assert probe.url == DEFAULT_MODEL_PROBE_URL
assert probe.timeout == DEFAULT_MODEL_PROBE_TIMEOUT
assert len(probe.headers) == 1
assert probe.headers[0].name == "Content-Type"
assert probe.headers[0].value == "application/json"
assert "meta-llama/Meta-Llama-3.1-8B-Instruct" in (probe.body or "")
assert "max_tokens" in (probe.body or "")

def test_service_model_does_not_override_explicit_probes(self):
conf = {
"type": "service",
"commands": ["python3 -m http.server"],
"port": 8000,
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"probes": [{"type": "http", "url": "/health"}],
}
parsed = parse_run_configuration(conf)
assert len(parsed.probes) == 1
assert parsed.probes[0].url == "/health"

def test_service_model_explicit_empty_probes_no_default(self):
conf = {
"type": "service",
"commands": ["python3 -m http.server"],
"port": 8000,
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"probes": [],
}
parsed = parse_run_configuration(conf)
assert len(parsed.probes) == 0

def test_services_replicas_and_scaling(self):
def test_conf(replicas: Any, scaling: Optional[Any] = None):
conf = {
Expand Down