diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go index ffbf0ffb163..100cdf5d60c 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/init.go @@ -14,6 +14,7 @@ import ( "net/url" "os" "path/filepath" + "regexp" "strings" "time" @@ -475,6 +476,18 @@ func (a *InitAction) Run(ctx context.Context) error { return fmt.Errorf("configuring model choice: %w", err) } + // Prompt for manifest parameters (e.g. tool credentials) after project selection + agentManifest, err = registry_api.ProcessManifestParameters( + ctx, agentManifest, a.azdClient, a.flags.NoPrompt, + ) + if err != nil { + return fmt.Errorf("failed to process manifest parameters: %w", err) + } + + // Inject toolbox MCP endpoint env vars into hosted agent definitions + // so agent.yaml is self-documenting about what env vars will be set. + injectToolboxEnvVarsIntoDefinition(agentManifest) + // Write the final agent.yaml to disk (after deployment names have been injected) if err := writeAgentDefinitionFile(targetDir, agentManifest); err != nil { return fmt.Errorf("writing agent definition: %w", err) @@ -1078,11 +1091,6 @@ func (a *InitAction) downloadAgentYaml( } } - agentManifest, err = registry_api.ProcessManifestParameters(ctx, agentManifest, a.azdClient, a.flags.NoPrompt) - if err != nil { - return nil, "", fmt.Errorf("failed to process manifest parameters: %w", err) - } - _, isPromptAgent := agentManifest.Template.(agent_yaml.PromptAgent) if isPromptAgent { agentManifest, err = agent_yaml.ProcessPromptAgentToolsConnections(ctx, agentManifest, a.azdClient) @@ -1237,6 +1245,33 @@ func (a *InitAction) addToProject(ctx context.Context, targetDir string, agentMa agentConfig.Deployments = a.deploymentDetails agentConfig.Resources = resourceDetails + // Process toolbox resources from the manifest + toolboxes, toolConnections, credEnvVars, err := extractToolboxAndConnectionConfigs(agentManifest) + if err != nil { + return err + } + agentConfig.Toolboxes = toolboxes + agentConfig.ToolConnections = toolConnections + + // Persist credential values as azd environment variables so they are + // resolved at provision/deploy time instead of stored in azure.yaml. + for envKey, envVal := range credEnvVars { + if _, setErr := a.azdClient.Environment().SetValue(ctx, &azdext.SetEnvRequest{ + EnvName: a.environment.Name, + Key: envKey, + Value: envVal, + }); setErr != nil { + return fmt.Errorf("storing credential env var %s: %w", envKey, setErr) + } + } + + // Process connection resources from the manifest + connections, err := extractConnectionConfigs(agentManifest) + if err != nil { + return err + } + agentConfig.Connections = connections + // Detect startup command from the project source directory startupCmd, err := resolveStartupCommandForInit(ctx, a.azdClient, a.projectConfig.Path, targetDir, a.flags.NoPrompt) if err != nil { @@ -1645,3 +1680,222 @@ func downloadDirectoryContentsWithoutGhCli( return nil } + +// extractToolboxAndConnectionConfigs extracts toolbox resource definitions from the agent manifest +// and converts them into project.Toolbox config entries and project.ToolConnection entries. +// Tools with a target/authType also produce connection entries for Bicep provisioning. +// Built-in tools (bing_grounding, azure_ai_search, etc.) produce toolbox tools but no connections. +func extractToolboxAndConnectionConfigs( + manifest *agent_yaml.AgentManifest, +) ([]project.Toolbox, []project.ToolConnection, map[string]string, error) { + if manifest == nil || manifest.Resources == nil { + return nil, nil, nil, nil + } + + var toolboxes []project.Toolbox + var connections []project.ToolConnection + // credentialEnvVars maps generated env var names to their raw values so + // the caller can persist them in the azd environment. + credentialEnvVars := map[string]string{} + + for _, resource := range manifest.Resources { + tbResource, ok := resource.(agent_yaml.ToolboxResource) + if !ok { + continue + } + + description := tbResource.Description + + if len(tbResource.Tools) == 0 { + return nil, nil, nil, fmt.Errorf( + "toolbox resource '%s' is missing required 'tools'", + tbResource.Name, + ) + } + + var tools []map[string]any + for _, rawTool := range tbResource.Tools { + toolMap, ok := rawTool.(map[string]any) + if !ok { + return nil, nil, nil, fmt.Errorf( + "toolbox resource '%s' has invalid tool entry: expected object", + tbResource.Name, + ) + } + + // Manifest and API both use "type" for tool kind + toolType, _ := toolMap["type"].(string) + + target, _ := toolMap["target"].(string) + if target == "" { + // No target — either a built-in tool or a pre-configured tool + // that already has project_connection_id. Pass through as-is. + result := make(map[string]any, len(toolMap)) + for k, v := range toolMap { + result[k] = v + } + tools = append(tools, result) + continue + } + + // External tools with target/authType need a connection + toolName, _ := toolMap["name"].(string) + authType, _ := toolMap["authType"].(string) + credentials, _ := toolMap["credentials"].(map[string]any) + + connName := toolName + if connName == "" { + connName = tbResource.Name + "-" + toolType + } + + conn := project.ToolConnection{ + Name: connName, + Category: "RemoteTool", + Target: target, + AuthType: authType, + } + + // Extract credentials, storing raw values as env vars and + // replacing them with ${VAR} references in the config. + if len(credentials) > 0 { + creds := make(map[string]any, len(credentials)) + for k, v := range credentials { + envVar := credentialEnvVarName(connName, k) + credentialEnvVars[envVar] = fmt.Sprintf("%v", v) + creds[k] = fmt.Sprintf("${%s}", envVar) + } + + // CustomKeys ARM type requires credentials nested under "keys" + if authType == "CustomKeys" { + conn.Credentials = map[string]any{"keys": creds} + } else { + conn.Credentials = creds + } + } + + connections = append(connections, conn) + + // Toolbox tool entry is minimal — deploy enriches from connection + tool := map[string]any{ + "type": toolType, + "project_connection_id": connName, + } + tools = append(tools, tool) + } + + toolboxes = append(toolboxes, project.Toolbox{ + Name: tbResource.Name, + Description: description, + Tools: tools, + }) + } + + return toolboxes, connections, credentialEnvVars, nil +} + +// credentialEnvVarName builds a deterministic env var name for a connection +// credential key, e.g. ("github-copilot", "clientSecret") → "TOOL_GITHUB_COPILOT_CLIENTSECRET". +// All non-alphanumeric characters are replaced with underscores and consecutive +// underscores are collapsed to produce a valid [A-Z0-9_]+ environment variable name. +var nonAlphanumRe = regexp.MustCompile(`[^A-Z0-9]+`) + +func credentialEnvVarName(connName, key string) string { + s := "TOOL_" + strings.ToUpper(connName) + "_" + strings.ToUpper(key) + return nonAlphanumRe.ReplaceAllString(s, "_") +} + +// injectToolboxEnvVarsIntoDefinition adds TOOLBOX_{NAME}_MCP_ENDPOINT entries +// to the environment_variables section of a hosted agent definition for each toolbox +// resource in the manifest. Entries already present in the definition are not overwritten. +func injectToolboxEnvVarsIntoDefinition(manifest *agent_yaml.AgentManifest) { + if manifest == nil || manifest.Resources == nil { + return + } + + containerAgent, ok := manifest.Template.(agent_yaml.ContainerAgent) + if !ok { + return + } + + // Collect toolbox resource names + var toolboxNames []string + for _, resource := range manifest.Resources { + if tbResource, ok := resource.(agent_yaml.ToolboxResource); ok { + toolboxNames = append(toolboxNames, tbResource.Name) + } + } + if len(toolboxNames) == 0 { + return + } + + if containerAgent.EnvironmentVariables == nil { + envVars := []agent_yaml.EnvironmentVariable{} + containerAgent.EnvironmentVariables = &envVars + } + + existingNames := make(map[string]bool, len(*containerAgent.EnvironmentVariables)) + for _, ev := range *containerAgent.EnvironmentVariables { + existingNames[ev.Name] = true + } + + for _, tbName := range toolboxNames { + envKey := toolboxMCPEndpointEnvKey(tbName) + if !existingNames[envKey] { + *containerAgent.EnvironmentVariables = append( + *containerAgent.EnvironmentVariables, + agent_yaml.EnvironmentVariable{ + Name: envKey, + Value: fmt.Sprintf("${%s}", envKey), + }, + ) + } + } + + manifest.Template = containerAgent +} + +// extractConnectionConfigs extracts connection resource definitions from the agent manifest +// and converts them into project.Connection config entries. +func extractConnectionConfigs(manifest *agent_yaml.AgentManifest) ([]project.Connection, error) { + if manifest == nil || manifest.Resources == nil { + return nil, nil + } + + var connections []project.Connection + + for _, resource := range manifest.Resources { + connResource, ok := resource.(agent_yaml.ConnectionResource) + if !ok { + continue + } + + conn := project.Connection{ + Name: connResource.Name, + Category: string(connResource.Category), + Target: connResource.Target, + AuthType: string(connResource.AuthType), + Credentials: connResource.Credentials, + Metadata: connResource.Metadata, + ExpiryTime: connResource.ExpiryTime, + IsSharedToAll: connResource.IsSharedToAll, + SharedUserList: connResource.SharedUserList, + PeRequirement: connResource.PeRequirement, + PeStatus: connResource.PeStatus, + UseWorkspaceManagedIdentity: connResource.UseWorkspaceManagedIdentity, + Error: connResource.Error, + } + + // Surface credentials.type to top-level authType when not explicitly set. + // The API expects authType at the connection level, not nested in credentials. + if conn.AuthType == "" { + if credType, ok := conn.Credentials["type"].(string); ok && credType != "" { + conn.AuthType = credType + delete(conn.Credentials, "type") + } + } + + connections = append(connections, conn) + } + + return connections, nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_from_code.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_from_code.go index 166f02b7601..f7df342ca52 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_from_code.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_from_code.go @@ -848,13 +848,13 @@ type protocolInfo struct { // knownProtocols lists the protocols offered during init, in display order. var knownProtocols = []protocolInfo{ - {Name: "responses", Version: "v1"}, - {Name: "invocations", Version: "v0.0.1"}, + {Name: "responses", Version: "1.0.0"}, + {Name: "invocations", Version: "1.0.0"}, } // promptProtocols asks the user which protocols their agent supports. // When flagProtocols is non-empty the prompt is skipped and those values are used directly. -// When noPrompt is true and no flag values are provided, defaults to [responses/v1]. +// When noPrompt is true and no flag values are provided, defaults to [responses/1.0.0]. func promptProtocols( ctx context.Context, promptClient azdext.PromptServiceClient, @@ -897,7 +897,7 @@ func promptProtocols( // Non-interactive mode: default to responses. if noPrompt { return []agent_yaml.ProtocolVersionRecord{ - {Protocol: "responses", Version: "v1"}, + {Protocol: "responses", Version: "1.0.0"}, }, nil } diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_from_code_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_from_code_test.go index 968eabdefa7..e32bda5bddc 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_from_code_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_from_code_test.go @@ -412,7 +412,7 @@ func TestWriteDefinitionToSrcDir(t *testing.T) { Kind: agent_yaml.AgentKindHosted, }, Protocols: []agent_yaml.ProtocolVersionRecord{ - {Protocol: "responses", Version: "v1"}, + {Protocol: "responses", Version: "1.0.0"}, }, EnvironmentVariables: &[]agent_yaml.EnvironmentVariable{ {Name: "AZURE_OPENAI_ENDPOINT", Value: "${AZURE_OPENAI_ENDPOINT}"}, @@ -601,22 +601,22 @@ func TestPromptProtocols_FlagValues(t *testing.T) { name: "responses only", flagProtocols: []string{"responses"}, wantProtocols: []agent_yaml.ProtocolVersionRecord{ - {Protocol: "responses", Version: "v1"}, + {Protocol: "responses", Version: "1.0.0"}, }, }, { name: "invocations only", flagProtocols: []string{"invocations"}, wantProtocols: []agent_yaml.ProtocolVersionRecord{ - {Protocol: "invocations", Version: "v0.0.1"}, + {Protocol: "invocations", Version: "1.0.0"}, }, }, { name: "both protocols", flagProtocols: []string{"responses", "invocations"}, wantProtocols: []agent_yaml.ProtocolVersionRecord{ - {Protocol: "responses", Version: "v1"}, - {Protocol: "invocations", Version: "v0.0.1"}, + {Protocol: "responses", Version: "1.0.0"}, + {Protocol: "invocations", Version: "1.0.0"}, }, }, { @@ -629,8 +629,8 @@ func TestPromptProtocols_FlagValues(t *testing.T) { name: "duplicates are removed", flagProtocols: []string{"responses", "responses", "invocations"}, wantProtocols: []agent_yaml.ProtocolVersionRecord{ - {Protocol: "responses", Version: "v1"}, - {Protocol: "invocations", Version: "v0.0.1"}, + {Protocol: "responses", Version: "1.0.0"}, + {Protocol: "invocations", Version: "1.0.0"}, }, }, } @@ -680,8 +680,8 @@ func TestPromptProtocols_NoPromptDefault(t *testing.T) { if got[0].Protocol != "responses" { t.Errorf("protocol = %q, want %q", got[0].Protocol, "responses") } - if got[0].Version != "v1" { - t.Errorf("version = %q, want %q", got[0].Version, "v1") + if got[0].Version != "1.0.0" { + t.Errorf("version = %q, want %q", got[0].Version, "1.0.0") } } @@ -736,8 +736,8 @@ func TestPromptProtocols_Interactive(t *testing.T) { }, nil }, wantProtocols: []agent_yaml.ProtocolVersionRecord{ - {Protocol: "responses", Version: "v1"}, - {Protocol: "invocations", Version: "v0.0.1"}, + {Protocol: "responses", Version: "1.0.0"}, + {Protocol: "invocations", Version: "1.0.0"}, }, }, { @@ -751,7 +751,7 @@ func TestPromptProtocols_Interactive(t *testing.T) { }, nil }, wantProtocols: []agent_yaml.ProtocolVersionRecord{ - {Protocol: "responses", Version: "v1"}, + {Protocol: "responses", Version: "1.0.0"}, }, }, { diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_test.go index 6fb935235a3..6f78c67070d 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/init_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/init_test.go @@ -9,6 +9,8 @@ import ( "path/filepath" "testing" + "azureaiagent/internal/pkg/agents/agent_yaml" + "github.com/azure/azure-dev/cli/azd/pkg/azdext" "google.golang.org/genproto/googleapis/rpc/errdetails" "google.golang.org/grpc/codes" @@ -394,3 +396,547 @@ func TestParseGitHubUrlNaive(t *testing.T) { }) } } + +func TestExtractToolboxAndConnectionConfigs_TypedTools(t *testing.T) { + t.Parallel() + + manifest := &agent_yaml.AgentManifest{ + Resources: []any{ + agent_yaml.ToolboxResource{ + Resource: agent_yaml.Resource{ + Name: "platform-tools", + Kind: agent_yaml.ResourceKindToolbox, + }, + Description: "Platform tools", + Tools: []any{ + map[string]any{ + // Built-in tool — no connection + "type": "bing_grounding", + }, + map[string]any{ + // External tool with name — connection name from Name field + "type": "mcp", + "name": "github-copilot", + "target": "https://api.githubcopilot.com/mcp", + "authType": "OAuth2", + "credentials": map[string]any{ + "clientId": "my-client-id", + "clientSecret": "my-secret", + }, + }, + }, + }, + }, + } + + toolboxes, connections, credEnvVars, err := extractToolboxAndConnectionConfigs(manifest) + if err != nil { + t.Fatalf("extractToolboxAndConnectionConfigs failed: %v", err) + } + + // Only the external tool creates a connection (not bing_grounding) + if len(connections) != 1 { + t.Fatalf("Expected 1 connection, got %d", len(connections)) + } + conn := connections[0] + if conn.Name != "github-copilot" { + t.Errorf("Expected connection name 'github-copilot', got '%s'", conn.Name) + } + if conn.Category != "RemoteTool" { + t.Errorf("Expected category 'RemoteTool', got '%s'", conn.Category) + } + if conn.Target != "https://api.githubcopilot.com/mcp" { + t.Errorf("Expected target, got '%s'", conn.Target) + } + if conn.AuthType != "OAuth2" { + t.Errorf("Expected authType 'OAuth2', got '%s'", conn.AuthType) + } + + // Credentials should be ${VAR} references, not raw values + if conn.Credentials["clientId"] != "${TOOL_GITHUB_COPILOT_CLIENTID}" { + t.Errorf("Expected env var ref for clientId, got '%v'", conn.Credentials["clientId"]) + } + if conn.Credentials["clientSecret"] != "${TOOL_GITHUB_COPILOT_CLIENTSECRET}" { + t.Errorf("Expected env var ref for clientSecret, got '%v'", conn.Credentials["clientSecret"]) + } + + // Raw values should be in the credEnvVars map + if credEnvVars["TOOL_GITHUB_COPILOT_CLIENTID"] != "my-client-id" { + t.Errorf("Expected env var value 'my-client-id', got '%s'", + credEnvVars["TOOL_GITHUB_COPILOT_CLIENTID"]) + } + if credEnvVars["TOOL_GITHUB_COPILOT_CLIENTSECRET"] != "my-secret" { + t.Errorf("Expected env var value 'my-secret', got '%s'", + credEnvVars["TOOL_GITHUB_COPILOT_CLIENTSECRET"]) + } + + // Verify toolbox has both tools + if len(toolboxes) != 1 { + t.Fatalf("Expected 1 toolbox, got %d", len(toolboxes)) + } + tb := toolboxes[0] + if tb.Name != "platform-tools" { + t.Errorf("Expected toolbox name 'platform-tools', got '%s'", tb.Name) + } + if tb.Description != "Platform tools" { + t.Errorf("Expected description 'Platform tools', got '%s'", tb.Description) + } + if len(tb.Tools) != 2 { + t.Fatalf("Expected 2 tools, got %d", len(tb.Tools)) + } + + // First tool: built-in (no project_connection_id) + if tb.Tools[0]["type"] != "bing_grounding" { + t.Errorf("Expected tool[0] type 'bing_grounding', got '%v'", tb.Tools[0]["type"]) + } + if _, hasConn := tb.Tools[0]["project_connection_id"]; hasConn { + t.Errorf("Built-in tool should not have project_connection_id") + } + + // Second tool: minimal (type + project_connection_id only) + if tb.Tools[1]["project_connection_id"] != "github-copilot" { + t.Errorf("Expected project_connection_id 'github-copilot', got '%v'", + tb.Tools[1]["project_connection_id"]) + } + if tb.Tools[1]["type"] != "mcp" { + t.Errorf("Expected tool type 'mcp', got '%v'", tb.Tools[1]["type"]) + } + // No server_url or server_label in init output — deploy enriches from connections + if _, has := tb.Tools[1]["server_url"]; has { + t.Errorf("Toolbox tool should not have server_url (deploy enriches it)") + } + if _, has := tb.Tools[1]["server_label"]; has { + t.Errorf("Toolbox tool should not have server_label (deploy enriches it)") + } +} + +func TestExtractToolboxAndConnectionConfigs_RawToolsFallback(t *testing.T) { + t.Parallel() + + manifest := &agent_yaml.AgentManifest{ + Resources: []any{ + agent_yaml.ToolboxResource{ + Resource: agent_yaml.Resource{ + Name: "raw-toolbox", + Kind: agent_yaml.ResourceKindToolbox, + }, + Description: "Raw tools", + Tools: []any{ + map[string]any{ + "type": "mcp", + "name": "existing", + "project_connection_id": "existing-conn", + }, + }, + }, + }, + } + + toolboxes, connections, credEnvVars, err := extractToolboxAndConnectionConfigs(manifest) + if err != nil { + t.Fatalf("extractToolboxAndConnectionConfigs failed: %v", err) + } + + // No connections or env vars extracted from raw tools + if len(connections) != 0 { + t.Errorf("Expected 0 connections, got %d", len(connections)) + } + if len(credEnvVars) != 0 { + t.Errorf("Expected 0 env vars, got %d", len(credEnvVars)) + } + + if len(toolboxes) != 1 { + t.Fatalf("Expected 1 toolbox, got %d", len(toolboxes)) + } + if toolboxes[0].Tools[0]["project_connection_id"] != "existing-conn" { + t.Errorf("Expected 'existing-conn', got '%v'", toolboxes[0].Tools[0]["project_connection_id"]) + } +} + +func TestExtractToolboxAndConnectionConfigs_NilManifest(t *testing.T) { + t.Parallel() + + toolboxes, connections, credEnvVars, err := extractToolboxAndConnectionConfigs(nil) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + if toolboxes != nil { + t.Errorf("Expected nil toolboxes, got %v", toolboxes) + } + if connections != nil { + t.Errorf("Expected nil connections, got %v", connections) + } + if credEnvVars != nil { + t.Errorf("Expected nil env vars, got %v", credEnvVars) + } +} + +func TestExtractToolboxAndConnectionConfigs_CustomKeysCredentials(t *testing.T) { + t.Parallel() + + manifest := &agent_yaml.AgentManifest{ + Resources: []any{ + agent_yaml.ToolboxResource{ + Resource: agent_yaml.Resource{ + Name: "my-tools", + Kind: agent_yaml.ResourceKindToolbox, + }, + Tools: []any{ + map[string]any{ + "type": "mcp", + "name": "custom-api", + "target": "https://example.com/mcp", + "authType": "CustomKeys", + "credentials": map[string]any{"key": "my-api-key"}, + }, + map[string]any{ + "type": "mcp", + "name": "oauth-tool", + "target": "https://example.com/oauth", + "authType": "OAuth2", + "credentials": map[string]any{"clientId": "id", "clientSecret": "secret"}, + }, + }, + }, + }, + } + + _, connections, _, err := extractToolboxAndConnectionConfigs(manifest) + if err != nil { + t.Fatalf("extractToolboxAndConnectionConfigs failed: %v", err) + } + + if len(connections) != 2 { + t.Fatalf("Expected 2 connections, got %d", len(connections)) + } + + // CustomKeys: credentials must be nested under "keys" + customConn := connections[0] + keysRaw, ok := customConn.Credentials["keys"] + if !ok { + t.Fatal("CustomKeys connection missing 'keys' wrapper in credentials") + } + keys, ok := keysRaw.(map[string]any) + if !ok { + t.Fatalf("Expected 'keys' to be map[string]any, got %T", keysRaw) + } + if keys["key"] != "${TOOL_CUSTOM_API_KEY}" { + t.Errorf("Expected env var ref for key, got '%v'", keys["key"]) + } + + // OAuth2: credentials should be flat (no "keys" wrapper) + oauthConn := connections[1] + if _, hasKeys := oauthConn.Credentials["keys"]; hasKeys { + t.Error("OAuth2 connection should not have 'keys' wrapper") + } + if oauthConn.Credentials["clientId"] != "${TOOL_OAUTH_TOOL_CLIENTID}" { + t.Errorf("Expected flat clientId ref, got '%v'", oauthConn.Credentials["clientId"]) + } +} + +func TestInjectToolboxEnvVarsIntoDefinition_AddsEnvVars(t *testing.T) { + t.Parallel() + + manifest := &agent_yaml.AgentManifest{ + Template: agent_yaml.ContainerAgent{ + AgentDefinition: agent_yaml.AgentDefinition{ + Kind: agent_yaml.AgentKindHosted, + Name: "my-agent", + }, + Protocols: []agent_yaml.ProtocolVersionRecord{ + {Protocol: "responses", Version: "1.0.0"}, + }, + EnvironmentVariables: &[]agent_yaml.EnvironmentVariable{ + {Name: "AZURE_OPENAI_ENDPOINT", Value: "${AZURE_OPENAI_ENDPOINT}"}, + }, + }, + Resources: []any{ + agent_yaml.ToolboxResource{ + Resource: agent_yaml.Resource{ + Name: "agent-tools", + Kind: agent_yaml.ResourceKindToolbox, + }, + Tools: []any{ + map[string]any{"type": "bing_grounding"}, + }, + }, + }, + } + + injectToolboxEnvVarsIntoDefinition(manifest) + + containerAgent := manifest.Template.(agent_yaml.ContainerAgent) + envVars := *containerAgent.EnvironmentVariables + + if len(envVars) != 2 { + t.Fatalf("Expected 2 env vars, got %d", len(envVars)) + } + + // Original env var is preserved + if envVars[0].Name != "AZURE_OPENAI_ENDPOINT" { + t.Errorf("Expected first env var to be AZURE_OPENAI_ENDPOINT, got %s", envVars[0].Name) + } + + // Toolbox env var is injected + if envVars[1].Name != "TOOLBOX_AGENT_TOOLS_MCP_ENDPOINT" { + t.Errorf("Expected injected env var name, got %s", envVars[1].Name) + } + if envVars[1].Value != "${TOOLBOX_AGENT_TOOLS_MCP_ENDPOINT}" { + t.Errorf("Expected env var reference value, got %s", envVars[1].Value) + } +} + +func TestInjectToolboxEnvVarsIntoDefinition_SkipsExisting(t *testing.T) { + t.Parallel() + + manifest := &agent_yaml.AgentManifest{ + Template: agent_yaml.ContainerAgent{ + AgentDefinition: agent_yaml.AgentDefinition{ + Kind: agent_yaml.AgentKindHosted, + Name: "my-agent", + }, + Protocols: []agent_yaml.ProtocolVersionRecord{ + {Protocol: "responses", Version: "1.0.0"}, + }, + EnvironmentVariables: &[]agent_yaml.EnvironmentVariable{ + {Name: "TOOLBOX_MY_TOOLS_MCP_ENDPOINT", Value: "custom-value"}, + }, + }, + Resources: []any{ + agent_yaml.ToolboxResource{ + Resource: agent_yaml.Resource{ + Name: "my-tools", + Kind: agent_yaml.ResourceKindToolbox, + }, + Tools: []any{ + map[string]any{"type": "bing_grounding"}, + }, + }, + }, + } + + injectToolboxEnvVarsIntoDefinition(manifest) + + containerAgent := manifest.Template.(agent_yaml.ContainerAgent) + envVars := *containerAgent.EnvironmentVariables + + // Should not add a duplicate — user's value is preserved + if len(envVars) != 1 { + t.Fatalf("Expected 1 env var (no duplicate), got %d", len(envVars)) + } + if envVars[0].Value != "custom-value" { + t.Errorf("Expected user value preserved, got %s", envVars[0].Value) + } +} + +func TestInjectToolboxEnvVarsIntoDefinition_MultipleToolboxes(t *testing.T) { + t.Parallel() + + manifest := &agent_yaml.AgentManifest{ + Template: agent_yaml.ContainerAgent{ + AgentDefinition: agent_yaml.AgentDefinition{ + Kind: agent_yaml.AgentKindHosted, + Name: "my-agent", + }, + Protocols: []agent_yaml.ProtocolVersionRecord{ + {Protocol: "responses", Version: "1.0.0"}, + }, + }, + Resources: []any{ + agent_yaml.ToolboxResource{ + Resource: agent_yaml.Resource{Name: "search-tools", Kind: agent_yaml.ResourceKindToolbox}, + Tools: []any{map[string]any{"type": "bing_grounding"}}, + }, + agent_yaml.ToolboxResource{ + Resource: agent_yaml.Resource{Name: "github-tools", Kind: agent_yaml.ResourceKindToolbox}, + Tools: []any{map[string]any{"type": "mcp", "target": "https://example.com"}}, + }, + }, + } + + injectToolboxEnvVarsIntoDefinition(manifest) + + containerAgent := manifest.Template.(agent_yaml.ContainerAgent) + envVars := *containerAgent.EnvironmentVariables + + if len(envVars) != 2 { + t.Fatalf("Expected 2 env vars, got %d", len(envVars)) + } + if envVars[0].Name != "TOOLBOX_SEARCH_TOOLS_MCP_ENDPOINT" { + t.Errorf("Expected first toolbox env var, got %s", envVars[0].Name) + } + if envVars[1].Name != "TOOLBOX_GITHUB_TOOLS_MCP_ENDPOINT" { + t.Errorf("Expected second toolbox env var, got %s", envVars[1].Name) + } +} + +func TestInjectToolboxEnvVarsIntoDefinition_NoopForNilManifest(t *testing.T) { + t.Parallel() + + // Should not panic + injectToolboxEnvVarsIntoDefinition(nil) +} + +func TestInjectToolboxEnvVarsIntoDefinition_NoopForPromptAgent(t *testing.T) { + t.Parallel() + + manifest := &agent_yaml.AgentManifest{ + Template: agent_yaml.PromptAgent{ + AgentDefinition: agent_yaml.AgentDefinition{ + Kind: agent_yaml.AgentKindPrompt, + Name: "prompt-agent", + }, + }, + Resources: []any{ + agent_yaml.ToolboxResource{ + Resource: agent_yaml.Resource{Name: "tools", Kind: agent_yaml.ResourceKindToolbox}, + Tools: []any{map[string]any{"type": "bing_grounding"}}, + }, + }, + } + + injectToolboxEnvVarsIntoDefinition(manifest) + + // Template should be unchanged (still a PromptAgent, no EnvironmentVariables field) + if _, ok := manifest.Template.(agent_yaml.PromptAgent); !ok { + t.Error("Expected template to remain a PromptAgent") + } +} + +func TestInjectToolboxEnvVarsIntoDefinition_NoopWithoutToolboxes(t *testing.T) { + t.Parallel() + + manifest := &agent_yaml.AgentManifest{ + Template: agent_yaml.ContainerAgent{ + AgentDefinition: agent_yaml.AgentDefinition{ + Kind: agent_yaml.AgentKindHosted, + Name: "my-agent", + }, + EnvironmentVariables: &[]agent_yaml.EnvironmentVariable{ + {Name: "AZURE_OPENAI_ENDPOINT", Value: "${AZURE_OPENAI_ENDPOINT}"}, + }, + }, + Resources: []any{}, + } + + injectToolboxEnvVarsIntoDefinition(manifest) + + containerAgent := manifest.Template.(agent_yaml.ContainerAgent) + if len(*containerAgent.EnvironmentVariables) != 1 { + t.Errorf("Expected env vars unchanged, got %d", len(*containerAgent.EnvironmentVariables)) + } +} + +func TestToolboxMCPEndpointEnvKey(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected string + }{ + {"simple", "my-tools", "TOOLBOX_MY_TOOLS_MCP_ENDPOINT"}, + {"spaces", "my tools", "TOOLBOX_MY_TOOLS_MCP_ENDPOINT"}, + {"mixed", "agent-tools v2", "TOOLBOX_AGENT_TOOLS_V2_MCP_ENDPOINT"}, + {"already upper", "TOOLS", "TOOLBOX_TOOLS_MCP_ENDPOINT"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := toolboxMCPEndpointEnvKey(tt.input) + if got != tt.expected { + t.Errorf("toolboxMCPEndpointEnvKey(%q) = %q, want %q", tt.input, got, tt.expected) + } + }) + } +} + +func TestExtractConnectionConfigs_SurfacesCredentialsType(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + connResource agent_yaml.ConnectionResource + wantAuthType string + wantCredHasType bool + wantCredKeyCount int + }{ + { + name: "surfaces credentials.type to authType when authType is empty", + connResource: agent_yaml.ConnectionResource{ + Resource: agent_yaml.Resource{ + Name: "my-conn", + Kind: agent_yaml.ResourceKindConnection, + }, + Target: "https://example.com", + Credentials: map[string]any{ + "type": "CustomKeys", + "key": "secret-value", + }, + }, + wantAuthType: "CustomKeys", + wantCredHasType: false, + wantCredKeyCount: 1, + }, + { + name: "preserves explicit authType even if credentials.type differs", + connResource: agent_yaml.ConnectionResource{ + Resource: agent_yaml.Resource{ + Name: "my-conn", + Kind: agent_yaml.ResourceKindConnection, + }, + Target: "https://example.com", + AuthType: agent_yaml.AuthTypeAAD, + Credentials: map[string]any{ + "type": "CustomKeys", + "key": "val", + }, + }, + wantAuthType: string(agent_yaml.AuthTypeAAD), + wantCredHasType: true, + wantCredKeyCount: 2, + }, + { + name: "no credentials.type and no authType stays empty", + connResource: agent_yaml.ConnectionResource{ + Resource: agent_yaml.Resource{ + Name: "my-conn", + Kind: agent_yaml.ResourceKindConnection, + }, + Target: "https://example.com", + Credentials: map[string]any{"key": "val"}, + }, + wantAuthType: "", + wantCredHasType: false, + wantCredKeyCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + manifest := &agent_yaml.AgentManifest{ + Resources: []any{tt.connResource}, + } + conns, err := extractConnectionConfigs(manifest) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(conns) != 1 { + t.Fatalf("expected 1 connection, got %d", len(conns)) + } + conn := conns[0] + if conn.AuthType != tt.wantAuthType { + t.Errorf("AuthType = %q, want %q", conn.AuthType, tt.wantAuthType) + } + _, hasType := conn.Credentials["type"] + if hasType != tt.wantCredHasType { + t.Errorf("credentials has 'type' = %v, want %v", + hasType, tt.wantCredHasType) + } + if len(conn.Credentials) != tt.wantCredKeyCount { + t.Errorf("credentials key count = %d, want %d", + len(conn.Credentials), tt.wantCredKeyCount) + } + }) + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go index a8b08c9e7e0..ab7a2cd8ba8 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen.go @@ -12,11 +12,15 @@ import ( "strconv" "strings" + "azureaiagent/internal/exterrors" "azureaiagent/internal/pkg/agents/agent_yaml" + "azureaiagent/internal/pkg/azure" "azureaiagent/internal/project" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/azure/azure-dev/cli/azd/pkg/azdext" "github.com/braydonk/yaml" + "github.com/drone/envsubst" "github.com/spf13/cobra" "google.golang.org/protobuf/types/known/structpb" ) @@ -47,6 +51,9 @@ func newListenCommand() *cobra.Command { WithProjectEventHandler("preprovision", func(ctx context.Context, args *azdext.ProjectEventArgs) error { return preprovisionHandler(ctx, azdClient, args) }). + WithProjectEventHandler("postprovision", func(ctx context.Context, args *azdext.ProjectEventArgs) error { + return postprovisionHandler(ctx, azdClient, args) + }). WithProjectEventHandler("predeploy", func(ctx context.Context, args *azdext.ProjectEventArgs) error { return predeployHandler(ctx, azdClient, args) }). @@ -81,6 +88,27 @@ func preprovisionHandler(ctx context.Context, azdClient *azdext.AzdClient, args return nil } +func postprovisionHandler( + ctx context.Context, + azdClient *azdext.AzdClient, + args *azdext.ProjectEventArgs, +) error { + for _, svc := range args.Project.Services { + if svc.Host != AiAgentHost { + continue + } + + if err := provisionToolboxes(ctx, azdClient, svc); err != nil { + return fmt.Errorf( + "failed to provision toolboxes for service %q: %w", + svc.Name, err, + ) + } + } + + return nil +} + func predeployHandler(ctx context.Context, azdClient *azdext.AzdClient, args *azdext.ProjectEventArgs) error { for _, svc := range args.Project.Services { switch svc.Host { @@ -147,6 +175,24 @@ func envUpdate(ctx context.Context, azdClient *azdext.AzdClient, azdProject *azd } } + if len(foundryAgentConfig.Connections) > 0 { + if err := connectionsEnvUpdate( + ctx, foundryAgentConfig.Connections, + azdClient, currentEnvResponse.Environment.Name, + ); err != nil { + return err + } + } + + if len(foundryAgentConfig.ToolConnections) > 0 { + if err := toolConnectionsEnvUpdate( + ctx, foundryAgentConfig.ToolConnections, + azdClient, currentEnvResponse.Environment.Name, + ); err != nil { + return err + } + } + return nil } @@ -227,6 +273,112 @@ func resourcesEnvUpdate(ctx context.Context, resources []project.Resource, azdCl return setEnvVar(ctx, azdClient, envName, "AI_PROJECT_DEPENDENT_RESOURCES", escapedJsonString) } +func connectionsEnvUpdate( + ctx context.Context, + connections []project.Connection, + azdClient *azdext.AzdClient, + envName string, +) error { + // Strip credentials from the connections env var — Bicep's ConnectionConfig + // type doesn't include credentials (they're a separate @secure param). + // Including them causes "unable to deserialize request body" errors. + stripped := make([]project.Connection, len(connections)) + copy(stripped, connections) + for i := range stripped { + stripped[i].Credentials = nil + } + + if err := marshalAndSetEnvVar(ctx, azdClient, envName, "AI_PROJECT_CONNECTIONS", stripped); err != nil { + return err + } + + return connectionCredentialsEnvUpdate(ctx, connections, azdClient, envName) +} + +// connectionCredentialsEnvUpdate builds a dictionary of connection name → credentials +// and serializes it to AI_PROJECT_CONNECTION_CREDENTIALS. +func connectionCredentialsEnvUpdate( + ctx context.Context, + connections []project.Connection, + azdClient *azdext.AzdClient, + envName string, +) error { + credMap := buildConnectionCredentials(connections) + if len(credMap) == 0 { + return nil + } + + return marshalAndSetEnvVar(ctx, azdClient, envName, "AI_PROJECT_CONNECTION_CREDENTIALS", credMap) +} + +// buildConnectionCredentials returns a map of connection name → credentials object +// for all connections that have non-empty credentials. +func buildConnectionCredentials(connections []project.Connection) map[string]map[string]any { + result := map[string]map[string]any{} + for _, conn := range connections { + if len(conn.Credentials) > 0 { + result[conn.Name] = conn.Credentials + } + } + + return result +} + +// toolConnectionsEnvUpdate serializes tool connections to AI_PROJECT_TOOL_CONNECTIONS env var. +func toolConnectionsEnvUpdate( + ctx context.Context, + connections []project.ToolConnection, + azdClient *azdext.AzdClient, + envName string, +) error { + // Normalize credentials before serializing: CustomKeys authType requires + // credentials nested under "keys" for the ARM API. + normalized := make([]project.ToolConnection, len(connections)) + copy(normalized, connections) + for i := range normalized { + normalized[i].Credentials = normalizeCredentials(normalized[i].AuthType, normalized[i].Credentials) + } + + return marshalAndSetEnvVar(ctx, azdClient, envName, "AI_PROJECT_TOOL_CONNECTIONS", normalized) +} + +// marshalAndSetEnvVar serializes a value to JSON, escapes it for safe storage +// in an azd environment variable, and persists it. +func marshalAndSetEnvVar( + ctx context.Context, + azdClient *azdext.AzdClient, + envName string, + key string, + value any, +) error { + data, err := json.Marshal(value) + if err != nil { + return fmt.Errorf("failed to marshal %s to JSON: %w", key, err) + } + + jsonString := string(data) + escaped := strings.ReplaceAll(jsonString, "\\", "\\\\") + escaped = strings.ReplaceAll(escaped, "\"", "\\\"") + + return setEnvVar(ctx, azdClient, envName, key, escaped) +} + +// normalizeCredentials ensures credentials match the expected ARM format. +// CustomKeys requires credentials nested under "keys": { "keys": { "key": "val" } }. +// If already wrapped, returns as-is. Other auth types are returned unchanged. +func normalizeCredentials(authType string, creds map[string]any) map[string]any { + if authType != "CustomKeys" || len(creds) == 0 { + return creds + } + + // Already correctly wrapped + if _, hasKeys := creds["keys"]; hasKeys && len(creds) == 1 { + return creds + } + + return map[string]any{"keys": creds} +} + func setEnvVar(ctx context.Context, azdClient *azdext.AzdClient, envName string, key string, value string) error { _, err := azdClient.Environment().SetValue(ctx, &azdext.SetEnvRequest{ EnvName: envName, @@ -316,3 +468,325 @@ func populateContainerSettings(ctx context.Context, azdClient *azdext.AzdClient, return nil } + +// provisionToolboxes creates or updates Foundry Toolsets for each toolbox +// in the service config. Called during post-provision after the project +// endpoint has been created by Bicep. +func provisionToolboxes( + ctx context.Context, + azdClient *azdext.AzdClient, + svc *azdext.ServiceConfig, +) error { + var config *project.ServiceTargetAgentConfig + if err := project.UnmarshalStruct(svc.Config, &config); err != nil { + return fmt.Errorf("failed to parse service config: %w", err) + } + + if config == nil || len(config.Toolboxes) == 0 { + return nil + } + + currentEnv, err := azdClient.Environment().GetCurrent( + ctx, &azdext.EmptyRequest{}, + ) + if err != nil { + return fmt.Errorf("failed to get current environment: %w", err) + } + + envValue, err := azdClient.Environment().GetValue(ctx, &azdext.GetEnvRequest{ + EnvName: currentEnv.Environment.Name, + Key: "AZURE_AI_PROJECT_ENDPOINT", + }) + if err != nil || envValue.Value == "" { + return exterrors.Dependency( + exterrors.CodeMissingAiProjectEndpoint, + "AZURE_AI_PROJECT_ENDPOINT is required for toolbox provisioning", + "run 'azd provision' to create the AI project first", + ) + } + projectEndpoint := envValue.Value + + envValue, err = azdClient.Environment().GetValue(ctx, &azdext.GetEnvRequest{ + EnvName: currentEnv.Environment.Name, + Key: "AZURE_TENANT_ID", + }) + if err != nil || envValue.Value == "" { + return exterrors.Dependency( + exterrors.CodeMissingAzureTenantId, + "AZURE_TENANT_ID is required for toolbox provisioning", + "run 'azd auth login' to authenticate", + ) + } + tenantId := envValue.Value + + cred, err := azidentity.NewAzureDeveloperCLICredential( + &azidentity.AzureDeveloperCLICredentialOptions{ + TenantID: tenantId, + AdditionallyAllowedTenants: []string{"*"}, + }, + ) + if err != nil { + return exterrors.Auth( + exterrors.CodeCredentialCreationFailed, + fmt.Sprintf("failed to create credential: %s", err), + "run 'azd auth login' to authenticate", + ) + } + + toolboxClient := azure.NewFoundryToolboxClient( + projectEndpoint, cred, + ) + + // Build azd env lookup for resolving ${VAR} references in tool entries + azdEnv, err := getAllEnvVars(ctx, azdClient, currentEnv.Environment.Name) + if err != nil { + return fmt.Errorf("failed to load environment variables: %w", err) + } + + // Build connection ID lookup from bicep outputs (name → ARM resource ID) + connIDMap := parseConnectionIDs(azdEnv["AI_PROJECT_CONNECTION_IDS_JSON"]) + + // Build connection lookup for enriching tool entries with server_url/server_label + connByName := map[string]project.ToolConnection{} + for _, c := range config.ToolConnections { + connByName[c.Name] = c + } + + for _, toolbox := range config.Toolboxes { + fmt.Fprintf( + os.Stderr, "Provisioning toolbox: %s\n", toolbox.Name, + ) + + // Resolve ${VAR} references in tool map values before sending to API + resolveToolboxEnvVars(&toolbox, azdEnv) + + // Fill in server_url/server_label from connection data + enrichToolboxFromConnections(&toolbox, connByName) + + // Replace project_connection_id friendly names with ARM resource IDs + resolveToolboxConnectionIDs(&toolbox, connIDMap) + + version, err := createToolboxVersion( + ctx, toolboxClient, toolbox, + ) + if err != nil { + return err + } + + if err := registerToolboxEnvVars( + ctx, azdClient, + currentEnv.Environment.Name, + projectEndpoint, toolbox.Name, version, + ); err != nil { + return err + } + + fmt.Fprintf( + os.Stderr, "Toolbox '%s' provisioned\n", toolbox.Name, + ) + } + + return nil +} + +// createToolboxVersion creates a new version of a toolbox. +// If the toolbox does not exist, it will be created automatically. +// Returns the version identifier of the newly created version. +func createToolboxVersion( + ctx context.Context, + client *azure.FoundryToolboxClient, + toolbox project.Toolbox, +) (string, error) { + req := &azure.CreateToolboxVersionRequest{ + Description: toolbox.Description, + Tools: toolbox.Tools, + } + + result, err := client.CreateToolboxVersion(ctx, toolbox.Name, req) + if err != nil { + return "", exterrors.Internal( + exterrors.CodeCreateToolboxVersionFailed, + fmt.Sprintf("failed to create toolbox version '%s': %s", toolbox.Name, err), + ) + } + + return result.Version, nil +} + +// registerToolboxEnvVars sets TOOLBOX_{NAME}_MCP_ENDPOINT with the versioned URL. +func registerToolboxEnvVars( + ctx context.Context, + azdClient *azdext.AzdClient, + envName string, + projectEndpoint string, + toolboxName string, + toolboxVersion string, +) error { + envKey := toolboxMCPEndpointEnvKey(toolboxName) + + endpoint := strings.TrimRight(projectEndpoint, "/") + mcpEndpoint := fmt.Sprintf( + "%s/toolboxes/%s/versions/%s/mcp?api-version=v1", + endpoint, toolboxName, toolboxVersion, + ) + + return setEnvVar( + ctx, azdClient, envName, envKey, mcpEndpoint, + ) +} + +// toolboxMCPEndpointEnvKey builds the TOOLBOX_{NAME}_MCP_ENDPOINT env var key. +func toolboxMCPEndpointEnvKey(toolboxName string) string { + key := strings.ReplaceAll(toolboxName, " ", "_") + key = strings.ReplaceAll(key, "-", "_") + return fmt.Sprintf("TOOLBOX_%s_MCP_ENDPOINT", strings.ToUpper(key)) +} + +// resolveToolboxEnvVars resolves ${VAR} references in toolbox name, description, +// and all tool map values using the provided azd environment variables. +func resolveToolboxEnvVars(toolbox *project.Toolbox, azdEnv map[string]string) { + toolbox.Name = resolveEnvValue(toolbox.Name, azdEnv) + toolbox.Description = resolveEnvValue(toolbox.Description, azdEnv) + for i, tool := range toolbox.Tools { + toolbox.Tools[i] = resolveMapValues(tool, azdEnv) + } +} + +// enrichToolboxFromConnections fills in server_url and server_label on toolbox +// tools that reference a connection via project_connection_id. This keeps the +// azure.yaml toolbox entries minimal while sending complete data to the API. +func enrichToolboxFromConnections( + toolbox *project.Toolbox, + connByName map[string]project.ToolConnection, +) { + for i, tool := range toolbox.Tools { + connID, _ := tool["project_connection_id"].(string) + if connID == "" { + continue + } + conn, ok := connByName[connID] + if !ok { + continue + } + if _, has := tool["server_url"]; !has && conn.Target != "" { + toolbox.Tools[i]["server_url"] = conn.Target + } + if _, has := tool["server_label"]; !has { + toolbox.Tools[i]["server_label"] = conn.Name + } + } +} + +// parseConnectionIDs parses the AI_PROJECT_CONNECTION_IDS_JSON env var +// (a JSON array of {name, id} objects) into a map of name → ARM resource ID. +func parseConnectionIDs(jsonStr string) map[string]string { + result := map[string]string{} + if jsonStr == "" { + return result + } + + var entries []struct { + Name string `json:"name"` + ID string `json:"id"` + } + if err := json.Unmarshal([]byte(jsonStr), &entries); err != nil { + fmt.Fprintf(os.Stderr, + "Warning: failed to parse AI_PROJECT_CONNECTION_IDS_JSON: %s\n", err) + return result + } + + for _, e := range entries { + if e.Name != "" && e.ID != "" { + result[e.Name] = e.ID + } + } + return result +} + +// resolveToolboxConnectionIDs replaces project_connection_id friendly names +// with their actual ARM resource IDs from bicep provisioning outputs. +func resolveToolboxConnectionIDs( + toolbox *project.Toolbox, + connIDs map[string]string, +) { + if len(connIDs) == 0 { + return + } + for i, tool := range toolbox.Tools { + connName, _ := tool["project_connection_id"].(string) + if connName == "" { + continue + } + connName = resolveTemplateRef(connName) + if armID, ok := connIDs[connName]; ok { + toolbox.Tools[i]["project_connection_id"] = armID + } + } +} + +// resolveTemplateRef strips {{ }} template wrapping and trims whitespace. +// "{{ my_conn }}" → "my_conn", "my_conn" → "my_conn" (unchanged). +func resolveTemplateRef(s string) string { + if strings.HasPrefix(s, "{{") && strings.HasSuffix(s, "}}") { + return strings.TrimSpace(s[2 : len(s)-2]) + } + return s +} + +// resolveEnvValue resolves ${VAR} references in a string using envsubst. +func resolveEnvValue(value string, azdEnv map[string]string) string { + resolved, err := envsubst.Eval(value, func(varName string) string { + return azdEnv[varName] + }) + if err != nil { + return value + } + return resolved +} + +// resolveMapValues recursively resolves ${VAR} references in all string values of a map. +func resolveMapValues(m map[string]any, azdEnv map[string]string) map[string]any { + resolved := make(map[string]any, len(m)) + for k, v := range m { + resolved[k] = resolveAnyValue(v, azdEnv) + } + return resolved +} + +// resolveAnyValue resolves ${VAR} references in a value of any type. +func resolveAnyValue(v any, azdEnv map[string]string) any { + switch val := v.(type) { + case string: + return resolveEnvValue(val, azdEnv) + case map[string]any: + return resolveMapValues(val, azdEnv) + case []any: + resolved := make([]any, len(val)) + for i, item := range val { + resolved[i] = resolveAnyValue(item, azdEnv) + } + return resolved + default: + return v + } +} + +// getAllEnvVars loads all environment variables from the azd environment. +func getAllEnvVars( + ctx context.Context, + azdClient *azdext.AzdClient, + envName string, +) (map[string]string, error) { + resp, err := azdClient.Environment().GetValues(ctx, &azdext.GetEnvironmentRequest{ + Name: envName, + }) + if err != nil { + return nil, err + } + + envMap := make(map[string]string, len(resp.KeyValues)) + for _, kv := range resp.KeyValues { + envMap[kv.Key] = kv.Value + } + return envMap, nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/listen_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen_test.go new file mode 100644 index 00000000000..9131ce2b5bd --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/listen_test.go @@ -0,0 +1,278 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "testing" + + "azureaiagent/internal/project" +) + +func TestNormalizeCredentials_CustomKeys_FlatToNested(t *testing.T) { + t.Parallel() + + // Old-format flat credentials should be wrapped under "keys" + creds := map[string]any{"key": "${TOOL_CONTEXT7_KEY}"} + result := normalizeCredentials("CustomKeys", creds) + + keysRaw, ok := result["keys"] + if !ok { + t.Fatal("Expected 'keys' wrapper in normalized credentials") + } + keys, ok := keysRaw.(map[string]any) + if !ok { + t.Fatalf("Expected keys to be map[string]any, got %T", keysRaw) + } + if keys["key"] != "${TOOL_CONTEXT7_KEY}" { + t.Errorf("Expected key value preserved, got %v", keys["key"]) + } +} + +func TestNormalizeCredentials_CustomKeys_AlreadyNested(t *testing.T) { + t.Parallel() + + // Already-correct nested credentials should be returned as-is + creds := map[string]any{ + "keys": map[string]any{"key": "${TOOL_CONTEXT7_KEY}"}, + } + result := normalizeCredentials("CustomKeys", creds) + + keysRaw, ok := result["keys"] + if !ok { + t.Fatal("Expected 'keys' wrapper preserved") + } + keys, ok := keysRaw.(map[string]any) + if !ok { + t.Fatalf("Expected keys to be map[string]any, got %T", keysRaw) + } + if keys["key"] != "${TOOL_CONTEXT7_KEY}" { + t.Errorf("Expected key value preserved, got %v", keys["key"]) + } + if len(result) != 1 { + t.Errorf("Expected only 'keys' in result, got %d entries", len(result)) + } +} + +func TestNormalizeCredentials_OAuth2_Unchanged(t *testing.T) { + t.Parallel() + + // Non-CustomKeys auth types should be returned unchanged + creds := map[string]any{ + "clientId": "${VAR_ID}", + "clientSecret": "${VAR_SECRET}", + } + result := normalizeCredentials("OAuth2", creds) + + if _, hasKeys := result["keys"]; hasKeys { + t.Error("OAuth2 credentials should not be wrapped in 'keys'") + } + if result["clientId"] != "${VAR_ID}" { + t.Errorf("Expected clientId preserved, got %v", result["clientId"]) + } +} + +func TestNormalizeCredentials_EmptyCredentials(t *testing.T) { + t.Parallel() + + result := normalizeCredentials("CustomKeys", nil) + if result != nil { + t.Errorf("Expected nil for nil input, got %v", result) + } + + result = normalizeCredentials("CustomKeys", map[string]any{}) + if len(result) != 0 { + t.Errorf("Expected empty map for empty input, got %v", result) + } +} + +func TestParseConnectionIDs(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + json string + expected map[string]string + }{ + { + name: "valid array", + json: `[{"name":"my-conn","id":"/subscriptions/123/resourceGroups/rg/providers/Microsoft.CognitiveServices/accounts/ai/projects/proj/connections/my-conn"}]`, + expected: map[string]string{"my-conn": "/subscriptions/123/resourceGroups/rg/providers/Microsoft.CognitiveServices/accounts/ai/projects/proj/connections/my-conn"}, + }, + { + name: "empty string", + json: "", + expected: map[string]string{}, + }, + { + name: "empty array", + json: "[]", + expected: map[string]string{}, + }, + { + name: "invalid JSON", + json: "not-json", + expected: map[string]string{}, + }, + { + name: "multiple connections", + json: `[{"name":"conn-a","id":"id-a"},{"name":"conn-b","id":"id-b"}]`, + expected: map[string]string{ + "conn-a": "id-a", + "conn-b": "id-b", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseConnectionIDs(tt.json) + if len(result) != len(tt.expected) { + t.Fatalf("got %d entries, want %d", len(result), len(tt.expected)) + } + for k, v := range tt.expected { + if result[k] != v { + t.Errorf("key %q: got %q, want %q", k, result[k], v) + } + } + }) + } +} + +func TestResolveToolboxConnectionIDs(t *testing.T) { + t.Parallel() + + connIDs := map[string]string{ + "github_mcp_connection": "/subscriptions/123/connections/github_mcp_connection", + } + + toolbox := project.Toolbox{ + Name: "test", + Tools: []map[string]any{ + {"type": "web_search"}, + {"type": "mcp", "project_connection_id": "{{ github_mcp_connection }}"}, + {"type": "mcp", "project_connection_id": "unknown_conn"}, + {"type": "mcp", "project_connection_id": "github_mcp_connection"}, + }, + } + + resolveToolboxConnectionIDs(&toolbox, connIDs) + + // Tool without project_connection_id: unchanged + if _, has := toolbox.Tools[0]["project_connection_id"]; has { + t.Error("tool 0 should not have project_connection_id") + } + + // Template ref {{ name }}: resolved to ARM ID + if toolbox.Tools[1]["project_connection_id"] != "/subscriptions/123/connections/github_mcp_connection" { + t.Errorf("tool 1 project_connection_id = %v, want ARM ID", + toolbox.Tools[1]["project_connection_id"]) + } + + // Unknown connection: left as-is + if toolbox.Tools[2]["project_connection_id"] != "unknown_conn" { + t.Errorf("tool 2 project_connection_id = %v, want 'unknown_conn'", + toolbox.Tools[2]["project_connection_id"]) + } + + // Bare name (no braces): also resolved + if toolbox.Tools[3]["project_connection_id"] != "/subscriptions/123/connections/github_mcp_connection" { + t.Errorf("tool 3 project_connection_id = %v, want ARM ID", + toolbox.Tools[3]["project_connection_id"]) + } +} + +func TestResolveTemplateRef(t *testing.T) { + t.Parallel() + + tests := []struct { + input string + want string + }{ + {"{{ my_conn }}", "my_conn"}, + {"{{my_conn}}", "my_conn"}, + {"{{ spaced }}", "spaced"}, + {"my_conn", "my_conn"}, + {"", ""}, + {"{not_template}", "{not_template}"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + t.Parallel() + if got := resolveTemplateRef(tt.input); got != tt.want { + t.Errorf("resolveTemplateRef(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestBuildConnectionCredentials(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + connections []project.Connection + wantKeys []string + wantEmpty bool + }{ + { + name: "empty connections", + wantEmpty: true, + }, + { + name: "connections with credentials", + connections: []project.Connection{ + { + Name: "my-openai", + Credentials: map[string]any{"key": "${OPENAI_API_KEY}"}, + }, + { + Name: "github-mcp", + Credentials: map[string]any{"pat": "${GITHUB_PAT}"}, + }, + }, + wantKeys: []string{"my-openai", "github-mcp"}, + }, + { + name: "skips connections without credentials", + connections: []project.Connection{ + { + Name: "no-creds", + Credentials: nil, + }, + { + Name: "has-creds", + Credentials: map[string]any{"secret": "val"}, + }, + }, + wantKeys: []string{"has-creds"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + result := buildConnectionCredentials(tt.connections) + + if tt.wantEmpty { + if len(result) != 0 { + t.Fatalf("expected empty map, got %v", result) + } + return + } + + if len(result) != len(tt.wantKeys) { + t.Fatalf("expected %d entries, got %d: %v", + len(tt.wantKeys), len(result), result) + } + + for _, key := range tt.wantKeys { + if _, ok := result[key]; !ok { + t.Errorf("expected key %q in result", key) + } + } + }) + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go b/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go index fe700c880e6..bfbb7101fe9 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go +++ b/cli/azd/extensions/azure.ai.agents/internal/exterrors/codes.go @@ -44,6 +44,7 @@ const ( CodeEnvironmentCreationFailed = "environment_creation_failed" CodeEnvironmentValuesFailed = "environment_values_failed" CodeMissingAiProjectEndpoint = "missing_ai_project_endpoint" + CodeMissingAzureTenantId = "missing_azure_tenant_id" CodeMissingAiProjectId = "missing_ai_project_id" CodeMissingAzureSubscription = "missing_azure_subscription_id" CodeMissingAgentEnvVars = "missing_agent_env_vars" @@ -87,6 +88,19 @@ const ( CodeInvalidFilePath = "invalid_file_path" ) +// Error codes for toolbox operations. +const ( + CodeInvalidToolbox = "invalid_toolbox" + CodeCreateToolboxVersionFailed = "create_toolbox_version_failed" +) + +// Error codes for connection operations. +const ( + CodeInvalidConnection = "invalid_connection" + CodeConnectionCreationFail = "connection_creation_failed" + CodeMissingConnectionField = "missing_connection_field" +) + // Error codes commonly used for internal errors. // // These are usually paired with [Internal] for unexpected failures @@ -108,4 +122,6 @@ const ( OpCreateAgent = "create_agent" OpStartContainer = "start_container" OpGetContainerOperation = "get_container_operation" + OpCreateToolboxVersion = "create_toolbox_version" + OpGetToolbox = "get_toolbox" ) diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/map.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/map.go index 0f62aede98d..b2e6b6ba578 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/map.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/map.go @@ -418,7 +418,7 @@ func CreateHostedAgentAPIRequest(hostedAgent ContainerAgent, buildConfig *AgentB } else { // Set default protocol versions if none specified protocolVersions = []agent_api.ProtocolVersionRecord{ - {Protocol: agent_api.AgentProtocolResponses, Version: "v1"}, + {Protocol: agent_api.AgentProtocolResponses, Version: "1.0.0"}, } } diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse.go index f859ff8a31a..3048bed1f6e 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse.go @@ -170,6 +170,18 @@ func ExtractResourceDefinitions(manifestYamlContent []byte) ([]any, error) { return nil, fmt.Errorf("failed to unmarshal to ToolResource: %w", err) } resourceDefs = append(resourceDefs, toolDef) + case ResourceKindToolbox: + var toolboxDef ToolboxResource + if err := yaml.Unmarshal(resourceBytes, &toolboxDef); err != nil { + return nil, fmt.Errorf("failed to unmarshal to ToolboxResource: %w", err) + } + resourceDefs = append(resourceDefs, toolboxDef) + case ResourceKindConnection: + var connDef ConnectionResource + if err := yaml.Unmarshal(resourceBytes, &connDef); err != nil { + return nil, fmt.Errorf("failed to unmarshal to ConnectionResource: %w", err) + } + resourceDefs = append(resourceDefs, connDef) default: return nil, fmt.Errorf("unrecognized resource kind: %s", resourceDef.Kind) } @@ -296,6 +308,18 @@ func ExtractToolsDefinitions(template map[string]any) ([]any, error) { return nil, fmt.Errorf("failed to unmarshal to CodeInterpreterTool: %w", err) } tools = append(tools, codeInterpreterTool) + case ToolKindAzureAiSearch: + var azureAiSearchTool AzureAISearchTool + if err := yaml.Unmarshal(toolBytes, &azureAiSearchTool); err != nil { + return nil, fmt.Errorf("failed to unmarshal to AzureAISearchTool: %w", err) + } + tools = append(tools, azureAiSearchTool) + case ToolKindA2APreview: + var a2aTool A2APreviewTool + if err := yaml.Unmarshal(toolBytes, &a2aTool); err != nil { + return nil, fmt.Errorf("failed to unmarshal to A2APreviewTool: %w", err) + } + tools = append(tools, a2aTool) default: return nil, fmt.Errorf("unrecognized tool kind: %s", toolDef.Kind) } diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse_test.go index 3679d528713..544402252b4 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/parse_test.go @@ -218,3 +218,675 @@ environment_variables: t.Errorf("Expected error message to contain '%s', got '%s'", expectedMsg, err.Error()) } } + +// TestExtractResourceDefinitions_ToolboxResource tests parsing toolbox resources from manifest +func TestExtractResourceDefinitions_ToolboxResource(t *testing.T) { + yamlContent := []byte(` +name: test-manifest +template: + kind: prompt + name: test-agent + model: + id: gpt-4.1-mini +resources: + - kind: toolbox + name: echo-toolbox + description: A sample toolbox + tools: + - type: mcp + server_label: github + server_url: https://api.example.com/mcp + project_connection_id: TestKey +`) + + resources, err := ExtractResourceDefinitions(yamlContent) + if err != nil { + t.Fatalf("ExtractResourceDefinitions failed: %v", err) + } + + if len(resources) != 1 { + t.Fatalf("Expected 1 resource, got %d", len(resources)) + } + + toolboxRes, ok := resources[0].(ToolboxResource) + if !ok { + t.Fatalf("Expected ToolboxResource, got %T", resources[0]) + } + + if toolboxRes.Name != "echo-toolbox" { + t.Errorf("Expected name 'echo-toolbox', got '%s'", toolboxRes.Name) + } + + if toolboxRes.Kind != ResourceKindToolbox { + t.Errorf("Expected kind 'toolbox', got '%s'", toolboxRes.Kind) + } + + if toolboxRes.Description != "A sample toolbox" { + t.Errorf("Expected description 'A sample toolbox', got '%s'", toolboxRes.Description) + } + + if len(toolboxRes.Tools) != 1 { + t.Fatalf("Expected 1 tool, got %d", len(toolboxRes.Tools)) + } +} + +// TestExtractResourceDefinitions_MixedResources tests parsing both model and toolbox resources +func TestExtractResourceDefinitions_MixedResources(t *testing.T) { + yamlContent := []byte(` +name: test-manifest +template: + kind: prompt + name: test-agent + model: + id: gpt-4.1-mini +resources: + - kind: model + name: primary-model + id: gpt-4.1-mini + - kind: toolbox + name: my-toolbox + description: My toolbox + tools: + - type: web_search +`) + + resources, err := ExtractResourceDefinitions(yamlContent) + if err != nil { + t.Fatalf("ExtractResourceDefinitions failed: %v", err) + } + + if len(resources) != 2 { + t.Fatalf("Expected 2 resources, got %d", len(resources)) + } + + if _, ok := resources[0].(ModelResource); !ok { + t.Errorf("Expected first resource to be ModelResource, got %T", resources[0]) + } + + toolboxRes, ok := resources[1].(ToolboxResource) + if !ok { + t.Fatalf("Expected second resource to be ToolboxResource, got %T", resources[1]) + } + + if toolboxRes.Name != "my-toolbox" { + t.Errorf("Expected name 'my-toolbox', got '%s'", toolboxRes.Name) + } + + if toolboxRes.Description != "My toolbox" { + t.Errorf("Expected description 'My toolbox', got '%s'", toolboxRes.Description) + } +} + +// TestExtractResourceDefinitions_ConnectionResource tests parsing connection resources +func TestExtractResourceDefinitions_ConnectionResource(t *testing.T) { + yamlContent := []byte(` +name: test-manifest +template: + kind: prompt + name: test-agent + model: + id: gpt-4.1-mini +resources: + - kind: connection + name: context7 + category: CustomKeys + target: https://mcp.context7.com/mcp + authType: CustomKeys + credentials: + key: my-api-key + metadata: + source: context7 +`) + + resources, err := ExtractResourceDefinitions(yamlContent) + if err != nil { + t.Fatalf("ExtractResourceDefinitions failed: %v", err) + } + + if len(resources) != 1 { + t.Fatalf("Expected 1 resource, got %d", len(resources)) + } + + connRes, ok := resources[0].(ConnectionResource) + if !ok { + t.Fatalf("Expected ConnectionResource, got %T", resources[0]) + } + + if connRes.Name != "context7" { + t.Errorf("Expected name 'context7', got '%s'", connRes.Name) + } + if connRes.Kind != ResourceKindConnection { + t.Errorf("Expected kind 'connection', got '%s'", connRes.Kind) + } + if connRes.Category != CategoryCustomKeys { + t.Errorf("Expected category 'CustomKeys', got '%s'", connRes.Category) + } + if connRes.Target != "https://mcp.context7.com/mcp" { + t.Errorf("Expected target 'https://mcp.context7.com/mcp', got '%s'", connRes.Target) + } + if connRes.AuthType != AuthTypeCustomKeys { + t.Errorf("Expected authType 'CustomKeys', got '%s'", connRes.AuthType) + } + if connRes.Credentials["key"] != "my-api-key" { + t.Errorf("Expected credentials.key 'my-api-key', got '%v'", connRes.Credentials["key"]) + } + if connRes.Metadata["source"] != "context7" { + t.Errorf("Expected metadata.source 'context7', got '%s'", connRes.Metadata["source"]) + } +} + +// TestExtractResourceDefinitions_AllResourceKinds tests model + toolbox + connection together +func TestExtractResourceDefinitions_AllResourceKinds(t *testing.T) { + yamlContent := []byte(` +name: test-manifest +template: + kind: prompt + name: test-agent + model: + id: gpt-4.1-mini +resources: + - kind: model + name: chat + id: gpt-4.1-mini + - kind: connection + name: context7 + category: CustomKeys + target: https://mcp.context7.com/mcp + authType: CustomKeys + credentials: + key: test-key + - kind: toolbox + name: agent-tools + description: MCP tools for documentation search + tools: + - type: web_search + - type: mcp + project_connection_id: context7 +`) + + resources, err := ExtractResourceDefinitions(yamlContent) + if err != nil { + t.Fatalf("ExtractResourceDefinitions failed: %v", err) + } + + if len(resources) != 3 { + t.Fatalf("Expected 3 resources, got %d", len(resources)) + } + + if _, ok := resources[0].(ModelResource); !ok { + t.Errorf("Expected first resource to be ModelResource, got %T", resources[0]) + } + if _, ok := resources[1].(ConnectionResource); !ok { + t.Errorf("Expected second resource to be ConnectionResource, got %T", resources[1]) + } + if _, ok := resources[2].(ToolboxResource); !ok { + t.Errorf("Expected third resource to be ToolboxResource, got %T", resources[2]) + } +} + +// TestExtractResourceDefinitions_ConnectionAllAuthTypes tests all supported auth types +func TestExtractResourceDefinitions_ConnectionAllAuthTypes(t *testing.T) { + authTypes := []AuthType{ + AuthTypeAAD, + AuthTypeApiKey, + AuthTypeCustomKeys, + AuthTypeNone, + AuthTypeOAuth2, + AuthTypePAT, + } + + for _, authType := range authTypes { + t.Run(string(authType), func(t *testing.T) { + yamlContent := []byte(` +name: test-manifest +template: + kind: prompt + name: test-agent + model: + id: gpt-4.1-mini +resources: + - kind: connection + name: test-conn + category: CustomKeys + target: https://example.com + authType: ` + string(authType) + ` +`) + resources, err := ExtractResourceDefinitions(yamlContent) + if err != nil { + t.Fatalf("Failed for authType %s: %v", authType, err) + } + + connRes := resources[0].(ConnectionResource) + if connRes.AuthType != authType { + t.Errorf("Expected authType '%s', got '%s'", authType, connRes.AuthType) + } + }) + } +} + +// TestExtractResourceDefinitions_ConnectionOptionalFields tests optional fields are preserved +func TestExtractResourceDefinitions_ConnectionOptionalFields(t *testing.T) { + yamlContent := []byte(` +name: test-manifest +template: + kind: prompt + name: test-agent + model: + id: gpt-4.1-mini +resources: + - kind: connection + name: full-conn + category: AzureOpenAI + target: https://myendpoint.openai.azure.com + authType: AAD + expiryTime: "2025-12-31T00:00:00Z" + isSharedToAll: true + sharedUserList: + - user1@contoso.com + - user2@contoso.com + metadata: + env: production + useWorkspaceManagedIdentity: false +`) + + resources, err := ExtractResourceDefinitions(yamlContent) + if err != nil { + t.Fatalf("ExtractResourceDefinitions failed: %v", err) + } + + connRes := resources[0].(ConnectionResource) + + if connRes.ExpiryTime != "2025-12-31T00:00:00Z" { + t.Errorf("Expected expiryTime, got '%s'", connRes.ExpiryTime) + } + if connRes.IsSharedToAll == nil || *connRes.IsSharedToAll != true { + t.Error("Expected isSharedToAll to be true") + } + if len(connRes.SharedUserList) != 2 { + t.Errorf("Expected 2 shared users, got %d", len(connRes.SharedUserList)) + } + if connRes.Metadata["env"] != "production" { + t.Errorf("Expected metadata.env 'production', got '%s'", connRes.Metadata["env"]) + } + if connRes.UseWorkspaceManagedIdentity == nil || *connRes.UseWorkspaceManagedIdentity != false { + t.Error("Expected useWorkspaceManagedIdentity to be false") + } +} + +// TestExtractToolsDefinitions_AzureAiSearch tests parsing an azure_ai_search tool +func TestExtractToolsDefinitions_AzureAiSearch(t *testing.T) { + template := map[string]any{ + "tools": []any{ + map[string]any{ + "name": "my-search", + "kind": "azure_ai_search", + "indexes": []any{ + map[string]any{ + "project_connection_id": "search-conn", + "index_name": "docs-index", + "query_type": "semantic", + "top_k": 10, + }, + }, + }, + }, + } + + tools, err := ExtractToolsDefinitions(template) + if err != nil { + t.Fatalf("ExtractToolsDefinitions failed: %v", err) + } + + if len(tools) != 1 { + t.Fatalf("Expected 1 tool, got %d", len(tools)) + } + + searchTool, ok := tools[0].(AzureAISearchTool) + if !ok { + t.Fatalf("Expected AzureAISearchTool, got %T", tools[0]) + } + if searchTool.Kind != ToolKindAzureAiSearch { + t.Errorf("Expected kind 'azure_ai_search', got '%s'", searchTool.Kind) + } + if len(searchTool.Indexes) != 1 { + t.Fatalf("Expected 1 index, got %d", len(searchTool.Indexes)) + } + if searchTool.Indexes[0].IndexName != "docs-index" { + t.Errorf("Expected index_name 'docs-index', got '%s'", searchTool.Indexes[0].IndexName) + } +} + +// TestExtractToolsDefinitions_A2APreview tests parsing an a2a_preview tool +func TestExtractToolsDefinitions_A2APreview(t *testing.T) { + template := map[string]any{ + "tools": []any{ + map[string]any{ + "name": "a2a-delegate", + "kind": "a2a_preview", + "baseUrl": "https://remote-agent.example.com", + "agentCardPath": "/.well-known/agent.json", + "projectConnectionId": "remote-conn", + }, + }, + } + + tools, err := ExtractToolsDefinitions(template) + if err != nil { + t.Fatalf("ExtractToolsDefinitions failed: %v", err) + } + + if len(tools) != 1 { + t.Fatalf("Expected 1 tool, got %d", len(tools)) + } + + a2aTool, ok := tools[0].(A2APreviewTool) + if !ok { + t.Fatalf("Expected A2APreviewTool, got %T", tools[0]) + } + if a2aTool.Kind != ToolKindA2APreview { + t.Errorf("Expected kind 'a2a_preview', got '%s'", a2aTool.Kind) + } + if a2aTool.BaseUrl != "https://remote-agent.example.com" { + t.Errorf("Expected baseUrl, got '%s'", a2aTool.BaseUrl) + } + if a2aTool.ProjectConnectionId != "remote-conn" { + t.Errorf("Expected projectConnectionId 'remote-conn', got '%s'", a2aTool.ProjectConnectionId) + } +} + +// TestExtractResourceDefinitions_ToolboxResourceWithTypedTools tests parsing a toolbox +// resource that has tool entries in the Tools []any field, +// matching the AgentSchema ToolboxResource/ToolboxTool format. +func TestExtractResourceDefinitions_ToolboxResourceWithTypedTools(t *testing.T) { + yamlContent := []byte(` +name: test-manifest +template: + kind: prompt + name: test-agent + model: + id: gpt-4.1-mini +resources: + - kind: toolbox + name: platform-tools + description: Platform tools with typed definitions + tools: + - type: bing_grounding + - type: mcp + name: github-copilot + target: https://api.githubcopilot.com/mcp + authType: OAuth2 + credentials: + clientId: my-client-id + clientSecret: my-client-secret + - type: mcp + name: custom-api + target: https://my-api.example.com/sse + authType: CustomKeys + credentials: + key: my-api-key +`) + + resources, err := ExtractResourceDefinitions(yamlContent) + if err != nil { + t.Fatalf("ExtractResourceDefinitions failed: %v", err) + } + + if len(resources) != 1 { + t.Fatalf("Expected 1 resource, got %d", len(resources)) + } + + toolboxRes, ok := resources[0].(ToolboxResource) + if !ok { + t.Fatalf("Expected ToolboxResource, got %T", resources[0]) + } + + if toolboxRes.Name != "platform-tools" { + t.Errorf("Expected name 'platform-tools', got '%s'", toolboxRes.Name) + } + + if toolboxRes.Description != "Platform tools with typed definitions" { + t.Errorf("Expected description, got '%s'", toolboxRes.Description) + } + + if len(toolboxRes.Tools) != 3 { + t.Fatalf("Expected 3 typed tools, got %d", len(toolboxRes.Tools)) + } + + // Helper to get tool as map + tool := func(i int) map[string]any { + m, ok := toolboxRes.Tools[i].(map[string]any) + if !ok { + t.Fatalf("Expected tool[%d] to be map[string]any, got %T", i, toolboxRes.Tools[i]) + } + return m + } + + // Check built-in tool (no target/authType/name) + if tool(0)["type"] != "bing_grounding" { + t.Errorf("Expected first tool type 'bing_grounding', got '%v'", tool(0)["type"]) + } + if tool(0)["target"] != nil { + t.Errorf("Expected no target for built-in tool, got '%v'", tool(0)["target"]) + } + + // Check MCP tool with name and OAuth2 + if tool(1)["type"] != "mcp" { + t.Errorf("Expected second tool type 'mcp', got '%v'", tool(1)["type"]) + } + if tool(1)["name"] != "github-copilot" { + t.Errorf("Expected second tool name 'github-copilot', got '%v'", tool(1)["name"]) + } + if tool(1)["target"] != "https://api.githubcopilot.com/mcp" { + t.Errorf("Expected second tool target, got '%v'", tool(1)["target"]) + } + if tool(1)["authType"] != "OAuth2" { + t.Errorf("Expected second tool authType 'OAuth2', got '%v'", tool(1)["authType"]) + } + creds1, _ := tool(1)["credentials"].(map[string]any) + if creds1["clientId"] != "my-client-id" { + t.Errorf("Expected second tool clientId, got '%v'", creds1["clientId"]) + } + + // Check MCP tool with CustomKeys + if tool(2)["type"] != "mcp" { + t.Errorf("Expected third tool type 'mcp', got '%v'", tool(2)["type"]) + } + if tool(2)["name"] != "custom-api" { + t.Errorf("Expected third tool name 'custom-api', got '%v'", tool(2)["name"]) + } + if tool(2)["authType"] != "CustomKeys" { + t.Errorf("Expected third tool authType 'CustomKeys', got '%v'", tool(2)["authType"]) + } +} + +// TestLoadAndValidateAgentManifest_RecordFormatParameters verifies that the +// record/map format for parameters (canonical agent manifest schema) is parsed +// correctly into PropertySchema.Properties. +func TestLoadAndValidateAgentManifest_RecordFormatParameters(t *testing.T) { + yamlContent := []byte(` +name: test-params +template: + name: test + kind: hosted + protocols: + - protocol: responses +resources: + - kind: model + name: chat + id: gpt-5 + - kind: toolbox + name: tools + tools: + - type: mcp + name: github + target: https://api.githubcopilot.com/mcp + authType: OAuth2 + credentials: + clientId: "{{ github_client_id }}" + clientSecret: "{{ github_client_secret }}" +parameters: + github_client_id: + schema: + type: string + description: OAuth client ID + required: true + github_client_secret: + schema: + type: string + description: OAuth client secret + required: true + model_name: + schema: + type: string + enum: + - gpt-4o + - gpt-4o-mini + default: gpt-4o + required: true +`) + + manifest, err := LoadAndValidateAgentManifest(yamlContent) + if err != nil { + t.Fatalf("LoadAndValidateAgentManifest failed: %v", err) + } + + if len(manifest.Parameters.Properties) != 3 { + t.Fatalf("Expected 3 parameters, got %d", len(manifest.Parameters.Properties)) + } + + // Find parameters by name (map order is not guaranteed) + paramsByName := map[string]Property{} + for _, p := range manifest.Parameters.Properties { + paramsByName[p.Name] = p + } + + // Check github_client_id + p, ok := paramsByName["github_client_id"] + if !ok { + t.Fatal("Missing parameter github_client_id") + } + if p.Kind != "string" { + t.Errorf("Expected kind 'string', got '%s'", p.Kind) + } + if p.Description == nil || *p.Description != "OAuth client ID" { + t.Errorf("Unexpected description: %v", p.Description) + } + if p.Required == nil || !*p.Required { + t.Error("Expected required=true") + } + + // Check model_name with enum and default + p, ok = paramsByName["model_name"] + if !ok { + t.Fatal("Missing parameter model_name") + } + if p.EnumValues == nil || len(*p.EnumValues) != 2 { + t.Fatalf("Expected 2 enum values, got %v", p.EnumValues) + } + if p.Default == nil { + t.Fatal("Expected default value") + } + if defaultStr, ok := (*p.Default).(string); !ok || defaultStr != "gpt-4o" { + t.Errorf("Expected default 'gpt-4o', got %v", *p.Default) + } +} + +// TestLoadAndValidateAgentManifest_ArrayFormatParameters verifies that the +// traditional array format for parameters still works after the UnmarshalYAML change. +func TestLoadAndValidateAgentManifest_ArrayFormatParameters(t *testing.T) { + yamlContent := []byte(` +name: test-array-params +template: + name: test + kind: hosted + protocols: + - protocol: responses +resources: + - kind: model + name: chat + id: gpt-5 +parameters: + properties: + - name: my_param + kind: string + description: A test parameter + required: true +`) + + manifest, err := LoadAndValidateAgentManifest(yamlContent) + if err != nil { + t.Fatalf("LoadAndValidateAgentManifest failed: %v", err) + } + + if len(manifest.Parameters.Properties) != 1 { + t.Fatalf("Expected 1 parameter, got %d", len(manifest.Parameters.Properties)) + } + + p := manifest.Parameters.Properties[0] + if p.Name != "my_param" { + t.Errorf("Expected name 'my_param', got '%s'", p.Name) + } + if p.Kind != "string" { + t.Errorf("Expected kind 'string', got '%s'", p.Kind) + } +} + +// TestLoadAndValidateAgentManifest_SecretParameter verifies that +// secret: true inside the schema block is parsed into Property.Secret. +func TestLoadAndValidateAgentManifest_SecretParameter(t *testing.T) { + yamlContent := []byte(` +name: test-secret +template: + name: test + kind: hosted + protocols: + - protocol: responses +resources: + - kind: model + name: chat + id: gpt-5 +parameters: + api_key: + description: API key for the custom MCP server + schema: + type: string + secret: true + required: true + display_name: + description: A non-secret parameter + schema: + type: string +`) + + manifest, err := LoadAndValidateAgentManifest(yamlContent) + if err != nil { + t.Fatalf("LoadAndValidateAgentManifest failed: %v", err) + } + + if len(manifest.Parameters.Properties) != 2 { + t.Fatalf("Expected 2 parameters, got %d", len(manifest.Parameters.Properties)) + } + + paramsByName := map[string]Property{} + for _, p := range manifest.Parameters.Properties { + paramsByName[p.Name] = p + } + + // api_key should be secret + apiKey, ok := paramsByName["api_key"] + if !ok { + t.Fatal("Missing parameter api_key") + } + if apiKey.Secret == nil || !*apiKey.Secret { + t.Error("Expected api_key to have secret=true") + } + + // display_name should NOT be secret + displayName, ok := paramsByName["display_name"] + if !ok { + t.Fatal("Missing parameter display_name") + } + if displayName.Secret != nil { + t.Errorf("Expected display_name secret to be nil, got %v", *displayName.Secret) + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml.go index bb9231c7651..13038ee5ccd 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml.go @@ -3,7 +3,12 @@ package agent_yaml -import "slices" +import ( + "fmt" + "slices" + + "go.yaml.in/yaml/v3" +) // AgentKind represents the type of agent type AgentKind string @@ -31,8 +36,10 @@ func ValidAgentKinds() []AgentKind { type ResourceKind string const ( - ResourceKindModel ResourceKind = "model" - ResourceKindTool ResourceKind = "tool" + ResourceKindModel ResourceKind = "model" + ResourceKindTool ResourceKind = "tool" + ResourceKindToolbox ResourceKind = "toolbox" + ResourceKindConnection ResourceKind = "connection" ) type ToolKind string @@ -40,12 +47,55 @@ type ToolKind string const ( ToolKindFunction ToolKind = "function" ToolKindCustom ToolKind = "custom" - ToolKindWebSearch ToolKind = "webSearch" - ToolKindBingGrounding ToolKind = "bingGrounding" - ToolKindFileSearch ToolKind = "fileSearch" + ToolKindWebSearch ToolKind = "web_search" + ToolKindBingGrounding ToolKind = "bing_grounding" + ToolKindFileSearch ToolKind = "file_search" ToolKindMcp ToolKind = "mcp" - ToolKindOpenApi ToolKind = "openApi" - ToolKindCodeInterpreter ToolKind = "codeInterpreter" + ToolKindOpenApi ToolKind = "openapi" + ToolKindCodeInterpreter ToolKind = "code_interpreter" + ToolKindAzureAiSearch ToolKind = "azure_ai_search" + ToolKindA2APreview ToolKind = "a2a_preview" +) + +// AuthType represents the authentication type for a connection. +type AuthType string + +const ( + AuthTypeAAD AuthType = "AAD" + AuthTypeApiKey AuthType = "ApiKey" + AuthTypeCustomKeys AuthType = "CustomKeys" + AuthTypeNone AuthType = "None" + AuthTypeOAuth2 AuthType = "OAuth2" + AuthTypePAT AuthType = "PAT" +) + +// CategoryKind represents the category of a connection resource. +type CategoryKind string + +const ( + CategoryAzureOpenAI CategoryKind = "AzureOpenAI" + CategoryCognitiveSearch CategoryKind = "CognitiveSearch" + CategoryCognitiveService CategoryKind = "CognitiveService" + CategoryCustomKeys CategoryKind = "CustomKeys" + CategoryServerlessEndpoint CategoryKind = "Serverless" + CategoryContainerRegistry CategoryKind = "ContainerRegistry" + CategoryApiKey CategoryKind = "ApiKey" + CategoryAzureBlob CategoryKind = "AzureBlob" + CategoryGit CategoryKind = "Git" + CategoryRedis CategoryKind = "Redis" + CategoryS3 CategoryKind = "S3" + CategorySnowflake CategoryKind = "Snowflake" + CategoryAzureSqlDb CategoryKind = "AzureSqlDb" + CategoryAzureSynapseAnalytics CategoryKind = "AzureSynapseAnalytics" + CategoryAzureMySqlDb CategoryKind = "AzureMySqlDb" + CategoryAzurePostgresDb CategoryKind = "AzurePostgresDb" + CategoryADLSGen2 CategoryKind = "ADLSGen2" + CategoryAzureDataExplorer CategoryKind = "AzureDataExplorer" + CategoryBingLLMSearch CategoryKind = "BingLLMSearch" + CategoryMicrosoftOneLake CategoryKind = "MicrosoftOneLake" + CategoryElasticSearch CategoryKind = "Elasticsearch" + CategoryPinecone CategoryKind = "Pinecone" + CategoryQdrant CategoryKind = "Qdrant" ) type ConnectionKind string @@ -252,6 +302,7 @@ type Property struct { Default *any `json:"default,omitempty" yaml:"default,omitempty"` Example *any `json:"example,omitempty" yaml:"example,omitempty"` EnumValues *[]any `json:"enumValues,omitempty" yaml:"enumValues,omitempty"` + Secret *bool `json:"secret,omitempty" yaml:"secret,omitempty"` } // ArrayProperty Represents an array property. @@ -272,10 +323,239 @@ type ObjectProperty struct { // PropertySchema Definition for the property schema of a model. // This includes the properties and example records. +// +// The schema supports two YAML layouts for Properties: +// +// Array format (explicit): +// +// properties: +// - name: foo +// kind: string +// +// Record/map format (canonical agent manifest shorthand): +// +// parameters: +// foo: +// schema: { type: string } +// description: a foo param +// required: true +// +// UnmarshalYAML detects which layout is present and normalises to []Property. type PropertySchema struct { - Examples *[]map[string]any `json:"examples,omitempty" yaml:"examples,omitempty"` - Strict *bool `json:"strict,omitempty" yaml:"strict,omitempty"` - Properties []Property `json:"properties" yaml:"properties"` + Examples *[]map[string]any `json:"examples,omitempty" yaml:"-"` + Strict *bool `json:"strict,omitempty" yaml:"-"` + Properties []Property `json:"properties" yaml:"-"` +} + +// UnmarshalYAML supports both the array format (properties: []) and the +// record/map format where parameter names are direct YAML keys. +func (ps *PropertySchema) UnmarshalYAML(value *yaml.Node) error { + // The node should be a mapping. + if value.Kind != yaml.MappingNode { + return fmt.Errorf("PropertySchema: expected mapping node, got %d", value.Kind) + } + + // First pass: look for known struct keys (examples, strict, properties). + // Anything else is treated as a record-format parameter name. + var ( + propertiesNode *yaml.Node + extraKeys []string + extraValues []*yaml.Node + ) + + for i := 0; i < len(value.Content)-1; i += 2 { + key := value.Content[i].Value + val := value.Content[i+1] + + switch key { + case "examples": + var examples []map[string]any + if err := val.Decode(&examples); err != nil { + return fmt.Errorf("PropertySchema.examples: %w", err) + } + ps.Examples = &examples + case "strict": + var strict bool + if err := val.Decode(&strict); err != nil { + return fmt.Errorf("PropertySchema.strict: %w", err) + } + ps.Strict = &strict + case "properties": + propertiesNode = val + default: + extraKeys = append(extraKeys, key) + extraValues = append(extraValues, val) + } + } + + // If an explicit "properties" key was found, decode it (array or map). + if propertiesNode != nil { + props, err := decodePropertiesNode(propertiesNode) + if err != nil { + return fmt.Errorf("PropertySchema.properties: %w", err) + } + ps.Properties = props + return nil + } + + // No explicit "properties" key — treat extra keys as record-format params. + if len(extraKeys) > 0 { + for i, name := range extraKeys { + prop, err := decodeRecordProperty(name, extraValues[i]) + if err != nil { + return fmt.Errorf("PropertySchema parameter %q: %w", name, err) + } + ps.Properties = append(ps.Properties, prop) + } + } + + return nil +} + +// decodePropertiesNode handles "properties:" as either an array or a map. +func decodePropertiesNode(node *yaml.Node) ([]Property, error) { + switch node.Kind { + case yaml.SequenceNode: + var props []Property + if err := node.Decode(&props); err != nil { + return nil, err + } + return props, nil + case yaml.MappingNode: + var props []Property + for i := 0; i < len(node.Content)-1; i += 2 { + name := node.Content[i].Value + prop, err := decodeRecordProperty(name, node.Content[i+1]) + if err != nil { + return nil, fmt.Errorf("property %q: %w", name, err) + } + props = append(props, prop) + } + return props, nil + default: + return nil, fmt.Errorf("expected sequence or mapping, got %d", node.Kind) + } +} + +// recordEntry is the intermediate structure for parsing a single record-format +// parameter entry like: +// +// param_name: +// schema: { type: string, enum: [...], default: ... } +// description: some text +// required: true +type recordEntry struct { + Schema *recordSchema `yaml:"schema"` + Description string `yaml:"description"` + Required bool `yaml:"required"` + Default *any `yaml:"default"` + Example *any `yaml:"example"` + EnumValues *[]any `yaml:"enumValues"` +} + +type recordSchema struct { + Type string `yaml:"type"` + Enum []any `yaml:"enum"` + Default *any `yaml:"default"` + Secret bool `yaml:"secret"` +} + +// decodeRecordProperty converts a record-format parameter entry into a Property. +func decodeRecordProperty(name string, node *yaml.Node) (Property, error) { + var entry recordEntry + if err := node.Decode(&entry); err != nil { + return Property{}, err + } + + prop := Property{Name: name} + if entry.Description != "" { + prop.Description = &entry.Description + } + if entry.Required { + prop.Required = &entry.Required + } + if entry.Default != nil { + prop.Default = entry.Default + } + if entry.Example != nil { + prop.Example = entry.Example + } + if entry.EnumValues != nil { + prop.EnumValues = entry.EnumValues + } + + // Extract kind/default/enum/secret from nested schema if present + if entry.Schema != nil { + prop.Kind = entry.Schema.Type + if entry.Schema.Default != nil && prop.Default == nil { + prop.Default = entry.Schema.Default + } + if len(entry.Schema.Enum) > 0 && prop.EnumValues == nil { + prop.EnumValues = &entry.Schema.Enum + } + if entry.Schema.Secret { + prop.Secret = new(true) + } + } + + return prop, nil +} + +// MarshalYAML writes PropertySchema back as the record/map format so that +// {{param}} placeholders elsewhere in the document survive a marshal→unmarshal +// round-trip through InjectParameterValuesIntoManifest. +func (ps PropertySchema) MarshalYAML() (any, error) { + out := make(map[string]any) + + if ps.Examples != nil { + out["examples"] = *ps.Examples + } + if ps.Strict != nil { + out["strict"] = *ps.Strict + } + + // Emit each property as a record-format entry. + props := make(map[string]any, len(ps.Properties)) + for _, p := range ps.Properties { + entry := map[string]any{} + schema := map[string]any{} + + if p.Kind != "" { + schema["type"] = p.Kind + } + if p.Default != nil { + schema["default"] = *p.Default + } + if p.EnumValues != nil { + schema["enum"] = *p.EnumValues + } + if p.Secret != nil && *p.Secret { + schema["secret"] = true + } + if len(schema) > 0 { + entry["schema"] = schema + } + + if p.Description != nil { + entry["description"] = *p.Description + } + if p.Required != nil { + entry["required"] = *p.Required + } + if p.Example != nil { + entry["example"] = *p.Example + } + props[p.Name] = entry + } + + if len(props) > 0 { + // Merge property keys at the top level (record format) + for k, v := range props { + out[k] = v + } + } + + return out, nil } // ProtocolVersionRecord represents a protocolversionrecord. @@ -305,6 +585,34 @@ type ToolResource struct { Options map[string]any `json:"options" yaml:"options"` } +// ToolboxResource Represents a toolbox resource required by the agent. +// A toolbox is a reusable collection of tools that can be deployed as a Foundry Toolset. +type ToolboxResource struct { + Resource `json:",inline" yaml:",inline"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + Tools []any `json:"tools" yaml:"tools"` +} + +// ConnectionResource Represents a connection resource required by the agent. +// Maps to the Bicep ConnectionPropertiesV2 spec for creating project connections. +type ConnectionResource struct { + Resource `json:",inline" yaml:",inline"` + Category CategoryKind `json:"category" yaml:"category"` + Target string `json:"target" yaml:"target"` + AuthType AuthType `json:"authType" yaml:"authType"` + Credentials map[string]any `json:"credentials,omitempty" yaml:"credentials,omitempty"` + Metadata map[string]string `json:"metadata,omitempty" yaml:"metadata,omitempty"` + ExpiryTime string `json:"expiryTime,omitempty" yaml:"expiryTime,omitempty"` + IsSharedToAll *bool `json:"isSharedToAll,omitempty" yaml:"isSharedToAll,omitempty"` + SharedUserList []string `json:"sharedUserList,omitempty" yaml:"sharedUserList,omitempty"` + PeRequirement string `json:"peRequirement,omitempty" yaml:"peRequirement,omitempty"` + PeStatus string `json:"peStatus,omitempty" yaml:"peStatus,omitempty"` + Error string `json:"error,omitempty" yaml:"error,omitempty"` + + // UseWorkspaceManagedIdentity indicates whether to use workspace managed identity. + UseWorkspaceManagedIdentity *bool `json:"useWorkspaceManagedIdentity,omitempty" yaml:"useWorkspaceManagedIdentity,omitempty"` //nolint:lll +} + // Template Template model for defining prompt templates. // This model specifies the rendering engine used for slot filling prompts, // the parser used to process the rendered template into API-compatible format, @@ -396,3 +704,57 @@ type CodeInterpreterTool struct { FileIds []string `json:"fileIds" yaml:"fileIds"` Options map[string]any `json:"options" yaml:"options"` } + +// AzureAISearchIndex represents a single index configuration within an AzureAISearchTool. +type AzureAISearchIndex struct { + ProjectConnectionId string `json:"project_connection_id" yaml:"project_connection_id"` + IndexName string `json:"index_name" yaml:"index_name"` + QueryType *string `json:"query_type,omitempty" yaml:"query_type,omitempty"` + TopK *int `json:"top_k,omitempty" yaml:"top_k,omitempty"` + Filter *string `json:"filter,omitempty" yaml:"filter,omitempty"` +} + +// AzureAISearchTool The Azure AI Search tool for grounding agent responses with search index data. +type AzureAISearchTool struct { + Tool `json:",inline" yaml:",inline"` + Indexes []AzureAISearchIndex `json:"indexes" yaml:"indexes"` +} + +// A2APreviewTool The A2A (Agent-to-Agent) preview tool for delegating tasks to other agents. +type A2APreviewTool struct { + Tool `json:",inline" yaml:",inline"` + BaseUrl string `json:"baseUrl" yaml:"baseUrl"` + AgentCardPath *string `json:"agentCardPath,omitempty" yaml:"agentCardPath,omitempty"` + ProjectConnectionId string `json:"projectConnectionId" yaml:"projectConnectionId"` +} + +// Credential type structs for typed access to connection credentials. +// The ConnectionResource.Credentials field is map[string]any for flexibility, +// but these structs can be used when code needs structured access. + +// ApiKeyCredentials holds credentials for ApiKey auth type. +type ApiKeyCredentials struct { + Key string `json:"key" yaml:"key"` +} + +// CustomKeysCredentials holds credentials for CustomKeys auth type. +type CustomKeysCredentials struct { + Keys map[string]string `json:"keys" yaml:"keys"` +} + +// OAuth2Credentials holds credentials for OAuth2 auth type. +type OAuth2Credentials struct { + AuthUrl string `json:"authUrl,omitempty" yaml:"authUrl,omitempty"` + ClientId string `json:"clientId" yaml:"clientId"` + ClientSecret string `json:"clientSecret,omitempty" yaml:"clientSecret,omitempty"` + DeveloperToken string `json:"developerToken,omitempty" yaml:"developerToken,omitempty"` + Password string `json:"password,omitempty" yaml:"password,omitempty"` + RefreshToken string `json:"refreshToken,omitempty" yaml:"refreshToken,omitempty"` + TenantId string `json:"tenantId,omitempty" yaml:"tenantId,omitempty"` + Username string `json:"username,omitempty" yaml:"username,omitempty"` +} + +// PATCredentials holds credentials for PAT (Personal Access Token) auth type. +type PATCredentials struct { + Pat string `json:"pat" yaml:"pat"` +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml_test.go index fba8c971244..398cceb368e 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/agent_yaml/yaml_test.go @@ -23,3 +23,108 @@ func TestArrayProperty_BasicSerialization(t *testing.T) { t.Fatalf("Failed to unmarshal ArrayProperty: %v", err) } } + +// TestConnectionResourceSerialization tests JSON round-trip for ConnectionResource +func TestConnectionResourceSerialization(t *testing.T) { + conn := ConnectionResource{ + Resource: Resource{Name: "test-conn", Kind: ResourceKindConnection}, + Category: CategoryCustomKeys, + Target: "https://example.com", + AuthType: AuthTypeCustomKeys, + Credentials: map[string]any{"key": "secret"}, + Metadata: map[string]string{"env": "test"}, + ExpiryTime: "2025-12-31", + IsSharedToAll: new(true), + } + + data, err := json.Marshal(conn) + if err != nil { + t.Fatalf("Failed to marshal ConnectionResource: %v", err) + } + + var conn2 ConnectionResource + if err := json.Unmarshal(data, &conn2); err != nil { + t.Fatalf("Failed to unmarshal ConnectionResource: %v", err) + } + + if conn2.Name != "test-conn" { + t.Errorf("Expected name 'test-conn', got '%s'", conn2.Name) + } + if conn2.AuthType != AuthTypeCustomKeys { + t.Errorf("Expected authType 'CustomKeys', got '%s'", conn2.AuthType) + } + if conn2.IsSharedToAll == nil || !*conn2.IsSharedToAll { + t.Error("Expected isSharedToAll to be true") + } +} + +// TestAzureAISearchToolSerialization tests JSON round-trip for AzureAISearchTool +func TestAzureAISearchToolSerialization(t *testing.T) { + tool := AzureAISearchTool{ + Tool: Tool{ + Name: "search-tool", + Kind: ToolKindAzureAiSearch, + }, + Indexes: []AzureAISearchIndex{ + { + ProjectConnectionId: "my-conn", + IndexName: "my-index", + TopK: new(5), + }, + }, + } + + data, err := json.Marshal(tool) + if err != nil { + t.Fatalf("Failed to marshal AzureAISearchTool: %v", err) + } + + var tool2 AzureAISearchTool + if err := json.Unmarshal(data, &tool2); err != nil { + t.Fatalf("Failed to unmarshal AzureAISearchTool: %v", err) + } + + if tool2.Kind != ToolKindAzureAiSearch { + t.Errorf("Expected kind 'azure_ai_search', got '%s'", tool2.Kind) + } + if len(tool2.Indexes) != 1 { + t.Fatalf("Expected 1 index, got %d", len(tool2.Indexes)) + } + if tool2.Indexes[0].IndexName != "my-index" { + t.Errorf("Expected index_name 'my-index', got '%s'", tool2.Indexes[0].IndexName) + } +} + +// TestA2APreviewToolSerialization tests JSON round-trip for A2APreviewTool +func TestA2APreviewToolSerialization(t *testing.T) { + agentCardPath := "/.well-known/agent-card.json" + tool := A2APreviewTool{ + Tool: Tool{ + Name: "a2a-tool", + Kind: ToolKindA2APreview, + }, + BaseUrl: "https://agent.example.com", + AgentCardPath: &agentCardPath, + ProjectConnectionId: "my-conn", + } + + data, err := json.Marshal(tool) + if err != nil { + t.Fatalf("Failed to marshal A2APreviewTool: %v", err) + } + + var tool2 A2APreviewTool + if err := json.Unmarshal(data, &tool2); err != nil { + t.Fatalf("Failed to unmarshal A2APreviewTool: %v", err) + } + + if tool2.Kind != ToolKindA2APreview { + t.Errorf("Expected kind 'a2a_preview', got '%s'", tool2.Kind) + } + if tool2.BaseUrl != "https://agent.example.com" { + t.Errorf("Expected baseUrl 'https://agent.example.com', got '%s'", tool2.BaseUrl) + } + if tool2.AgentCardPath == nil || *tool2.AgentCardPath != agentCardPath { + t.Errorf("Expected agentCardPath '%s'", agentCardPath) + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/registry_api/helpers.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/registry_api/helpers.go index 77586036d92..be342bf8b27 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/registry_api/helpers.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/agents/registry_api/helpers.go @@ -6,6 +6,7 @@ package registry_api import ( "context" "fmt" + "os" "reflect" "strconv" "strings" @@ -15,6 +16,7 @@ import ( "github.com/azure/azure-dev/cli/azd/pkg/azdext" "github.com/braydonk/yaml" + "golang.org/x/term" ) // ParameterValues represents the user-provided values for manifest parameters @@ -346,9 +348,13 @@ func promptForYamlParameterValues( var value any var err error isRequired := property.Required != nil && *property.Required + isSecret := property.Secret != nil && *property.Secret if len(enumValues) > 0 { // Use selection for enum parameters value, err = promptForEnumValue(ctx, property.Name, enumValues, defaultValue, azdClient, noPrompt) + } else if isSecret { + // Use masked input for secret parameters + value, err = promptForSecretValue(property.Name, isRequired) } else { // Use text input for other parameters value, err = promptForTextValue(ctx, property.Name, defaultValue, isRequired, azdClient) @@ -451,7 +457,7 @@ func promptForTextValue( defaultStr = fmt.Sprintf("%v", defaultValue) } - message := fmt.Sprintf("Enter value for parameter '%s':", paramName) + message := fmt.Sprintf("Enter value for parameter '%s'", paramName) if defaultStr != "" { message += fmt.Sprintf(" (default: %s)", defaultStr) } @@ -480,6 +486,24 @@ func promptForTextValue( return resp.Value, nil } +// promptForSecretValue reads a secret value from the terminal with masked input. +func promptForSecretValue(paramName string, required bool) (any, error) { + fmt.Printf("Enter value for parameter '%s': ", paramName) + + secret, err := term.ReadPassword(int(os.Stdin.Fd())) + fmt.Println() // newline after masked input + if err != nil { + return nil, fmt.Errorf("failed to read secret value: %w", err) + } + + value := strings.TrimSpace(string(secret)) + if value == "" && required { + return nil, fmt.Errorf("parameter '%s' is required but no value was provided", paramName) + } + + return value, nil +} + // injectParameterValues replaces parameter placeholders in the template with actual values func injectParameterValues(template string, paramValues ParameterValues) ([]byte, error) { // Replace each parameter placeholder with its value diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_toolsets_client.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_toolsets_client.go new file mode 100644 index 00000000000..537e09d1a69 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_toolsets_client.go @@ -0,0 +1,214 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azure + +import ( + "azureaiagent/internal/version" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" + "github.com/azure/azure-dev/cli/azd/pkg/azsdk" +) + +const ( + toolboxesApiVersion = "v1" + toolboxesFeatureHeader = "Toolboxes=V1Preview" +) + +// FoundryToolboxClient provides methods for interacting with the Foundry Toolboxes API. +type FoundryToolboxClient struct { + endpoint string + pipeline runtime.Pipeline +} + +// NewFoundryToolboxClient creates a new FoundryToolboxClient. +func NewFoundryToolboxClient( + endpoint string, + cred azcore.TokenCredential, +) *FoundryToolboxClient { + userAgent := fmt.Sprintf("azd-ext-azure-ai-agents/%s", version.Version) + + clientOptions := &policy.ClientOptions{ + Logging: policy.LogOptions{ + AllowedHeaders: []string{azsdk.MsCorrelationIdHeader, "X-Request-Id"}, + IncludeBody: true, + }, + PerCallPolicies: []policy.Policy{ + runtime.NewBearerTokenPolicy(cred, []string{"https://ai.azure.com/.default"}, nil), + azsdk.NewMsCorrelationPolicy(), + azsdk.NewUserAgentPolicy(userAgent), + }, + } + + pipeline := runtime.NewPipeline( + "azure-ai-agents", + "v1.0.0", + runtime.PipelineOptions{}, + clientOptions, + ) + + return &FoundryToolboxClient{ + endpoint: endpoint, + pipeline: pipeline, + } +} + +// CreateToolboxVersionRequest is the request body for creating a new toolbox version. +// The toolbox name is provided in the URL path, not in the body. +type CreateToolboxVersionRequest struct { + Description string `json:"description,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` + Tools []map[string]any `json:"tools"` +} + +// ToolboxObject is the lightweight response for a toolbox (no tools list). +type ToolboxObject struct { + Id string `json:"id"` + Name string `json:"name"` + DefaultVersion string `json:"default_version"` +} + +// ToolboxVersionObject is the response for a specific toolbox version. +type ToolboxVersionObject struct { + Id string `json:"id"` + Name string `json:"name"` + Version string `json:"version"` + Description string `json:"description,omitempty"` + CreatedAt int64 `json:"created_at"` + Metadata map[string]string `json:"metadata,omitempty"` + Tools []map[string]any `json:"tools"` +} + +// CreateToolboxVersion creates a new version of a toolbox. +// If the toolbox does not exist, it will be created automatically. +func (c *FoundryToolboxClient) CreateToolboxVersion( + ctx context.Context, + toolboxName string, + request *CreateToolboxVersionRequest, +) (*ToolboxVersionObject, error) { + targetUrl := fmt.Sprintf( + "%s/toolboxes/%s/versions?api-version=%s", + c.endpoint, url.PathEscape(toolboxName), toolboxesApiVersion, + ) + + payload, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := runtime.NewRequest(ctx, http.MethodPost, targetUrl) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Raw().Header.Set("Foundry-Features", toolboxesFeatureHeader) + + if err := req.SetBody( + streaming.NopCloser(bytes.NewReader(payload)), + "application/json", + ); err != nil { + return nil, fmt.Errorf("failed to set request body: %w", err) + } + + resp, err := c.pipeline.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if !runtime.HasStatusCode(resp, http.StatusOK) { + return nil, runtime.NewResponseError(resp) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + var result ToolboxVersionObject + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return &result, nil +} + +// GetToolbox retrieves a toolbox by name. +func (c *FoundryToolboxClient) GetToolbox( + ctx context.Context, + toolboxName string, +) (*ToolboxObject, error) { + targetUrl := fmt.Sprintf( + "%s/toolboxes/%s?api-version=%s", + c.endpoint, url.PathEscape(toolboxName), toolboxesApiVersion, + ) + + req, err := runtime.NewRequest(ctx, http.MethodGet, targetUrl) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Raw().Header.Set("Foundry-Features", toolboxesFeatureHeader) + + resp, err := c.pipeline.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if !runtime.HasStatusCode(resp, http.StatusOK) { + return nil, runtime.NewResponseError(resp) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + var result ToolboxObject + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return &result, nil +} + +// DeleteToolbox deletes a toolbox and all its versions. +func (c *FoundryToolboxClient) DeleteToolbox( + ctx context.Context, + toolboxName string, +) error { + targetUrl := fmt.Sprintf( + "%s/toolboxes/%s?api-version=%s", + c.endpoint, url.PathEscape(toolboxName), toolboxesApiVersion, + ) + + req, err := runtime.NewRequest(ctx, http.MethodDelete, targetUrl) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Raw().Header.Set("Foundry-Features", toolboxesFeatureHeader) + + resp, err := c.pipeline.Do(req) + if err != nil { + return fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if !runtime.HasStatusCode(resp, http.StatusOK, http.StatusNoContent) { + return runtime.NewResponseError(resp) + } + + return nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_toolsets_client_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_toolsets_client_test.go new file mode 100644 index 00000000000..d2818625f70 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_toolsets_client_test.go @@ -0,0 +1,288 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azure + +import ( + "fmt" + "io" + "net/http" + "net/url" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/stretchr/testify/require" +) + +// roundTripFunc is a test helper that captures HTTP requests. +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +// newTestPipeline creates an Azure SDK pipeline backed by a custom round-tripper. +func newTestPipeline(fn roundTripFunc) runtime.Pipeline { + return runtime.NewPipeline( + "test", + "v0.0.0", + runtime.PipelineOptions{}, + &policy.ClientOptions{ + Transport: &http.Client{Transport: fn}, + }, + ) +} + +// newTestToolboxClient creates a FoundryToolboxClient backed by a custom +// HTTP round-tripper so we can inspect requests and control responses +// without touching the network. +func newTestToolboxClient( + endpoint string, + fn roundTripFunc, +) *FoundryToolboxClient { + return &FoundryToolboxClient{ + endpoint: endpoint, + pipeline: newTestPipeline(fn), + } +} + +func TestCreateToolboxVersion_URLConstruction(t *testing.T) { + tests := []struct { + name string + endpoint string + toolboxName string + wantPath string + wantQuery string + }{ + { + name: "simple name", + endpoint: "https://example.com", + toolboxName: "my-toolbox", + wantPath: "/toolboxes/my-toolbox/versions", + wantQuery: "api-version=" + toolboxesApiVersion, + }, + { + name: "name with special chars is escaped", + endpoint: "https://example.com", + toolboxName: "my toolbox/v2", + wantPath: "/toolboxes/my%20toolbox%2Fv2/versions", + wantQuery: "api-version=" + toolboxesApiVersion, + }, + { + name: "endpoint with trailing slash", + endpoint: "https://example.com/", + toolboxName: "tools", + wantPath: "//toolboxes/tools/versions", + wantQuery: "api-version=" + toolboxesApiVersion, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var capturedReq *http.Request + + client := newTestToolboxClient(tt.endpoint, func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"id":"1","name":"tb","version":"v1","tools":[]}`)), + Header: make(http.Header), + }, nil + }) + + _, err := client.CreateToolboxVersion(t.Context(), tt.toolboxName, &CreateToolboxVersionRequest{ + Tools: []map[string]any{}, + }) + require.NoError(t, err) + require.NotNil(t, capturedReq) + + require.Equal(t, http.MethodPost, capturedReq.Method) + require.Equal(t, tt.wantPath, capturedReq.URL.EscapedPath()) + require.Equal(t, tt.wantQuery, capturedReq.URL.RawQuery) + }) + } +} + +func TestCreateToolboxVersion_RequiredHeaders(t *testing.T) { + var capturedReq *http.Request + + client := newTestToolboxClient("https://example.com", func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"id":"1","name":"tb","version":"v1","tools":[]}`)), + Header: make(http.Header), + }, nil + }) + + _, err := client.CreateToolboxVersion(t.Context(), "test-toolbox", &CreateToolboxVersionRequest{ + Tools: []map[string]any{}, + }) + require.NoError(t, err) + require.NotNil(t, capturedReq) + + require.Equal(t, toolboxesFeatureHeader, capturedReq.Header.Get("Foundry-Features")) + require.Equal(t, "application/json", capturedReq.Header.Get("Content-Type")) +} + +func TestCreateToolboxVersion_ErrorStatusCodes(t *testing.T) { + tests := []struct { + name string + statusCode int + wantErr bool + }{ + {"200 OK", http.StatusOK, false}, + {"400 Bad Request", http.StatusBadRequest, true}, + {"404 Not Found", http.StatusNotFound, true}, + {"409 Conflict", http.StatusConflict, true}, + {"500 Internal Server Error", http.StatusInternalServerError, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := newTestToolboxClient("https://example.com", func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: tt.statusCode, + Body: io.NopCloser(strings.NewReader(`{"id":"1","name":"tb","version":"v1","tools":[]}`)), + Header: make(http.Header), + }, nil + }) + + _, err := client.CreateToolboxVersion(t.Context(), "test", &CreateToolboxVersionRequest{ + Tools: []map[string]any{}, + }) + + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestGetToolbox_URLConstruction(t *testing.T) { + tests := []struct { + name string + toolboxName string + wantPath string + }{ + { + name: "simple name", + toolboxName: "my-toolbox", + wantPath: "/toolboxes/my-toolbox", + }, + { + name: "name needing escape", + toolboxName: "test/box", + wantPath: "/toolboxes/" + url.PathEscape("test/box"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var capturedReq *http.Request + + client := newTestToolboxClient("https://example.com", func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"id":"1","name":"tb","default_version":"v1"}`)), + Header: make(http.Header), + }, nil + }) + + _, err := client.GetToolbox(t.Context(), tt.toolboxName) + require.NoError(t, err) + require.NotNil(t, capturedReq) + + require.Equal(t, http.MethodGet, capturedReq.Method) + require.Equal(t, tt.wantPath, capturedReq.URL.EscapedPath()) + require.Equal(t, toolboxesFeatureHeader, capturedReq.Header.Get("Foundry-Features")) + }) + } +} + +func TestDeleteToolbox_URLAndStatusCodes(t *testing.T) { + tests := []struct { + name string + toolboxName string + statusCode int + wantErr bool + }{ + {"200 OK", "my-toolbox", http.StatusOK, false}, + {"204 No Content", "my-toolbox", http.StatusNoContent, false}, + {"404 Not Found", "my-toolbox", http.StatusNotFound, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var capturedReq *http.Request + + client := newTestToolboxClient("https://example.com", func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{ + StatusCode: tt.statusCode, + Body: io.NopCloser(strings.NewReader("")), + Header: make(http.Header), + }, nil + }) + + err := client.DeleteToolbox(t.Context(), tt.toolboxName) + require.NotNil(t, capturedReq) + require.Equal(t, http.MethodDelete, capturedReq.Method) + + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestToolboxClient_PathEscaping_Adversarial(t *testing.T) { + tests := []struct { + name string + toolboxName string + }{ + {"path traversal", "../../../etc/passwd"}, + {"slashes in name", "name/with/slashes"}, + {"backslash traversal", `..\..\evil`}, + {"URL-encoded slash", "..%2F..%2Fevil"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var capturedReq *http.Request + + client := newTestToolboxClient("https://example.com", func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"id":"1","name":"tb","default_version":"v1"}`)), + Header: make(http.Header), + }, nil + }) + + _, _ = client.GetToolbox(t.Context(), tt.toolboxName) + require.NotNil(t, capturedReq) + + // The escaped name must not introduce extra path segments + escaped := url.PathEscape(tt.toolboxName) + expectedPath := fmt.Sprintf("/toolboxes/%s", escaped) + require.Equal(t, expectedPath, capturedReq.URL.EscapedPath()) + + // No raw slashes in the toolbox name segment + escapedPath := capturedReq.URL.EscapedPath() + segments := strings.Split(strings.Trim(escapedPath, "/"), "/") + // Should be exactly: ["toolboxes", ""] + // (plus query params handled separately) + require.GreaterOrEqual(t, len(segments), 2, + "path should have at least 2 segments") + require.Equal(t, "toolboxes", segments[0]) + }) + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/project/config.go b/cli/azd/extensions/azure.ai.agents/internal/project/config.go index 4694d649466..78109795bc2 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/project/config.go +++ b/cli/azd/extensions/azure.ai.agents/internal/project/config.go @@ -44,11 +44,14 @@ var ResourceTiers = []ResourceTier{ // ServiceTargetAgentConfig provides custom configuration for the Azure AI Service target type ServiceTargetAgentConfig struct { - Environment map[string]string `json:"env,omitempty"` - Container *ContainerSettings `json:"container,omitempty"` - Deployments []Deployment `json:"deployments,omitempty"` - Resources []Resource `json:"resources,omitempty"` - StartupCommand string `json:"startupCommand,omitempty"` + Environment map[string]string `json:"env,omitempty"` + Container *ContainerSettings `json:"container,omitempty"` + Deployments []Deployment `json:"deployments,omitempty"` + Resources []Resource `json:"resources,omitempty"` + ToolConnections []ToolConnection `json:"toolConnections,omitempty"` + Toolboxes []Toolbox `json:"toolboxes,omitempty"` + Connections []Connection `json:"connections,omitempty"` + StartupCommand string `json:"startupCommand,omitempty"` } // ContainerSettings provides container configuration for the Azure AI Service target @@ -108,6 +111,41 @@ type Resource struct { ConnectionName string `json:"connectionName"` } +// Toolbox represents a reusable collection of tools deployed as a Foundry Toolset +type Toolbox struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Tools []map[string]any `json:"tools"` +} + +// Connection represents a project connection matching the Bicep ConnectionPropertiesV2 spec. +type Connection struct { + Name string `json:"name"` + Category string `json:"category"` + Target string `json:"target"` + AuthType string `json:"authType"` + Credentials map[string]any `json:"credentials,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` + ExpiryTime string `json:"expiryTime,omitempty"` + IsSharedToAll *bool `json:"isSharedToAll,omitempty"` + SharedUserList []string `json:"sharedUserList,omitempty"` + PeRequirement string `json:"peRequirement,omitempty"` + PeStatus string `json:"peStatus,omitempty"` + UseWorkspaceManagedIdentity *bool `json:"useWorkspaceManagedIdentity,omitempty"` + Error string `json:"error,omitempty"` +} + +// ToolConnection represents a connection to an external service (MCP tool, A2A, custom API) +// that must be created in the Foundry project during provisioning via Bicep. +type ToolConnection struct { + Name string `json:"name"` + Category string `json:"category"` + Target string `json:"target"` + AuthType string `json:"authType"` + Credentials map[string]any `json:"credentials,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + // UnmarshalStruct converts a structpb.Struct to a Go struct of type T func UnmarshalStruct[T any](s *structpb.Struct, out *T) error { structBytes, err := protojson.Marshal(s) diff --git a/cli/azd/extensions/azure.ai.agents/internal/project/config_test.go b/cli/azd/extensions/azure.ai.agents/internal/project/config_test.go new file mode 100644 index 00000000000..5026950b173 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/project/config_test.go @@ -0,0 +1,304 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package project + +import ( + "testing" +) + +// TestServiceTargetAgentConfig_WithToolboxes tests MarshalStruct/UnmarshalStruct round-trip +// with the Toolboxes field populated. +func TestServiceTargetAgentConfig_WithToolboxes(t *testing.T) { + original := ServiceTargetAgentConfig{ + Toolboxes: []Toolbox{ + { + Name: "echo-toolbox", + Description: "A sample toolbox", + Tools: []map[string]any{ + { + "type": "mcp", + "server_label": "github", + "server_url": "https://api.example.com/mcp", + "project_connection_id": "TestKey", + }, + }, + }, + }, + } + + s, err := MarshalStruct(&original) + if err != nil { + t.Fatalf("MarshalStruct failed: %v", err) + } + + var roundTripped ServiceTargetAgentConfig + if err := UnmarshalStruct(s, &roundTripped); err != nil { + t.Fatalf("UnmarshalStruct failed: %v", err) + } + + if len(roundTripped.Toolboxes) != 1 { + t.Fatalf("Expected 1 toolbox, got %d", len(roundTripped.Toolboxes)) + } + + tb := roundTripped.Toolboxes[0] + if tb.Name != "echo-toolbox" { + t.Errorf("Expected toolbox name 'echo-toolbox', got '%s'", tb.Name) + } + if tb.Description != "A sample toolbox" { + t.Errorf("Expected description 'A sample toolbox', got '%s'", tb.Description) + } + if len(tb.Tools) != 1 { + t.Fatalf("Expected 1 tool in toolbox, got %d", len(tb.Tools)) + } + + tool := tb.Tools[0] + if tool["type"] != "mcp" { + t.Errorf("Expected tool type 'mcp', got '%v'", tool["type"]) + } + if tool["server_label"] != "github" { + t.Errorf("Expected server_label 'github', got '%v'", tool["server_label"]) + } +} + +// TestServiceTargetAgentConfig_EmptyToolboxes tests that an empty Toolboxes slice round-trips correctly. +func TestServiceTargetAgentConfig_EmptyToolboxes(t *testing.T) { + original := ServiceTargetAgentConfig{ + Toolboxes: []Toolbox{}, + } + + s, err := MarshalStruct(&original) + if err != nil { + t.Fatalf("MarshalStruct failed: %v", err) + } + + var roundTripped ServiceTargetAgentConfig + if err := UnmarshalStruct(s, &roundTripped); err != nil { + t.Fatalf("UnmarshalStruct failed: %v", err) + } + + if len(roundTripped.Toolboxes) != 0 { + t.Errorf("Expected 0 toolboxes, got %d", len(roundTripped.Toolboxes)) + } +} + +// TestServiceTargetAgentConfig_MultipleToolboxes tests round-tripping multiple toolboxes. +func TestServiceTargetAgentConfig_MultipleToolboxes(t *testing.T) { + original := ServiceTargetAgentConfig{ + Toolboxes: []Toolbox{ + { + Name: "toolbox-one", + Description: "First toolbox", + Tools: []map[string]any{ + { + "type": "mcp", + "server_label": "server-a", + "project_connection_id": "KeyA", + }, + }, + }, + { + Name: "toolbox-two", + Description: "Second toolbox", + Tools: []map[string]any{ + { + "type": "mcp", + "server_label": "server-b", + "project_connection_id": "KeyB", + }, + { + "type": "mcp", + "server_label": "server-c", + "project_connection_id": "KeyC", + }, + }, + }, + }, + } + + s, err := MarshalStruct(&original) + if err != nil { + t.Fatalf("MarshalStruct failed: %v", err) + } + + var roundTripped ServiceTargetAgentConfig + if err := UnmarshalStruct(s, &roundTripped); err != nil { + t.Fatalf("UnmarshalStruct failed: %v", err) + } + + if len(roundTripped.Toolboxes) != 2 { + t.Fatalf("Expected 2 toolboxes, got %d", len(roundTripped.Toolboxes)) + } + + if roundTripped.Toolboxes[0].Name != "toolbox-one" { + t.Errorf("Expected first toolbox name 'toolbox-one', got '%s'", roundTripped.Toolboxes[0].Name) + } + + if roundTripped.Toolboxes[1].Name != "toolbox-two" { + t.Errorf("Expected second toolbox name 'toolbox-two', got '%s'", roundTripped.Toolboxes[1].Name) + } + + if len(roundTripped.Toolboxes[1].Tools) != 2 { + t.Errorf("Expected 2 tools in second toolbox, got %d", len(roundTripped.Toolboxes[1].Tools)) + } +} + +// TestServiceTargetAgentConfig_WithOtherFields tests that Toolboxes coexists correctly +// alongside other ServiceTargetAgentConfig fields. +func TestServiceTargetAgentConfig_WithOtherFields(t *testing.T) { + original := ServiceTargetAgentConfig{ + Environment: map[string]string{"KEY": "VALUE"}, + Deployments: []Deployment{ + { + Name: "test-deployment", + Model: DeploymentModel{ + Name: "gpt-4.1-mini", + Format: "OpenAI", + Version: "2025-04-14", + }, + Sku: DeploymentSku{ + Name: "Standard", + Capacity: 10, + }, + }, + }, + Toolboxes: []Toolbox{ + { + Name: "my-toolbox", + Description: "Coexisting toolbox", + Tools: []map[string]any{ + { + "type": "mcp", + "server_label": "test-server", + "project_connection_id": "TestConn", + }, + }, + }, + }, + } + + s, err := MarshalStruct(&original) + if err != nil { + t.Fatalf("MarshalStruct failed: %v", err) + } + + var roundTripped ServiceTargetAgentConfig + if err := UnmarshalStruct(s, &roundTripped); err != nil { + t.Fatalf("UnmarshalStruct failed: %v", err) + } + + if roundTripped.Environment["KEY"] != "VALUE" { + t.Errorf("Expected env KEY=VALUE, got '%s'", roundTripped.Environment["KEY"]) + } + + if len(roundTripped.Deployments) != 1 { + t.Fatalf("Expected 1 deployment, got %d", len(roundTripped.Deployments)) + } + + if len(roundTripped.Toolboxes) != 1 { + t.Fatalf("Expected 1 toolbox, got %d", len(roundTripped.Toolboxes)) + } + + if roundTripped.Toolboxes[0].Name != "my-toolbox" { + t.Errorf("Expected toolbox name 'my-toolbox', got '%s'", roundTripped.Toolboxes[0].Name) + } +} + +// TestServiceTargetAgentConfig_WithToolConnections tests MarshalStruct/UnmarshalStruct +// round-trip with the ToolConnections field populated. +func TestServiceTargetAgentConfig_WithToolConnections(t *testing.T) { + original := ServiceTargetAgentConfig{ + ToolConnections: []ToolConnection{ + { + Name: "github-mcp", + Category: "RemoteTool", + Target: "https://api.githubcopilot.com/mcp", + AuthType: "OAuth2", + Credentials: map[string]any{ //nolint:gosec // test data, not real credentials + "clientId": "${GITHUB_CLIENT_ID}", + "clientSecret": "${GITHUB_CLIENT_SECRET}", + }, + Metadata: map[string]string{ + "ApiType": "Azure", + }, + }, + }, + Toolboxes: []Toolbox{ + { + Name: "platform-tools", + Tools: []map[string]any{ + { + "type": "mcp", + "project_connection_id": "github-mcp", + "server_url": "https://api.githubcopilot.com/mcp", + }, + }, + }, + }, + } + + s, err := MarshalStruct(&original) + if err != nil { + t.Fatalf("MarshalStruct failed: %v", err) + } + + var roundTripped ServiceTargetAgentConfig + if err := UnmarshalStruct(s, &roundTripped); err != nil { + t.Fatalf("UnmarshalStruct failed: %v", err) + } + + if len(roundTripped.ToolConnections) != 1 { + t.Fatalf("Expected 1 tool connection, got %d", len(roundTripped.ToolConnections)) + } + + conn := roundTripped.ToolConnections[0] + if conn.Name != "github-mcp" { + t.Errorf("Expected connection name 'github-mcp', got '%s'", conn.Name) + } + if conn.Category != "RemoteTool" { + t.Errorf("Expected category 'RemoteTool', got '%s'", conn.Category) + } + if conn.Target != "https://api.githubcopilot.com/mcp" { + t.Errorf("Expected target 'https://api.githubcopilot.com/mcp', got '%s'", conn.Target) + } + if conn.AuthType != "OAuth2" { + t.Errorf("Expected authType 'OAuth2', got '%s'", conn.AuthType) + } + if conn.Credentials["clientId"] != "${GITHUB_CLIENT_ID}" { + t.Errorf("Expected clientId '${GITHUB_CLIENT_ID}', got '%v'", conn.Credentials["clientId"]) + } + if conn.Metadata["ApiType"] != "Azure" { + t.Errorf("Expected metadata ApiType 'Azure', got '%s'", conn.Metadata["ApiType"]) + } + + // Verify toolbox is also preserved + if len(roundTripped.Toolboxes) != 1 { + t.Fatalf("Expected 1 toolbox, got %d", len(roundTripped.Toolboxes)) + } + if roundTripped.Toolboxes[0].Tools[0]["project_connection_id"] != "github-mcp" { + t.Errorf("Expected project_connection_id 'github-mcp', got '%v'", + roundTripped.Toolboxes[0].Tools[0]["project_connection_id"]) + } +} + +// TestServiceTargetAgentConfig_EmptyToolConnections tests that an empty ToolConnections +// slice is omitted and doesn't break round-trip. +func TestServiceTargetAgentConfig_EmptyToolConnections(t *testing.T) { + original := ServiceTargetAgentConfig{ + ToolConnections: []ToolConnection{}, + } + + s, err := MarshalStruct(&original) + if err != nil { + t.Fatalf("MarshalStruct failed: %v", err) + } + + var roundTripped ServiceTargetAgentConfig + if err := UnmarshalStruct(s, &roundTripped); err != nil { + t.Fatalf("UnmarshalStruct failed: %v", err) + } + + if len(roundTripped.ToolConnections) != 0 { + t.Errorf("Expected 0 tool connections, got %d", len(roundTripped.ToolConnections)) + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent_test.go b/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent_test.go index f23f15d7348..6c11e7ce59e 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/project/service_target_agent_test.go @@ -108,3 +108,30 @@ func TestApplyVnextMetadata(t *testing.T) { }) } } + +func TestGetServiceKey_NormalizesToolboxNames(t *testing.T) { + t.Parallel() + + p := &AgentServiceTargetProvider{} + + tests := []struct { + name string + input string + expected string + }{ + {"hyphens", "agent-tools", "AGENT_TOOLS"}, + {"spaces", "agent tools", "AGENT_TOOLS"}, + {"mixed", "my-agent tools", "MY_AGENT_TOOLS"}, + {"already upper", "TOOLS", "TOOLS"}, + {"lowercase", "tools", "TOOLS"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := p.getServiceKey(tt.input) + if got != tt.expected { + t.Errorf("getServiceKey(%q) = %q, want %q", tt.input, got, tt.expected) + } + }) + } +} diff --git a/cli/azd/extensions/azure.ai.agents/schemas/azure.ai.agent.json b/cli/azd/extensions/azure.ai.agents/schemas/azure.ai.agent.json index 22d77ce8341..bf8b113ecda 100644 --- a/cli/azd/extensions/azure.ai.agents/schemas/azure.ai.agent.json +++ b/cli/azd/extensions/azure.ai.agents/schemas/azure.ai.agent.json @@ -24,6 +24,21 @@ "description": "List of external resources for agent execution.", "items": { "$ref": "#/definitions/Resource" } }, + "toolConnections": { + "type": "array", + "description": "List of tool connections to external services (MCP tools, A2A, custom APIs) created during provisioning.", + "items": { "$ref": "#/definitions/ToolConnection" } + }, + "toolboxes": { + "type": "array", + "description": "List of toolboxes (Foundry Toolsets) to deploy.", + "items": { "$ref": "#/definitions/Toolbox" } + }, + "connections": { + "type": "array", + "description": "List of project connections to create via Bicep provisioning.", + "items": { "$ref": "#/definitions/Connection" } + }, "startupCommand": { "type": "string", "description": "Command to start the agent server (e.g., 'python main.py'). Used by 'azd ai agent run' for local development." @@ -119,6 +134,82 @@ }, "required": ["resource", "connectionName"], "additionalProperties": false + }, + "ToolConnection": { + "type": "object", + "description": "A connection to an external service (MCP tool, A2A, custom API) created via Bicep during provisioning.", + "properties": { + "name": { "type": "string", "description": "Connection name used as project_connection_id in toolbox tools." }, + "category": { "type": "string", "description": "Connection category (e.g., 'RemoteTool')." }, + "target": { "type": "string", "description": "Target endpoint URL for the connection." }, + "authType": { + "type": "string", + "description": "Authentication type for the connection.", + "enum": ["AAD", "AccessKey", "AccountKey", "ApiKey", "CustomKeys", "ManagedIdentity", "None", "OAuth2", "PAT", "ServicePrincipal", "UsernamePassword"] + }, + "credentials": { + "type": "object", + "description": "Credentials for the connection. Values may contain ${ENV_VAR} references resolved at provision time." + }, + "metadata": { + "type": "object", + "description": "Additional metadata for the connection.", + "additionalProperties": { "type": "string" } + } + }, + "required": ["name", "category", "target", "authType"], + "additionalProperties": false + }, + "Toolbox": { + "type": "object", + "description": "A reusable collection of tools deployed as a Foundry Toolset.", + "properties": { + "name": { "type": "string", "description": "Name of the toolbox." }, + "description": { "type": "string", "description": "Description of the toolbox." }, + "tools": { + "type": "array", + "description": "List of tools in the toolbox. Each tool is an object with properties passed to the Foundry Toolsets API.", + "items": { "type": "object" } + } + }, + "required": ["name", "tools"], + "additionalProperties": false + }, + "Connection": { + "type": "object", + "description": "A project connection matching the Bicep ConnectionPropertiesV2 spec.", + "properties": { + "name": { "type": "string", "description": "Connection name.", "pattern": "^[a-zA-Z0-9][a-zA-Z0-9_-]{2,32}$" }, + "category": { "type": "string", "description": "Connection category (e.g., 'CustomKeys', 'AzureOpenAI', 'CognitiveSearch')." }, + "target": { "type": "string", "description": "Target endpoint URL for the connection." }, + "authType": { + "type": "string", + "description": "Authentication type.", + "enum": ["AAD", "ApiKey", "CustomKeys", "None", "OAuth2", "PAT"] + }, + "credentials": { + "type": "object", + "description": "Authentication credentials. Structure depends on authType." + }, + "metadata": { + "type": "object", + "description": "Additional metadata as key-value pairs.", + "additionalProperties": { "type": "string" } + }, + "expiryTime": { "type": "string", "description": "Connection expiry time." }, + "isSharedToAll": { "type": "boolean", "description": "Whether the connection is shared to all users." }, + "sharedUserList": { + "type": "array", + "description": "List of users the connection is shared with.", + "items": { "type": "string" } + }, + "peRequirement": { "type": "string", "description": "Private endpoint requirement." }, + "peStatus": { "type": "string", "description": "Private endpoint status." }, + "useWorkspaceManagedIdentity": { "type": "boolean", "description": "Whether to use workspace managed identity." }, + "error": { "type": "string", "description": "Error information." } + }, + "required": ["name", "category", "target", "authType"], + "additionalProperties": false } } } \ No newline at end of file diff --git a/cli/azd/pkg/azapi/azure_resource_types.go b/cli/azd/pkg/azapi/azure_resource_types.go index 9354854c0d0..57869e4ffab 100644 --- a/cli/azd/pkg/azapi/azure_resource_types.go +++ b/cli/azd/pkg/azapi/azure_resource_types.go @@ -57,6 +57,8 @@ const ( //nolint:lll AzureResourceTypeCognitiveServiceAccountProject AzureResourceType = "Microsoft.CognitiveServices/accounts/projects" //nolint:lll + AzureResourceTypeCognitiveServiceAccountProjectConnection AzureResourceType = "Microsoft.CognitiveServices/accounts/projects/connections" + //nolint:lll AzureResourceTypeCognitiveServiceAccountCapabilityHost AzureResourceType = "Microsoft.CognitiveServices/accounts/capabilityHosts" ) @@ -135,6 +137,8 @@ func GetResourceTypeDisplayName(resourceType AzureResourceType) string { return "Azure AI Services Model Deployment" case AzureResourceTypeCognitiveServiceAccountProject: return "Foundry project" + case AzureResourceTypeCognitiveServiceAccountProjectConnection: + return "Foundry project connection" case AzureResourceTypeCognitiveServiceAccountCapabilityHost: return "Foundry capability host" case AzureResourceTypeSearchService: