diff --git a/server/internal/database/rag_service_config.go b/server/internal/database/rag_service_config.go index 5c86eee9..a51feca7 100644 --- a/server/internal/database/rag_service_config.go +++ b/server/internal/database/rag_service_config.go @@ -4,11 +4,23 @@ import ( "bytes" "encoding/json" "fmt" + "regexp" "slices" "sort" "strings" ) +// ragPipelineNamePatternText is the allowlist pattern for RAG pipeline names. +// It is kept as a const so that the compiled regexp and the error message both +// reference the same literal and cannot drift apart. +const ragPipelineNamePatternText = `^[a-z0-9_][a-z0-9_-]*$` + +// ragPipelineNamePattern restricts pipeline names to lowercase alphanumeric +// characters, hyphens, and underscores. The first character must not be a +// hyphen so that names are safe as filename components and cannot be +// misinterpreted as CLI flags if ever passed to a command. +var ragPipelineNamePattern = regexp.MustCompile(ragPipelineNamePatternText) + // RAGPipelineLLMConfig represents LLM configuration for an embedding or RAG step. type RAGPipelineLLMConfig struct { Provider string `json:"provider"` @@ -126,9 +138,11 @@ func validateRAGPipeline(p RAGPipeline, i int, seenNames map[string]bool) []erro var errs []error prefix := fmt.Sprintf("pipelines[%d]", i) - // name (required, unique) + // name (required, allowlist, unique) if p.Name == "" { errs = append(errs, fmt.Errorf("%s.name is required", prefix)) + } else if !ragPipelineNamePattern.MatchString(p.Name) { + errs = append(errs, fmt.Errorf("%s.name %q is invalid: must match %s", prefix, p.Name, ragPipelineNamePatternText)) } else if seenNames[p.Name] { errs = append(errs, fmt.Errorf("pipelines contains duplicate name %q", p.Name)) } else { diff --git a/server/internal/database/rag_service_config_test.go b/server/internal/database/rag_service_config_test.go index 953f41c7..b9f3f7ca 100644 --- a/server/internal/database/rag_service_config_test.go +++ b/server/internal/database/rag_service_config_test.go @@ -384,6 +384,50 @@ func TestParseRAGServiceConfig_MissingRAGLLM(t *testing.T) { assert.Contains(t, errs[0].Error(), "rag_llm.provider") } +func TestParseRAGServiceConfig_PipelineNameAllowlist(t *testing.T) { + validNames := []string{ + "default", + "my-pipeline", + "my_pipeline", + "pipeline-1", + "a", + "abc123", + "a-b_c-1", + } + for _, name := range validNames { + t.Run("valid/"+name, func(t *testing.T) { + config := minimalRAGConfig() + config["pipelines"].([]any)[0].(map[string]any)["name"] = name + _, errs := database.ParseRAGServiceConfig(config, false) + assert.Empty(t, errs, "name %q should be valid", name) + }) + } + + invalidNames := []string{ + "My Pipeline", // uppercase + space + "pipeline name", // space + "pipeline/name", // slash + "../etc/passwd", // path traversal + "UPPER", // uppercase + "pipešŸ”„line", // unicode emoji + "pipeline.name", // dot + "-pipeline", // leading hyphen (could be misread as a CLI flag) + "", // empty (covered separately, but included for completeness) + } + for _, name := range invalidNames { + if name == "" { + continue // empty name is a separate "required" error + } + t.Run("invalid/"+name, func(t *testing.T) { + config := minimalRAGConfig() + config["pipelines"].([]any)[0].(map[string]any)["name"] = name + _, errs := database.ParseRAGServiceConfig(config, false) + require.NotEmpty(t, errs, "name %q should be invalid", name) + assert.Contains(t, errs[0].Error(), "must match ^[a-z0-9_][a-z0-9_-]*$") + }) + } +} + func TestParseRAGServiceConfig_MultiplePipelines(t *testing.T) { config := map[string]any{ "pipelines": []any{