From 76a26ff06c5ee8c90d1d9e78ca695a3399841c34 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Tue, 24 Mar 2026 18:03:59 -0400 Subject: [PATCH 1/2] Add x-cog-accept MIME type annotation for Path/File input fields Add an 'accept' parameter to Input() that specifies allowed MIME types or file extensions for Path/File inputs. The Go static schema generator extracts this and emits it as 'x-cog-accept' in the OpenAPI schema, giving schema consumers (UIs, validators, API clients) visibility into what file types an input expects. Usage: Input(accept="image/*"), Input(accept="audio/wav,audio/mp3"), or Input(accept=".safetensors,.bin"). Using accept on non-Path/File types is a hard build error. --- .../tests/accept_mime_type.txtar | 55 +++++++++++ .../tests/accept_mime_type_error.txtar | 23 +++++ pkg/schema/errors.go | 1 + pkg/schema/openapi.go | 5 + pkg/schema/openapi_test.go | 74 +++++++++++++++ pkg/schema/python/parser.go | 67 ++++++++------ pkg/schema/python/parser_test.go | 91 +++++++++++++++++++ pkg/schema/types.go | 1 + python/cog/input.py | 6 ++ 9 files changed, 295 insertions(+), 28 deletions(-) create mode 100644 integration-tests/tests/accept_mime_type.txtar create mode 100644 integration-tests/tests/accept_mime_type_error.txtar diff --git a/integration-tests/tests/accept_mime_type.txtar b/integration-tests/tests/accept_mime_type.txtar new file mode 100644 index 0000000000..ded967ff04 --- /dev/null +++ b/integration-tests/tests/accept_mime_type.txtar @@ -0,0 +1,55 @@ +# Test that the accept parameter on Path/File inputs produces the +# x-cog-accept annotation in the generated OpenAPI schema. +# +# Verifies: +# - accept="image/*" on a Path input emits x-cog-accept in the schema +# - accept with multiple MIME types works +# - accept with file extensions works +# - Fields without accept do not have x-cog-accept +# - Prediction still works end-to-end + +cog build -t $TEST_IMAGE + +# Extract the schema from the image label +exec docker inspect $TEST_IMAGE --format '{{index .Config.Labels "run.cog.openapi_schema"}}' + +# x-cog-accept annotations are present +stdout '"x-cog-accept":"image/\*"' +stdout '"x-cog-accept":"audio/wav,audio/mp3"' +stdout '"x-cog-accept":".safetensors,.bin"' + +# The prompt field (str) should NOT have x-cog-accept +# (we check the schema has prompt but confirm no extra x-cog-accept entries) +stdout '"prompt":' + +# Path fields still have uri format +stdout '"format":"uri"' + +# Prediction works end-to-end +cog predict $TEST_IMAGE -i prompt=hello -i image=@test.png +stdout 'hello-png' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from typing import Optional + +from cog import BasePredictor, Input, Path + + +class Predictor(BasePredictor): + def predict( + self, + prompt: str = Input(description="Text prompt", default="test"), + image: Path = Input(description="Input image", accept="image/*"), + audio: Optional[Path] = Input(description="Audio clip", accept="audio/wav,audio/mp3", default=None), + weights: Optional[Path] = Input(description="Model weights", accept=".safetensors,.bin", default=None), + ) -> str: + ext = str(image).split(".")[-1] + return f"{prompt}-{ext}" + +-- test.png -- +fake image content diff --git a/integration-tests/tests/accept_mime_type_error.txtar b/integration-tests/tests/accept_mime_type_error.txtar new file mode 100644 index 0000000000..c5da86a0a3 --- /dev/null +++ b/integration-tests/tests/accept_mime_type_error.txtar @@ -0,0 +1,23 @@ +# Test that using accept on a non-Path/File input type causes a build error. +# +# The accept parameter is only valid on Path or File inputs. Using it on +# str, int, float, etc. should produce a clear error at build time. + +! cog build -t $TEST_IMAGE +stderr 'accept is only valid on Path or File inputs' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from cog import BasePredictor, Input + + +class Predictor(BasePredictor): + def predict( + self, + name: str = Input(description="User name", accept="text/plain"), + ) -> str: + return f"hello {name}" diff --git a/pkg/schema/errors.go b/pkg/schema/errors.go index c046df4209..81a8cb3402 100644 --- a/pkg/schema/errors.go +++ b/pkg/schema/errors.go @@ -28,6 +28,7 @@ const ( ErrChoicesNotResolvable ErrDefaultNotResolvable ErrUnresolvableType + ErrAcceptOnNonFileType ErrOther ) diff --git a/pkg/schema/openapi.go b/pkg/schema/openapi.go index d0cf252ea0..6edb43db1b 100644 --- a/pkg/schema/openapi.go +++ b/pkg/schema/openapi.go @@ -365,6 +365,11 @@ func buildInputSchema(info *PredictorInfo) (map[string]any, []enumSchema) { prop.Set("deprecated", true) } + // MIME type constraint for Path/File inputs + if field.Accept != nil { + prop.Set("x-cog-accept", *field.Accept) + } + properties.Set(name, prop) }) diff --git a/pkg/schema/openapi_test.go b/pkg/schema/openapi_test.go index af85736748..d4cd44123e 100644 --- a/pkg/schema/openapi_test.go +++ b/pkg/schema/openapi_test.go @@ -682,6 +682,80 @@ func TestMultipleInputTypes(t *testing.T) { assert.NotContains(t, required, "secret_key") } +// --------------------------------------------------------------------------- +// Tests: Accept (MIME type) annotation +// --------------------------------------------------------------------------- + +func TestAcceptAnnotation(t *testing.T) { + accept := "image/*" + inputs := NewOrderedMap[string, InputField]() + inputs.Set("image", InputField{ + Name: "image", + Order: 0, + FieldType: FieldType{Primitive: TypePath, Repetition: Required}, + Accept: &accept, + }) + + info := &PredictorInfo{ + Inputs: inputs, + Output: SchemaPrim(TypeString), + Mode: ModePredict, + } + + spec := parseSpec(t, info) + props := getPath(spec, "components", "schemas", "Input", "properties").(map[string]any) + + imageField := props["image"].(map[string]any) + assert.Equal(t, "string", imageField["type"]) + assert.Equal(t, "uri", imageField["format"]) + assert.Equal(t, "image/*", imageField["x-cog-accept"]) +} + +func TestAcceptAnnotationMultipleMimeTypes(t *testing.T) { + accept := "audio/wav,audio/mp3,audio/flac" + inputs := NewOrderedMap[string, InputField]() + inputs.Set("audio", InputField{ + Name: "audio", + Order: 0, + FieldType: FieldType{Primitive: TypePath, Repetition: Required}, + Accept: &accept, + }) + + info := &PredictorInfo{ + Inputs: inputs, + Output: SchemaPrim(TypeString), + Mode: ModePredict, + } + + spec := parseSpec(t, info) + props := getPath(spec, "components", "schemas", "Input", "properties").(map[string]any) + + audioField := props["audio"].(map[string]any) + assert.Equal(t, "audio/wav,audio/mp3,audio/flac", audioField["x-cog-accept"]) +} + +func TestAcceptAnnotationNotPresentWhenNil(t *testing.T) { + inputs := NewOrderedMap[string, InputField]() + inputs.Set("image", InputField{ + Name: "image", + Order: 0, + FieldType: FieldType{Primitive: TypePath, Repetition: Required}, + }) + + info := &PredictorInfo{ + Inputs: inputs, + Output: SchemaPrim(TypeString), + Mode: ModePredict, + } + + spec := parseSpec(t, info) + props := getPath(spec, "components", "schemas", "Input", "properties").(map[string]any) + + imageField := props["image"].(map[string]any) + _, hasAccept := imageField["x-cog-accept"] + assert.False(t, hasAccept, "x-cog-accept should not be present when Accept is nil") +} + // --------------------------------------------------------------------------- // Tests: Edge cases // --------------------------------------------------------------------------- diff --git a/pkg/schema/python/parser.go b/pkg/schema/python/parser.go index a2e79c40bc..be9a981f94 100644 --- a/pkg/schema/python/parser.go +++ b/pkg/schema/python/parser.go @@ -731,6 +731,7 @@ type inputCallInfo struct { Regex *string Choices []schema.DefaultValue Deprecated *bool + Accept *string } type inputMethodInfo struct { @@ -1148,38 +1149,20 @@ func parseTypedDefaultParameter( if err != nil { return schema.InputField{}, err } - return schema.InputField{ - Name: name, - Order: order, - FieldType: fieldType, - Default: info.Default, - Description: info.Description, - GE: info.GE, - LE: info.LE, - MinLength: info.MinLength, - MaxLength: info.MaxLength, - Regex: info.Regex, - Choices: info.Choices, - Deprecated: info.Deprecated, - }, nil + field, err := inputCallInfoToField(name, order, fieldType, info) + if err != nil { + return schema.InputField{}, err + } + return field, nil } // 2. Reference to Input() via class attribute or static method if info, ok := resolveInputReference(valNode, source, registry); ok { - return schema.InputField{ - Name: name, - Order: order, - FieldType: fieldType, - Default: info.Default, - Description: info.Description, - GE: info.GE, - LE: info.LE, - MinLength: info.MinLength, - MaxLength: info.MaxLength, - Regex: info.Regex, - Choices: info.Choices, - Deprecated: info.Deprecated, - }, nil + field, err := inputCallInfoToField(name, order, fieldType, info) + if err != nil { + return schema.InputField{}, err + } + return field, nil } // 3. Plain default — must be statically resolvable @@ -1339,6 +1322,30 @@ func isInputCall(node *sitter.Node, source []byte, imports *schema.ImportContext return false } +// inputCallInfoToField converts parsed Input() kwargs into an InputField, +// validating that accept is only used on Path/File types. +func inputCallInfoToField(name string, order int, fieldType schema.FieldType, info inputCallInfo) (schema.InputField, error) { + if info.Accept != nil && fieldType.Primitive != schema.TypePath && fieldType.Primitive != schema.TypeFile { + return schema.InputField{}, schema.NewError(schema.ErrAcceptOnNonFileType, + fmt.Sprintf("accept is only valid on Path or File inputs (parameter '%s')", name)) + } + return schema.InputField{ + Name: name, + Order: order, + FieldType: fieldType, + Default: info.Default, + Description: info.Description, + GE: info.GE, + LE: info.LE, + MinLength: info.MinLength, + MaxLength: info.MaxLength, + Regex: info.Regex, + Choices: info.Choices, + Deprecated: info.Deprecated, + Accept: info.Accept, + }, nil +} + func parseInputCall(node *sitter.Node, source []byte, paramName string, scope moduleScope) (inputCallInfo, error) { var info inputCallInfo @@ -1406,6 +1413,10 @@ func parseInputCall(node *sitter.Node, source []byte, paramName string, scope mo if b, ok := parseBoolLiteral(valNode, source); ok { info.Deprecated = &b } + case "accept": + if s, ok := parseStringLiteral(valNode, source); ok { + info.Accept = &s + } } } diff --git a/pkg/schema/python/parser_test.go b/pkg/schema/python/parser_test.go index 3458b643c6..4f11c0ba36 100644 --- a/pkg/schema/python/parser_test.go +++ b/pkg/schema/python/parser_test.go @@ -922,6 +922,97 @@ class Predictor(BasePredictor): require.True(t, *old.Deprecated) } +// --------------------------------------------------------------------------- +// Accept MIME type +// --------------------------------------------------------------------------- + +func TestAcceptMimeType(t *testing.T) { + source := ` +from cog import BasePredictor, Input, Path + +class Predictor(BasePredictor): + def predict(self, image: Path = Input(description="An image", accept="image/*")) -> str: + pass +` + info := parse(t, source, "Predictor") + image, ok := info.Inputs.Get("image") + require.True(t, ok) + require.NotNil(t, image.Accept) + require.Equal(t, "image/*", *image.Accept) +} + +func TestAcceptMultipleMimeTypes(t *testing.T) { + source := ` +from cog import BasePredictor, Input, Path + +class Predictor(BasePredictor): + def predict(self, audio: Path = Input(accept="audio/wav,audio/mp3")) -> str: + pass +` + info := parse(t, source, "Predictor") + audio, ok := info.Inputs.Get("audio") + require.True(t, ok) + require.NotNil(t, audio.Accept) + require.Equal(t, "audio/wav,audio/mp3", *audio.Accept) +} + +func TestAcceptFileExtensions(t *testing.T) { + source := ` +from cog import BasePredictor, Input, Path + +class Predictor(BasePredictor): + def predict(self, weights: Path = Input(accept=".safetensors,.bin")) -> str: + pass +` + info := parse(t, source, "Predictor") + weights, ok := info.Inputs.Get("weights") + require.True(t, ok) + require.NotNil(t, weights.Accept) + require.Equal(t, ".safetensors,.bin", *weights.Accept) +} + +func TestAcceptOnFileType(t *testing.T) { + source := ` +from cog import BasePredictor, Input, File + +class Predictor(BasePredictor): + def predict(self, f: File = Input(accept="image/png")) -> str: + pass +` + info := parse(t, source, "Predictor") + f, ok := info.Inputs.Get("f") + require.True(t, ok) + require.NotNil(t, f.Accept) + require.Equal(t, "image/png", *f.Accept) +} + +func TestAcceptOnNonFileTypeErrors(t *testing.T) { + source := ` +from cog import BasePredictor, Input + +class Predictor(BasePredictor): + def predict(self, name: str = Input(accept="image/*")) -> str: + pass +` + se := parseErr(t, source, "Predictor", schema.ModePredict) + require.Equal(t, schema.ErrAcceptOnNonFileType, se.Kind) + require.Contains(t, se.Error(), "name") +} + +func TestAcceptNotSetWhenOmitted(t *testing.T) { + source := ` +from cog import BasePredictor, Input, Path + +class Predictor(BasePredictor): + def predict(self, image: Path = Input(description="An image")) -> str: + pass +` + info := parse(t, source, "Predictor") + image, ok := info.Inputs.Get("image") + require.True(t, ok) + require.Nil(t, image.Accept) +} + // --------------------------------------------------------------------------- // File type (deprecated alias for Path) // --------------------------------------------------------------------------- diff --git a/pkg/schema/types.go b/pkg/schema/types.go index 46a7007d29..c4b0c6ed4d 100644 --- a/pkg/schema/types.go +++ b/pkg/schema/types.go @@ -187,6 +187,7 @@ type InputField struct { Regex *string Choices []DefaultValue Deprecated *bool + Accept *string // MIME types / file extensions for Path/File inputs (e.g. "image/*") } // IsRequired returns true if this field is required in the schema. diff --git a/python/cog/input.py b/python/cog/input.py index 9a9063412b..02531b1a4d 100644 --- a/python/cog/input.py +++ b/python/cog/input.py @@ -27,6 +27,7 @@ class FieldInfo: regex: Optional[str] = None choices: Optional[List[Union[str, int]]] = None deprecated: Optional[bool] = None + accept: Optional[str] = None def Input( @@ -41,6 +42,7 @@ def Input( regex: Optional[str] = None, choices: Optional[List[Union[str, int]]] = None, deprecated: Optional[bool] = None, + accept: Optional[str] = None, ) -> Any: """ Create an input field specification for a predictor parameter. @@ -70,6 +72,9 @@ def predict( regex: Regular expression pattern for string inputs. choices: List of allowed values. deprecated: Whether the input is deprecated. + accept: Allowed MIME types or file extensions for Path/File inputs, + using the same format as the HTML accept attribute (e.g. + ``"image/*"``, ``"image/png,image/jpeg"``, ``".safetensors,.bin"``). Returns: A FieldInfo instance containing the field metadata. @@ -92,4 +97,5 @@ def predict( regex=regex, choices=choices, deprecated=deprecated, + accept=accept, ) From e4311a593bb979f0b3cf40da6af05e4aa76d7411 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Tue, 24 Mar 2026 18:11:18 -0400 Subject: [PATCH 2/2] Set COG_STATIC_SCHEMA=1 in accept integration tests --- integration-tests/tests/accept_mime_type.txtar | 2 ++ integration-tests/tests/accept_mime_type_error.txtar | 2 ++ 2 files changed, 4 insertions(+) diff --git a/integration-tests/tests/accept_mime_type.txtar b/integration-tests/tests/accept_mime_type.txtar index ded967ff04..7ebc86f3f2 100644 --- a/integration-tests/tests/accept_mime_type.txtar +++ b/integration-tests/tests/accept_mime_type.txtar @@ -8,6 +8,8 @@ # - Fields without accept do not have x-cog-accept # - Prediction still works end-to-end +env COG_STATIC_SCHEMA=1 + cog build -t $TEST_IMAGE # Extract the schema from the image label diff --git a/integration-tests/tests/accept_mime_type_error.txtar b/integration-tests/tests/accept_mime_type_error.txtar index c5da86a0a3..376257f461 100644 --- a/integration-tests/tests/accept_mime_type_error.txtar +++ b/integration-tests/tests/accept_mime_type_error.txtar @@ -3,6 +3,8 @@ # The accept parameter is only valid on Path or File inputs. Using it on # str, int, float, etc. should produce a clear error at build time. +env COG_STATIC_SCHEMA=1 + ! cog build -t $TEST_IMAGE stderr 'accept is only valid on Path or File inputs'